Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -38,16 +38,19 @@ def get_default_config():
|
|
| 38 |
"index_directory": "./index",
|
| 39 |
},
|
| 40 |
"models": {
|
| 41 |
-
#
|
| 42 |
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
|
| 43 |
-
|
|
|
|
| 44 |
},
|
| 45 |
"chunking": {
|
| 46 |
-
|
| 47 |
-
"
|
|
|
|
| 48 |
},
|
| 49 |
"thresholds": {
|
| 50 |
-
|
|
|
|
| 51 |
},
|
| 52 |
"messages": {
|
| 53 |
"welcome": "Ask me anything about the documents in the knowledge base!",
|
|
@@ -178,7 +181,7 @@ def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
|
|
| 178 |
class RAGIndex:
|
| 179 |
def __init__(self):
|
| 180 |
self.embedder = None
|
| 181 |
-
self.qa_pipeline = None
|
| 182 |
self.chunks: List[str] = []
|
| 183 |
self.chunk_sources: List[str] = []
|
| 184 |
self.index = None
|
|
@@ -200,12 +203,12 @@ class RAGIndex:
|
|
| 200 |
print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
|
| 201 |
self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
| 202 |
|
| 203 |
-
print(f"Loading QA model: {QA_MODEL_NAME}")
|
|
|
|
| 204 |
self.qa_pipeline = pipeline(
|
| 205 |
-
"
|
| 206 |
-
model=
|
| 207 |
-
tokenizer=
|
| 208 |
-
handle_impossible_answer=True,
|
| 209 |
)
|
| 210 |
except Exception as e:
|
| 211 |
print(f"Error loading models: {e}")
|
|
@@ -327,7 +330,7 @@ class RAGIndex:
|
|
| 327 |
return []
|
| 328 |
|
| 329 |
def answer(self, question: str) -> str:
|
| 330 |
-
"""Answer a question using RAG"""
|
| 331 |
if not self.initialized:
|
| 332 |
return "❌ Assistant not properly initialized. Please check the logs."
|
| 333 |
|
|
@@ -350,45 +353,51 @@ class RAGIndex:
|
|
| 350 |
f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
|
| 351 |
)
|
| 352 |
|
| 353 |
-
#
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
# Truncate context if too long (max 512 tokens for most QA models)
|
| 357 |
-
max_context_length = 2000 # characters, roughly 512 tokens
|
| 358 |
-
truncated_ctx = ctx[:max_context_length]
|
| 359 |
-
|
| 360 |
-
qa_input = {"question": question, "context": truncated_ctx}
|
| 361 |
-
|
| 362 |
-
try:
|
| 363 |
-
result = self.qa_pipeline(qa_input)
|
| 364 |
-
answer_text = result.get("answer", "").strip()
|
| 365 |
-
answer_score = result.get("score", 0.0)
|
| 366 |
-
|
| 367 |
-
if answer_text and answer_score > 0.01: # Minimum confidence threshold
|
| 368 |
-
answers.append((answer_text, source, answer_score, score))
|
| 369 |
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
return (
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
f"💡 Try asking a more specific question."
|
| 382 |
)
|
| 383 |
|
| 384 |
-
|
| 385 |
-
answers.sort(key=lambda x: x[2] * x[3], reverse=True)
|
| 386 |
-
best_answer, best_source, qa_score, retrieval_score = answers[0]
|
| 387 |
|
| 388 |
return (
|
| 389 |
-
f"**Answer:** {
|
| 390 |
-
f"**
|
| 391 |
-
f"**Confidence:** {qa_score:.2%}"
|
| 392 |
)
|
| 393 |
|
| 394 |
|
|
@@ -482,7 +491,7 @@ def rebuild_index():
|
|
| 482 |
)
|
| 483 |
|
| 484 |
|
| 485 |
-
# Description +
|
| 486 |
description = WELCOME_MSG
|
| 487 |
if not rag_index.initialized or rag_index.index is None or not rag_index.chunks:
|
| 488 |
description += (
|
|
@@ -497,9 +506,9 @@ examples = [
|
|
| 497 |
]
|
| 498 |
if not examples and rag_index.initialized and rag_index.index is not None and rag_index.chunks:
|
| 499 |
examples = [
|
| 500 |
-
"What is
|
| 501 |
-
"
|
| 502 |
-
"
|
| 503 |
]
|
| 504 |
|
| 505 |
|
|
|
|
| 38 |
"index_directory": "./index",
|
| 39 |
},
|
| 40 |
"models": {
|
| 41 |
+
# Embedding model for FAISS
|
| 42 |
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
|
| 43 |
+
# Abstractive generation model (can upgrade to flan-t5-base if resources allow)
|
| 44 |
+
"qa": "google/flan-t5-small",
|
| 45 |
},
|
| 46 |
"chunking": {
|
| 47 |
+
# Larger chunks -> better conceptual coverage
|
| 48 |
+
"chunk_size": 1200,
|
| 49 |
+
"overlap": 200,
|
| 50 |
},
|
| 51 |
"thresholds": {
|
| 52 |
+
# More permissive to not miss relevant chunks
|
| 53 |
+
"similarity": 0.1,
|
| 54 |
},
|
| 55 |
"messages": {
|
| 56 |
"welcome": "Ask me anything about the documents in the knowledge base!",
|
|
|
|
| 181 |
class RAGIndex:
|
| 182 |
def __init__(self):
|
| 183 |
self.embedder = None
|
| 184 |
+
self.qa_pipeline = None # now a generative pipeline
|
| 185 |
self.chunks: List[str] = []
|
| 186 |
self.chunk_sources: List[str] = []
|
| 187 |
self.index = None
|
|
|
|
| 203 |
print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
|
| 204 |
self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
| 205 |
|
| 206 |
+
print(f"Loading QA (generation) model: {QA_MODEL_NAME}")
|
| 207 |
+
# Abstractive generation pipeline (Flan-T5)
|
| 208 |
self.qa_pipeline = pipeline(
|
| 209 |
+
"text2text-generation",
|
| 210 |
+
model=QA_MODEL_NAME,
|
| 211 |
+
tokenizer=QA_MODEL_NAME,
|
|
|
|
| 212 |
)
|
| 213 |
except Exception as e:
|
| 214 |
print(f"Error loading models: {e}")
|
|
|
|
| 330 |
return []
|
| 331 |
|
| 332 |
def answer(self, question: str) -> str:
|
| 333 |
+
"""Answer a question using RAG + abstractive generation"""
|
| 334 |
if not self.initialized:
|
| 335 |
return "❌ Assistant not properly initialized. Please check the logs."
|
| 336 |
|
|
|
|
| 353 |
f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
|
| 354 |
)
|
| 355 |
|
| 356 |
+
# Combine contexts into a single block and track sources
|
| 357 |
+
combined_context = []
|
| 358 |
+
used_sources = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
+
for ctx, source, score in contexts:
|
| 361 |
+
used_sources.add(source)
|
| 362 |
+
combined_context.append(f"[Source: {source}]\n{ctx}")
|
| 363 |
+
|
| 364 |
+
combined_text = "\n\n".join(combined_context)
|
| 365 |
+
|
| 366 |
+
# Limit context length to keep it manageable for the model
|
| 367 |
+
max_context_chars = 4000
|
| 368 |
+
if len(combined_text) > max_context_chars:
|
| 369 |
+
combined_text = combined_text[:max_context_chars]
|
| 370 |
+
|
| 371 |
+
# Prompt for the generative model
|
| 372 |
+
prompt = (
|
| 373 |
+
"You are an AI assistant that answers questions using only the provided context. "
|
| 374 |
+
"If the answer cannot be found in the context, reply exactly with: "
|
| 375 |
+
"\"I don't know based on the provided documents.\"\n\n"
|
| 376 |
+
f"Context:\n{combined_text}\n\n"
|
| 377 |
+
f"Question: {question}\n\n"
|
| 378 |
+
"Answer:"
|
| 379 |
+
)
|
| 380 |
|
| 381 |
+
try:
|
| 382 |
+
result = self.qa_pipeline(
|
| 383 |
+
prompt,
|
| 384 |
+
max_new_tokens=256,
|
| 385 |
+
do_sample=False,
|
| 386 |
+
)
|
| 387 |
+
# text2text-generation returns list of dicts with 'generated_text'
|
| 388 |
+
answer_text = result[0]["generated_text"].strip()
|
| 389 |
+
except Exception as e:
|
| 390 |
+
print(f"Generation error: {e}")
|
| 391 |
return (
|
| 392 |
+
"There was an error while generating the answer. "
|
| 393 |
+
"Please try again with a shorter question or different wording."
|
|
|
|
| 394 |
)
|
| 395 |
|
| 396 |
+
sources_str = ", ".join(sorted(used_sources)) if used_sources else "N/A"
|
|
|
|
|
|
|
| 397 |
|
| 398 |
return (
|
| 399 |
+
f"**Answer:** {answer_text}\n\n"
|
| 400 |
+
f"**Sources:** {sources_str}"
|
|
|
|
| 401 |
)
|
| 402 |
|
| 403 |
|
|
|
|
| 491 |
)
|
| 492 |
|
| 493 |
|
| 494 |
+
# Description + optional examples
|
| 495 |
description = WELCOME_MSG
|
| 496 |
if not rag_index.initialized or rag_index.index is None or not rag_index.chunks:
|
| 497 |
description += (
|
|
|
|
| 506 |
]
|
| 507 |
if not examples and rag_index.initialized and rag_index.index is not None and rag_index.chunks:
|
| 508 |
examples = [
|
| 509 |
+
"What is a knowledge base?",
|
| 510 |
+
"What are best practices for maintaining a KB?",
|
| 511 |
+
"How should I structure knowledge base articles?",
|
| 512 |
]
|
| 513 |
|
| 514 |
|