sofzcc commited on
Commit
7494e47
·
verified ·
1 Parent(s): 2d28f5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -102
app.py CHANGED
@@ -1,8 +1,8 @@
1
- import re
2
  import os
3
  import glob
4
  import yaml
5
  import shutil
 
6
  from typing import List, Tuple
7
 
8
  import faiss
@@ -35,22 +35,18 @@ def get_default_config():
35
  """Provide default configuration"""
36
  return {
37
  "kb": {
38
- "directory": "./knowledge_base",
39
  "index_directory": "./index",
40
  },
41
  "models": {
42
- # Embedding model for FAISS
43
  "embedding": "sentence-transformers/all-MiniLM-L6-v2",
44
- # Abstractive generation model
45
  "qa": "google/flan-t5-small",
46
  },
47
  "chunking": {
48
- # Larger chunks -> better conceptual coverage
49
  "chunk_size": 1200,
50
  "overlap": 200,
51
  },
52
  "thresholds": {
53
- # More permissive to not miss relevant chunks
54
  "similarity": 0.1,
55
  },
56
  "messages": {
@@ -69,7 +65,7 @@ CONFIG = load_config()
69
  KB_DIR = CONFIG["kb"]["directory"]
70
  INDEX_DIR = CONFIG["kb"]["index_directory"]
71
  EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"]
72
- QA_MODEL_NAME = CONFIG["models"]["qa"]
73
  CHUNK_SIZE = CONFIG["chunking"]["chunk_size"]
74
  CHUNK_OVERLAP = CONFIG["chunking"]["overlap"]
75
  SIM_THRESHOLD = CONFIG["thresholds"]["similarity"]
@@ -103,45 +99,7 @@ def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
103
  start += chunk_size - overlap
104
 
105
  return chunks
106
-
107
-
108
- def clean_context_text(text: str) -> str:
109
- """
110
- Clean raw document context before sending to the generator:
111
- - Remove markdown headings (#, ##, ###)
112
- - Remove list markers (1., 2), -, *)
113
- - Remove duplicate lines
114
- """
115
- lines = text.splitlines()
116
- cleaned = []
117
- seen = set()
118
-
119
- for line in lines:
120
- l = line.strip()
121
- if not l:
122
- continue
123
-
124
- # Remove markdown headings like "# 1. Title", "## Section"
125
- l = re.sub(r"^#+\s*", "", l)
126
-
127
- # Remove ordered list prefixes like "1. ", "2) "
128
- l = re.sub(r"^\d+[\.\)]\s*", "", l)
129
-
130
- # Remove bullet markers like "- ", "* "
131
- l = re.sub(r"^[-*]\s*", "", l)
132
-
133
- # Skip very short "noise" lines
134
- if len(l) < 5:
135
- continue
136
 
137
- # Avoid exact duplicates
138
- if l in seen:
139
- continue
140
- seen.add(l)
141
-
142
- cleaned.append(l)
143
-
144
- return "\n".join(cleaned)
145
 
146
  def load_file_text(path: str) -> str:
147
  """Load text from various file formats with error handling"""
@@ -213,6 +171,45 @@ def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
213
  return docs
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  # -----------------------------
217
  # KB INDEX (FAISS)
218
  # -----------------------------
@@ -365,10 +362,10 @@ class RAGIndex:
365
  print(f"Retrieval error: {e}")
366
  return []
367
 
368
- def _generate_from_context(self, prompt: str) -> str:
369
  """Run Flan-T5 on the given prompt and return the decoded answer."""
370
  if self.qa_model is None or self.qa_tokenizer is None:
371
- return "Model not loaded."
372
 
373
  inputs = self.qa_tokenizer(
374
  prompt,
@@ -377,23 +374,21 @@ class RAGIndex:
377
  max_length=768,
378
  )
379
 
380
- output_ids = self.qa_model.generate(
381
  **inputs,
382
- max_new_tokens=128,
383
  do_sample=False,
384
- top_p=0.9,
385
- temperature=0.7,
386
  )
387
 
388
  answer = self.qa_tokenizer.decode(
389
- output_ids[0],
390
  skip_special_tokens=True,
391
  ).strip()
392
 
393
  return answer
394
 
395
  def answer(self, question: str) -> str:
396
- """Answer a question using RAG + abstractive generation"""
397
  if not self.initialized:
398
  return "❌ Assistant not properly initialized. Please check the logs."
399
 
@@ -407,7 +402,7 @@ class RAGIndex:
407
  f"Supported formats: .txt, .md, .pdf, .docx"
408
  )
409
 
410
- # Retrieve relevant contexts
411
  contexts = self.retrieve(question, top_k=3)
412
 
413
  if not contexts:
@@ -416,62 +411,55 @@ class RAGIndex:
416
  f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
417
  )
