Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
import re
|
| 2 |
import os
|
| 3 |
import glob
|
| 4 |
import yaml
|
| 5 |
import shutil
|
|
|
|
| 6 |
from typing import List, Tuple
|
| 7 |
|
| 8 |
import faiss
|
|
@@ -35,22 +35,18 @@ def get_default_config():
|
|
| 35 |
"""Provide default configuration"""
|
| 36 |
return {
|
| 37 |
"kb": {
|
| 38 |
-
"directory": "./knowledge_base",
|
| 39 |
"index_directory": "./index",
|
| 40 |
},
|
| 41 |
"models": {
|
| 42 |
-
# Embedding model for FAISS
|
| 43 |
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
|
| 44 |
-
# Abstractive generation model
|
| 45 |
"qa": "google/flan-t5-small",
|
| 46 |
},
|
| 47 |
"chunking": {
|
| 48 |
-
# Larger chunks -> better conceptual coverage
|
| 49 |
"chunk_size": 1200,
|
| 50 |
"overlap": 200,
|
| 51 |
},
|
| 52 |
"thresholds": {
|
| 53 |
-
# More permissive to not miss relevant chunks
|
| 54 |
"similarity": 0.1,
|
| 55 |
},
|
| 56 |
"messages": {
|
|
@@ -69,7 +65,7 @@ CONFIG = load_config()
|
|
| 69 |
KB_DIR = CONFIG["kb"]["directory"]
|
| 70 |
INDEX_DIR = CONFIG["kb"]["index_directory"]
|
| 71 |
EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"]
|
| 72 |
-
QA_MODEL_NAME = CONFIG["models"]
|
| 73 |
CHUNK_SIZE = CONFIG["chunking"]["chunk_size"]
|
| 74 |
CHUNK_OVERLAP = CONFIG["chunking"]["overlap"]
|
| 75 |
SIM_THRESHOLD = CONFIG["thresholds"]["similarity"]
|
|
@@ -103,45 +99,7 @@ def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
|
|
| 103 |
start += chunk_size - overlap
|
| 104 |
|
| 105 |
return chunks
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def clean_context_text(text: str) -> str:
|
| 109 |
-
"""
|
| 110 |
-
Clean raw document context before sending to the generator:
|
| 111 |
-
- Remove markdown headings (#, ##, ###)
|
| 112 |
-
- Remove list markers (1., 2), -, *)
|
| 113 |
-
- Remove duplicate lines
|
| 114 |
-
"""
|
| 115 |
-
lines = text.splitlines()
|
| 116 |
-
cleaned = []
|
| 117 |
-
seen = set()
|
| 118 |
-
|
| 119 |
-
for line in lines:
|
| 120 |
-
l = line.strip()
|
| 121 |
-
if not l:
|
| 122 |
-
continue
|
| 123 |
-
|
| 124 |
-
# Remove markdown headings like "# 1. Title", "## Section"
|
| 125 |
-
l = re.sub(r"^#+\s*", "", l)
|
| 126 |
-
|
| 127 |
-
# Remove ordered list prefixes like "1. ", "2) "
|
| 128 |
-
l = re.sub(r"^\d+[\.\)]\s*", "", l)
|
| 129 |
-
|
| 130 |
-
# Remove bullet markers like "- ", "* "
|
| 131 |
-
l = re.sub(r"^[-*]\s*", "", l)
|
| 132 |
-
|
| 133 |
-
# Skip very short "noise" lines
|
| 134 |
-
if len(l) < 5:
|
| 135 |
-
continue
|
| 136 |
|
| 137 |
-
# Avoid exact duplicates
|
| 138 |
-
if l in seen:
|
| 139 |
-
continue
|
| 140 |
-
seen.add(l)
|
| 141 |
-
|
| 142 |
-
cleaned.append(l)
|
| 143 |
-
|
| 144 |
-
return "\n".join(cleaned)
|
| 145 |
|
| 146 |
def load_file_text(path: str) -> str:
|
| 147 |
"""Load text from various file formats with error handling"""
|
|
@@ -213,6 +171,45 @@ def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
|
|
| 213 |
return docs
|
| 214 |
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
# -----------------------------
|
| 217 |
# KB INDEX (FAISS)
|
| 218 |
# -----------------------------
|
|
@@ -365,10 +362,10 @@ class RAGIndex:
|
|
| 365 |
print(f"Retrieval error: {e}")
|
| 366 |
return []
|
| 367 |
|
| 368 |
-
def _generate_from_context(self, prompt: str) -> str:
|
| 369 |
"""Run Flan-T5 on the given prompt and return the decoded answer."""
|
| 370 |
if self.qa_model is None or self.qa_tokenizer is None:
|
| 371 |
-
|
| 372 |
|
| 373 |
inputs = self.qa_tokenizer(
|
| 374 |
prompt,
|
|
@@ -377,23 +374,21 @@ class RAGIndex:
|
|
| 377 |
max_length=768,
|
| 378 |
)
|
| 379 |
|
| 380 |
-
|
| 381 |
**inputs,
|
| 382 |
-
max_new_tokens=
|
| 383 |
do_sample=False,
|
| 384 |
-
top_p=0.9,
|
| 385 |
-
temperature=0.7,
|
| 386 |
)
|
| 387 |
|
| 388 |
answer = self.qa_tokenizer.decode(
|
| 389 |
-
|
| 390 |
skip_special_tokens=True,
|
| 391 |
).strip()
|
| 392 |
|
| 393 |
return answer
|
| 394 |
|
| 395 |
def answer(self, question: str) -> str:
|
| 396 |
-
"""Answer a question using RAG +
|
| 397 |
if not self.initialized:
|
| 398 |
return "❌ Assistant not properly initialized. Please check the logs."
|
| 399 |
|
|
@@ -407,7 +402,7 @@ class RAGIndex:
|
|
| 407 |
f"Supported formats: .txt, .md, .pdf, .docx"
|
| 408 |
)
|
| 409 |
|
| 410 |
-
# Retrieve relevant contexts
|
| 411 |
contexts = self.retrieve(question, top_k=3)
|
| 412 |
|
| 413 |
if not contexts:
|
|
@@ -416,62 +411,55 @@ class RAGIndex:
|
|
| 416 |
f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
|
| 417 |
)
|
| 418 |
|
| 419 |
-
# Combine contexts into a single block and track sources
|
| 420 |
-
combined_context = []
|
| 421 |
used_sources = set()
|
| 422 |
|
|
|
|
|
|
|
| 423 |
for ctx, source, score in contexts:
|
| 424 |
used_sources.add(source)
|
| 425 |
-
# Only include the pure text as context
|
| 426 |
-
combined_context.append(ctx)
|
| 427 |
|
| 428 |
-
|
|
|
|
|
|
|
| 429 |
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
"
|
| 435 |
-
f"{ctx}\n\nSummary:"
|
| 436 |
)
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
)
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
# STEP 2 — Combine all summaries into a clean evidence pool
|
| 448 |
evidence = " ".join(summaries)
|
| 449 |
-
|
| 450 |
-
#
|
| 451 |
-
|
| 452 |
"You are an AI assistant that answers questions using only the summarized evidence below.\n"
|
| 453 |
-
"Write a clear
|
| 454 |
-
"Do NOT
|
| 455 |
-
"Do NOT
|
| 456 |
-
"If the answer cannot be found in the evidence, reply
|
| 457 |
"\"I don't know based on the provided documents.\"\n\n"
|
| 458 |
f"Evidence:\n{evidence}\n\n"
|
| 459 |
f"Question: {question}\n\n"
|
| 460 |
"Answer:"
|
| 461 |
)
|
| 462 |
-
|
| 463 |
-
inputs = self.qa_tokenizer(prompt_answer, return_tensors="pt", truncation=True).to(self.qa_model.device)
|
| 464 |
-
output = self.qa_model.generate(
|
| 465 |
-
**inputs,
|
| 466 |
-
max_new_tokens=128,
|
| 467 |
-
do_sample=False
|
| 468 |
-
)
|
| 469 |
-
answer_text = self.qa_tokenizer.decode(output[0], skip_special_tokens=True).strip()
|
| 470 |
-
|
| 471 |
-
|
| 472 |
|
| 473 |
try:
|
| 474 |
-
answer_text = self._generate_from_context(
|
| 475 |
except Exception as e:
|
| 476 |
print(f"Generation error: {e}")
|
| 477 |
return (
|
|
@@ -503,27 +491,22 @@ def rag_respond(message, history):
|
|
| 503 |
history = []
|
| 504 |
|
| 505 |
if not message or not str(message).strip():
|
| 506 |
-
# Keep history unchanged, just clear input
|
| 507 |
return "", history
|
| 508 |
|
| 509 |
user_msg = str(message)
|
| 510 |
|
| 511 |
-
# Append user message
|
| 512 |
history.append({
|
| 513 |
"role": "user",
|
| 514 |
"content": user_msg,
|
| 515 |
})
|
| 516 |
|
| 517 |
-
# Get bot reply
|
| 518 |
bot_reply = rag_index.answer(user_msg)
|
| 519 |
|
| 520 |
-
# Append assistant message
|
| 521 |
history.append({
|
| 522 |
"role": "assistant",
|
| 523 |
"content": bot_reply,
|
| 524 |
})
|
| 525 |
|
| 526 |
-
# Clear textbox, return updated history
|
| 527 |
return "", history
|
| 528 |
|
| 529 |
|
|
@@ -539,7 +522,6 @@ def upload_to_kb(files):
|
|
| 539 |
saved_files = []
|
| 540 |
|
| 541 |
for f in files:
|
| 542 |
-
# Gradio File object or temp file path
|
| 543 |
src_path = getattr(f, "name", None) or str(f)
|
| 544 |
if not os.path.exists(src_path):
|
| 545 |
continue
|
|
@@ -603,20 +585,19 @@ with gr.Blocks(title=CONFIG["client"]["name"]) as demo:
|
|
| 603 |
gr.Markdown(description)
|
| 604 |
|
| 605 |
with gr.Tab("Chat"):
|
| 606 |
-
chatbot = gr.Chatbot(label="RAG Chat")
|
| 607 |
|
| 608 |
with gr.Row():
|
| 609 |
txt = gr.Textbox(
|
| 610 |
show_label=False,
|
| 611 |
placeholder="Ask a question about your documents and press Enter to send...",
|
| 612 |
-
lines=1,
|
| 613 |
)
|
| 614 |
|
| 615 |
with gr.Row():
|
| 616 |
send_btn = gr.Button("Send")
|
| 617 |
clear_btn = gr.Button("Clear")
|
| 618 |
|
| 619 |
-
# Enter submits, Send button also submits
|
| 620 |
txt.submit(rag_respond, [txt, chatbot], [txt, chatbot])
|
| 621 |
send_btn.click(rag_respond, [txt, chatbot], [txt, chatbot])
|
| 622 |
clear_btn.click(lambda: ([], ""), None, [chatbot, txt])
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import glob
|
| 3 |
import yaml
|
| 4 |
import shutil
|
| 5 |
+
import re
|
| 6 |
from typing import List, Tuple
|
| 7 |
|
| 8 |
import faiss
|
|
|
|
| 35 |
"""Provide default configuration"""
|
| 36 |
return {
|
| 37 |
"kb": {
|
| 38 |
+
"directory": "./knowledge_base", # can be overridden in config.yaml (e.g., ./kb)
|
| 39 |
"index_directory": "./index",
|
| 40 |
},
|
| 41 |
"models": {
|
|
|
|
| 42 |
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
|
|
|
|
| 43 |
"qa": "google/flan-t5-small",
|
| 44 |
},
|
| 45 |
"chunking": {
|
|
|
|
| 46 |
"chunk_size": 1200,
|
| 47 |
"overlap": 200,
|
| 48 |
},
|
| 49 |
"thresholds": {
|
|
|
|
| 50 |
"similarity": 0.1,
|
| 51 |
},
|
| 52 |
"messages": {
|
|
|
|
| 65 |
KB_DIR = CONFIG["kb"]["directory"]
|
| 66 |
INDEX_DIR = CONFIG["kb"]["index_directory"]
|
| 67 |
EMBEDDING_MODEL_NAME = CONFIG["models"]["embedding"]
|
| 68 |
+
QA_MODEL_NAME = CONFIG["models"].get("qa", "google/flan-t5-small")
|
| 69 |
CHUNK_SIZE = CONFIG["chunking"]["chunk_size"]
|
| 70 |
CHUNK_OVERLAP = CONFIG["chunking"]["overlap"]
|
| 71 |
SIM_THRESHOLD = CONFIG["thresholds"]["similarity"]
|
|
|
|
| 99 |
start += chunk_size - overlap
|
| 100 |
|
| 101 |
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
def load_file_text(path: str) -> str:
|
| 105 |
"""Load text from various file formats with error handling"""
|
|
|
|
| 171 |
return docs
|
| 172 |
|
| 173 |
|
| 174 |
+
def clean_context_text(text: str) -> str:
|
| 175 |
+
"""
|
| 176 |
+
Clean raw document context before sending to the generator:
|
| 177 |
+
- Remove markdown headings (#, ##, ###)
|
| 178 |
+
- Remove list markers (1., 2), -, *)
|
| 179 |
+
- Remove duplicate lines
|
| 180 |
+
"""
|
| 181 |
+
lines = text.splitlines()
|
| 182 |
+
cleaned = []
|
| 183 |
+
seen = set()
|
| 184 |
+
|
| 185 |
+
for line in lines:
|
| 186 |
+
l = line.strip()
|
| 187 |
+
if not l:
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
# Remove markdown headings like "# 1. Title", "## Section"
|
| 191 |
+
l = re.sub(r"^#+\s*", "", l)
|
| 192 |
+
|
| 193 |
+
# Remove ordered list prefixes like "1. ", "2) "
|
| 194 |
+
l = re.sub(r"^\d+[\.\)]\s*", "", l)
|
| 195 |
+
|
| 196 |
+
# Remove bullet markers like "- ", "* "
|
| 197 |
+
l = re.sub(r"^[-*]\s*", "", l)
|
| 198 |
+
|
| 199 |
+
# Skip very short "noise" lines
|
| 200 |
+
if len(l) < 5:
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
# Avoid exact duplicates
|
| 204 |
+
if l in seen:
|
| 205 |
+
continue
|
| 206 |
+
seen.add(l)
|
| 207 |
+
|
| 208 |
+
cleaned.append(l)
|
| 209 |
+
|
| 210 |
+
return "\n".join(cleaned)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
# -----------------------------
|
| 214 |
# KB INDEX (FAISS)
|
| 215 |
# -----------------------------
|
|
|
|
| 362 |
print(f"Retrieval error: {e}")
|
| 363 |
return []
|
| 364 |
|
| 365 |
+
def _generate_from_context(self, prompt: str, max_new_tokens: int = 128) -> str:
|
| 366 |
"""Run Flan-T5 on the given prompt and return the decoded answer."""
|
| 367 |
if self.qa_model is None or self.qa_tokenizer is None:
|
| 368 |
+
raise RuntimeError("QA model not loaded.")
|
| 369 |
|
| 370 |
inputs = self.qa_tokenizer(
|
| 371 |
prompt,
|
|
|
|
| 374 |
max_length=768,
|
| 375 |
)
|
| 376 |
|
| 377 |
+
outputs = self.qa_model.generate(
|
| 378 |
**inputs,
|
| 379 |
+
max_new_tokens=max_new_tokens,
|
| 380 |
do_sample=False,
|
|
|
|
|
|
|
| 381 |
)
|
| 382 |
|
| 383 |
answer = self.qa_tokenizer.decode(
|
| 384 |
+
outputs[0],
|
| 385 |
skip_special_tokens=True,
|
| 386 |
).strip()
|
| 387 |
|
| 388 |
return answer
|
| 389 |
|
| 390 |
def answer(self, question: str) -> str:
|
| 391 |
+
"""Answer a question using RAG + two-step summarization + generation."""
|
| 392 |
if not self.initialized:
|
| 393 |
return "❌ Assistant not properly initialized. Please check the logs."
|
| 394 |
|
|
|
|
| 402 |
f"Supported formats: .txt, .md, .pdf, .docx"
|
| 403 |
)
|
| 404 |
|
| 405 |
+
# 1) Retrieve relevant contexts
|
| 406 |
contexts = self.retrieve(question, top_k=3)
|
| 407 |
|
| 408 |
if not contexts:
|
|
|
|
| 411 |
f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
|
| 412 |
)
|
| 413 |
|
|
|
|
|
|
|
| 414 |
used_sources = set()
|
| 415 |
|
| 416 |
+
# 2) Summarize each retrieved chunk into 1 sentence
|
| 417 |
+
summaries = []
|
| 418 |
for ctx, source, score in contexts:
|
| 419 |
used_sources.add(source)
|
|
|
|
|
|
|
| 420 |
|
| 421 |
+
cleaned_ctx = clean_context_text(ctx)
|
| 422 |
+
if not cleaned_ctx.strip():
|
| 423 |
+
continue
|
| 424 |
|
| 425 |
+
summary_prompt = (
|
| 426 |
+
"Summarize the following text in ONE concise sentence, keeping only the main idea. "
|
| 427 |
+
"Do not include headings, numbers, or bullet markers.\n\n"
|
| 428 |
+
f"{cleaned_ctx}\n\n"
|
| 429 |
+
"Summary:"
|
|
|
|
| 430 |
)
|
| 431 |
+
|
| 432 |
+
try:
|
| 433 |
+
summary = self._generate_from_context(summary_prompt, max_new_tokens=64)
|
| 434 |
+
summaries.append(summary)
|
| 435 |
+
except Exception as e:
|
| 436 |
+
print(f"Summary generation error: {e}")
|
| 437 |
+
continue
|
| 438 |
+
|
| 439 |
+
if not summaries:
|
| 440 |
+
return (
|
| 441 |
+
f"{NO_ANSWER_MSG}\n\n"
|
| 442 |
+
f"💡 Try rephrasing your question or adding more detailed documents to the knowledge base."
|
| 443 |
)
|
| 444 |
+
|
| 445 |
+
# 3) Combine summaries into an evidence pool
|
|
|
|
|
|
|
| 446 |
evidence = " ".join(summaries)
|
| 447 |
+
|
| 448 |
+
# 4) Ask the model to answer using only the summaries
|
| 449 |
+
answer_prompt = (
|
| 450 |
"You are an AI assistant that answers questions using only the summarized evidence below.\n"
|
| 451 |
+
"Write a clear, helpful answer in 1–3 sentences, in your own words.\n"
|
| 452 |
+
"- Do NOT include headings, section numbers, markdown, or bullet symbols.\n"
|
| 453 |
+
"- Do NOT mention file names or sources in the answer.\n"
|
| 454 |
+
"- If the answer cannot be found in the evidence, reply exactly: "
|
| 455 |
"\"I don't know based on the provided documents.\"\n\n"
|
| 456 |
f"Evidence:\n{evidence}\n\n"
|
| 457 |
f"Question: {question}\n\n"
|
| 458 |
"Answer:"
|
| 459 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
|
| 461 |
try:
|
| 462 |
+
answer_text = self._generate_from_context(answer_prompt, max_new_tokens=128)
|
| 463 |
except Exception as e:
|
| 464 |
print(f"Generation error: {e}")
|
| 465 |
return (
|
|
|
|
| 491 |
history = []
|
| 492 |
|
| 493 |
if not message or not str(message).strip():
|
|
|
|
| 494 |
return "", history
|
| 495 |
|
| 496 |
user_msg = str(message)
|
| 497 |
|
|
|
|
| 498 |
history.append({
|
| 499 |
"role": "user",
|
| 500 |
"content": user_msg,
|
| 501 |
})
|
| 502 |
|
|
|
|
| 503 |
bot_reply = rag_index.answer(user_msg)
|
| 504 |
|
|
|
|
| 505 |
history.append({
|
| 506 |
"role": "assistant",
|
| 507 |
"content": bot_reply,
|
| 508 |
})
|
| 509 |
|
|
|
|
| 510 |
return "", history
|
| 511 |
|
| 512 |
|
|
|
|
| 522 |
saved_files = []
|
| 523 |
|
| 524 |
for f in files:
|
|
|
|
| 525 |
src_path = getattr(f, "name", None) or str(f)
|
| 526 |
if not os.path.exists(src_path):
|
| 527 |
continue
|
|
|
|
| 585 |
gr.Markdown(description)
|
| 586 |
|
| 587 |
with gr.Tab("Chat"):
|
| 588 |
+
chatbot = gr.Chatbot(label="RAG Chat")
|
| 589 |
|
| 590 |
with gr.Row():
|
| 591 |
txt = gr.Textbox(
|
| 592 |
show_label=False,
|
| 593 |
placeholder="Ask a question about your documents and press Enter to send...",
|
| 594 |
+
lines=1, # single line so Enter submits
|
| 595 |
)
|
| 596 |
|
| 597 |
with gr.Row():
|
| 598 |
send_btn = gr.Button("Send")
|
| 599 |
clear_btn = gr.Button("Clear")
|
| 600 |
|
|
|
|
| 601 |
txt.submit(rag_respond, [txt, chatbot], [txt, chatbot])
|
| 602 |
send_btn.click(rag_respond, [txt, chatbot], [txt, chatbot])
|
| 603 |
clear_btn.click(lambda: ([], ""), None, [chatbot, txt])
|