418
 
419
- # Combine contexts into a single block and track sources
420
- combined_context = []
421
  used_sources = set()
422
 
 
 
423
  for ctx, source, score in contexts:
424
  used_sources.add(source)
425
- # Only include the pure text as context
426
- combined_context.append(ctx)
427
 
428
- combined_text = "\n\n".join(combined_context)
 
 
429
 
430
- # STEP 1 — Summarize each chunk individually
431
- summaries = []
432
- for ctx in combined_context:
433
- prompt_summary = (
434
- "Summarize the following text in one concise sentence, keeping only the core idea:\n\n"
435
- f"{ctx}\n\nSummary:"
436
  )
437
-
438
- inputs = self.qa_tokenizer(prompt_summary, return_tensors="pt", truncation=True).to(self.qa_model.device)
439
- output = self.qa_model.generate(
440
- **inputs,
441
- max_new_tokens=64,
442
- do_sample=False
 
 
 
 
 
 
443
  )
444
- summary_text = self.qa_tokenizer.decode(output[0], skip_special_tokens=True).strip()
445
- summaries.append(summary_text)
446
-
447
- # STEP 2 — Combine all summaries into a clean evidence pool
448
  evidence = " ".join(summaries)
449
-
450
- # STEP 3 Ask model to answer based on summaries only
451
- prompt_answer = (
452
  "You are an AI assistant that answers questions using only the summarized evidence below.\n"
453
- "Write a clear and complete answer in 1–3 sentences.\n"
454
- "Do NOT repeat numbers, headings, markdown, or irrelevant text.\n"
455
- "Do NOT say where the information came from.\n"
456
- "If the answer cannot be found in the evidence, reply:\n"
457
  "\"I don't know based on the provided documents.\"\n\n"
458
  f"Evidence:\n{evidence}\n\n"
459
  f"Question: {question}\n\n"
460
  "Answer:"
461
  )
462
-
463
- inputs = self.qa_tokenizer(prompt_answer, return_tensors="pt", truncation=True).to(self.qa_model.device)
464
- output = self.qa_model.generate(
465
- **inputs,
466
- max_new_tokens=128,
467
- do_sample=False
468
- )
469
- answer_text = self.qa_tokenizer.decode(output[0], skip_special_tokens=True).strip()
470
-
471
-
472
 
473
  try:
474
- answer_text = self._generate_from_context(prompt)
475
  except Exception as e:
476
  print(f"Generation error: {e}")
477
  return (
@@ -503,27 +491,22 @@ def rag_respond(message, history):
503
  history = []
504
 
505
  if not message or not str(message).strip():
506
- # Keep history unchanged, just clear input
507
  return "", history
508
 
509
  user_msg = str(message)
510
 
511
- # Append user message
512
  history.append({
513
  "role": "user",
514
  "content": user_msg,
515
  })
516
 
517
- # Get bot reply
518
  bot_reply = rag_index.answer(user_msg)
519
 
520
- # Append assistant message
521
  history.append({
522
  "role": "assistant",
523
  "content": bot_reply,
524
  })
525
 
526
- # Clear textbox, return updated history
527
  return "", history
528
 
529
 
@@ -539,7 +522,6 @@ def upload_to_kb(files):
539
  saved_files = []
540
 
541
  for f in files:
542
- # Gradio File object or temp file path
543
  src_path = getattr(f, "name", None) or str(f)
544
  if not os.path.exists(src_path):
545
  continue
@@ -603,20 +585,19 @@ with gr.Blocks(title=CONFIG["client"]["name"]) as demo:
603
  gr.Markdown(description)
604
 
605
  with gr.Tab("Chat"):
606
- chatbot = gr.Chatbot(label="RAG Chat") # messages-format by default
607
 
608
  with gr.Row():
609
  txt = gr.Textbox(
610
  show_label=False,
611
  placeholder="Ask a question about your documents and press Enter to send...",
612
- lines=1,
613
  )
614
 
615
  with gr.Row():
616
  send_btn = gr.Button("Send")
617
  clear_btn = gr.Button("Clear")
618
 
619
- # Enter submits, Send button also submits
620
  txt.submit(rag_respond, [txt, chatbot], [txt, chatbot])
621
  send_btn.click(rag_respond, [txt, chatbot], [txt, chatbot])
622
  clear_btn.click(lambda: ([], ""), None, [chatbot, txt])
 
 
1
  import os
2
  import glob
3
  import yaml
4
  import shutil
5
+ import re
6
  from typing import List, Tuple
7
 
8
  import faiss
 
35
  """Provide default configuration"""
36
  return {
37
  "kb": {
38
+ "directory": "./knowledge_base", # can be overridden in config.yaml (e.g., ./kb)
39
  "index_directory": "./index",
40
  },
41
  "models": {
 
42
  "embedding": "sentence-transformers/all-MiniLM-L6-v2",
 
43
  "qa": "google/flan-t5-small",
44
  },
45
  "chunking": {
 
46
  "chunk_size": 1200,
47
  "overlap": 200,
48
  },
49
  "thresholds": {
 
50
  "similarity": 0.1,
51
  },
52
  "messages": {
 
65
  KB_DIR = CONFIG["kb"]["directory"]
66
  INDEX_DIR = CONFIG["kb"]["index_directory"]
67
  EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"]
68
+ QA_MODEL_NAME = CONFIG["models"].get("qa", "google/flan-t5-small")
69
  CHUNK_SIZE = CONFIG["chunking"]["chunk_size"]
70
  CHUNK_OVERLAP = CONFIG["chunking"]["overlap"]
71
  SIM_THRESHOLD = CONFIG["thresholds"]["similarity"]
 
99
  start += chunk_size - overlap
100
 
101
  return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
 
 
 
 
 
 
 
103
 
104
  def load_file_text(path: str) -> str:
105
  """Load text from various file formats with error handling"""
 
171
  return docs
172
 
173
 
174
+ def clean_context_text(text: str) -> str:
175
+ """
176
+ Clean raw document context before sending to the generator:
177
+ - Remove markdown headings (#, ##, ###)
178
+ - Remove list markers (1., 2), -, *)
179
+ - Remove duplicate lines
180
+ """
181
+ lines = text.splitlines()
182
+ cleaned = []
183
+ seen = set()
184
+
185
+ for line in lines:
186
+ l = line.strip()
187
+ if not l:
188
+ continue
189
+
190
+ # Remove markdown headings like "# 1. Title", "## Section"
191
+ l = re.sub(r"^#+\s*", "", l)
192
+
193
+ # Remove ordered list prefixes like "1. ", "2) "
194
+ l = re.sub(r"^\d+[\.\)]\s*", "", l)
195
+
196
+ # Remove bullet markers like "- ", "* "
197
+ l = re.sub(r"^[-*]\s*", "", l)
198
+
199
+ # Skip very short "noise" lines
200
+ if len(l) < 5:
201
+ continue
202
+
203
+ # Avoid exact duplicates
204
+ if l in seen:
205
+ continue
206
+ seen.add(l)
207
+
208
+ cleaned.append(l)
209
+
210
+ return "\n".join(cleaned)
211
+
212
+
213
  # -----------------------------
214
  # KB INDEX (FAISS)
215
  # -----------------------------
 
362
  print(f"Retrieval error: {e}")
363
  return []
364
 
365
+ def _generate_from_context(self, prompt: str, max_new_tokens: int = 128) -> str:
366
  """Run Flan-T5 on the given prompt and return the decoded answer."""
367
  if self.qa_model is None or self.qa_tokenizer is None:
368
+ raise RuntimeError("QA model not loaded.")
369
 
370
  inputs = self.qa_tokenizer(
371
  prompt,
 
374
  max_length=768,
375
  )
376
 
377
+ outputs = self.qa_model.generate(
378
  **inputs,
379
+ max_new_tokens=max_new_tokens,
380
  do_sample=False,
 
 
381
  )
382
 
383
  answer = self.qa_tokenizer.decode(
384
+ outputs[0],
385
  skip_special_tokens=True,
386
  ).strip()
387
 
388
  return answer
389
 
390
  def answer(self, question: str) -> str:
391
+ """Answer a question using RAG + two-step summarization + generation."""
392
  if not self.initialized:
393
  return "❌ Assistant not properly initialized. Please check the logs."
394
 
 
402
  f"Supported formats: .txt, .md, .pdf, .docx"
403
  )
404
 
405
+ # 1) Retrieve relevant contexts
406
  contexts = self.retrieve(question, top_k=3)
407
 
408
  if not contexts:
 
411
  f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
412
  )
413
 
 
 
414
  used_sources = set()
415
 
416
+ # 2) Summarize each retrieved chunk into 1 sentence
417
+ summaries = []
418
  for ctx, source, score in contexts:
419
  used_sources.add(source)
 
 
420
 
421
+ cleaned_ctx = clean_context_text(ctx)
422
+ if not cleaned_ctx.strip():
423
+ continue
424
 
425
+ summary_prompt = (
426
+ "Summarize the following text in ONE concise sentence, keeping only the main idea. "
427
+ "Do not include headings, numbers, or bullet markers.\n\n"
428
+ f"{cleaned_ctx}\n\n"
429
+ "Summary:"
 
430
  )
431
+
432
+ try:
433
+ summary = self._generate_from_context(summary_prompt, max_new_tokens=64)
434
+ summaries.append(summary)
435
+ except Exception as e:
436
+ print(f"Summary generation error: {e}")
437
+ continue
438
+
439
+ if not summaries:
440
+ return (
441
+ f"{NO_ANSWER_MSG}\n\n"
442
+ f"💡 Try rephrasing your question or adding more detailed documents to the knowledge base."
443
  )
444
+
445
+ # 3) Combine summaries into an evidence pool
 
 
446
  evidence = " ".join(summaries)
447
+
448
+ # 4) Ask the model to answer using only the summaries
449
+ answer_prompt = (
450
  "You are an AI assistant that answers questions using only the summarized evidence below.\n"
451
+ "Write a clear, helpful answer in 1–3 sentences, in your own words.\n"
452
+ "- Do NOT include headings, section numbers, markdown, or bullet symbols.\n"
453
+ "- Do NOT mention file names or sources in the answer.\n"
454
+ "- If the answer cannot be found in the evidence, reply exactly: "
455
  "\"I don't know based on the provided documents.\"\n\n"
456
  f"Evidence:\n{evidence}\n\n"
457
  f"Question: {question}\n\n"
458
  "Answer:"
459
  )
 
 
 
 
 
 
 
 
 
 
460
 
461
  try:
462
+ answer_text = self._generate_from_context(answer_prompt, max_new_tokens=128)
463
  except Exception as e:
464
  print(f"Generation error: {e}")
465
  return (
 
491
  history = []
492
 
493
  if not message or not str(message).strip():
 
494
  return "", history
495
 
496
  user_msg = str(message)
497
 
 
498
  history.append({
499
  "role": "user",
500
  "content": user_msg,
501
  })
502
 
 
503
  bot_reply = rag_index.answer(user_msg)
504
 
 
505
  history.append({
506
  "role": "assistant",
507
  "content": bot_reply,
508
  })
509
 
 
510
  return "", history
511
 
512
 
 
522
  saved_files = []
523
 
524
  for f in files:
 
525
  src_path = getattr(f, "name", None) or str(f)
526
  if not os.path.exists(src_path):
527
  continue
 
585
  gr.Markdown(description)
586
 
587
  with gr.Tab("Chat"):
588
+ chatbot = gr.Chatbot(label="RAG Chat")
589
 
590
  with gr.Row():
591
  txt = gr.Textbox(
592
  show_label=False,
593
  placeholder="Ask a question about your documents and press Enter to send...",
594
+ lines=1, # single line so Enter submits
595
  )
596
 
597
  with gr.Row():
598
  send_btn = gr.Button("Send")
599
  clear_btn = gr.Button("Clear")
600
 
 
601
  txt.submit(rag_respond, [txt, chatbot], [txt, chatbot])
602
  send_btn.click(rag_respond, [txt, chatbot], [txt, chatbot])
603
  clear_btn.click(lambda: ([], ""), None, [chatbot, txt])