sofzcc commited on
Commit
27759ba
·
verified ·
1 Parent(s): 93b82c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -48
app.py CHANGED
@@ -38,16 +38,19 @@ def get_default_config():
38
  "index_directory": "./index",
39
  },
40
  "models": {
41
- # You can also use "all-MiniLM-L6-v2" here, but this path works well on HF
42
  "embedding": "sentence-transformers/all-MiniLM-L6-v2",
43
- "qa": "deepset/roberta-base-squad2",
 
44
  },
45
  "chunking": {
46
- "chunk_size": 500,
47
- "overlap": 50,
 
48
  },
49
  "thresholds": {
50
- "similarity": 0.3,
 
51
  },
52
  "messages": {
53
  "welcome": "Ask me anything about the documents in the knowledge base!",
@@ -178,7 +181,7 @@ def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
178
  class RAGIndex:
179
  def __init__(self):
180
  self.embedder = None
181
- self.qa_pipeline = None
182
  self.chunks: List[str] = []
183
  self.chunk_sources: List[str] = []
184
  self.index = None
@@ -200,12 +203,12 @@ class RAGIndex:
200
  print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
201
  self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
202
 
203
- print(f"Loading QA model: {QA_MODEL_NAME}")
 
204
  self.qa_pipeline = pipeline(
205
- "question-answering",
206
- model=AutoModelForQuestionAnswering.from_pretrained(QA_MODEL_NAME),
207
- tokenizer=AutoTokenizer.from_pretrained(QA_MODEL_NAME),
208
- handle_impossible_answer=True,
209
  )
210
  except Exception as e:
211
  print(f"Error loading models: {e}")
@@ -327,7 +330,7 @@ class RAGIndex:
327
  return []
328
 
329
  def answer(self, question: str) -> str:
330
- """Answer a question using RAG"""
331
  if not self.initialized:
332
  return "❌ Assistant not properly initialized. Please check the logs."
333
 
@@ -350,45 +353,51 @@ class RAGIndex:
350
  f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
351
  )
352
 
353
- # Try to extract answer from each context
354
- answers = []
355
- for ctx, source, score in contexts:
356
- # Truncate context if too long (max 512 tokens for most QA models)
357
- max_context_length = 2000 # characters, roughly 512 tokens
358
- truncated_ctx = ctx[:max_context_length]
359
-
360
- qa_input = {"question": question, "context": truncated_ctx}
361
-
362
- try:
363
- result = self.qa_pipeline(qa_input)
364
- answer_text = result.get("answer", "").strip()
365
- answer_score = result.get("score", 0.0)
366
-
367
- if answer_text and answer_score > 0.01: # Minimum confidence threshold
368
- answers.append((answer_text, source, answer_score, score))
369
 
370
- except Exception as e:
371
- print(f"QA error on context from {source}: {e}")
372
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
- if not answers:
375
- # Provide context even if no specific answer found
376
- best_ctx, best_src, best_score = contexts[0]
377
- preview = best_ctx[:300] + "..." if len(best_ctx) > 300 else best_ctx
 
 
 
 
 
 
378
  return (
379
- f"I found relevant information but couldn't extract a specific answer.\n\n"
380
- f"**Relevant context from {best_src}:**\n{preview}\n\n"
381
- f"💡 Try asking a more specific question."
382
  )
383
 
384
- # Pick best answer (weighted by both retrieval and QA scores)
385
- answers.sort(key=lambda x: x[2] * x[3], reverse=True)
386
- best_answer, best_source, qa_score, retrieval_score = answers[0]
387
 
388
  return (
389
- f"**Answer:** {best_answer}\n\n"
390
- f"**Source:** {best_source}\n"
391
- f"**Confidence:** {qa_score:.2%}"
392
  )
393
 
394
 
@@ -482,7 +491,7 @@ def rebuild_index():
482
  )
483
 
484
 
485
- # Description + (optional) examples
486
  description = WELCOME_MSG
487
  if not rag_index.initialized or rag_index.index is None or not rag_index.chunks:
488
  description += (
@@ -497,9 +506,9 @@ examples = [
497
  ]
498
  if not examples and rag_index.initialized and rag_index.index is not None and rag_index.chunks:
499
  examples = [
500
- "What is this document about?",
501
- "Can you summarize the main points?",
502
- "What are the key findings?",
503
  ]
504
 
505
 
 
38
  "index_directory": "./index",
39
  },
40
  "models": {
41
+ # Embedding model for FAISS
42
  "embedding": "sentence-transformers/all-MiniLM-L6-v2",
43
+ # Abstractive generation model (can upgrade to flan-t5-base if resources allow)
44
+ "qa": "google/flan-t5-small",
45
  },
46
  "chunking": {
47
+ # Larger chunks -> better conceptual coverage
48
+ "chunk_size": 1200,
49
+ "overlap": 200,
50
  },
51
  "thresholds": {
52
+ # More permissive to not miss relevant chunks
53
+ "similarity": 0.1,
54
  },
55
  "messages": {
56
  "welcome": "Ask me anything about the documents in the knowledge base!",
 
181
  class RAGIndex:
182
  def __init__(self):
183
  self.embedder = None
184
+ self.qa_pipeline = None # now a generative pipeline
185
  self.chunks: List[str] = []
186
  self.chunk_sources: List[str] = []
187
  self.index = None
 
203
  print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
204
  self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
205
 
206
+ print(f"Loading QA (generation) model: {QA_MODEL_NAME}")
207
+ # Abstractive generation pipeline (Flan-T5)
208
  self.qa_pipeline = pipeline(
209
+ "text2text-generation",
210
+ model=QA_MODEL_NAME,
211
+ tokenizer=QA_MODEL_NAME,
 
212
  )
213
  except Exception as e:
214
  print(f"Error loading models: {e}")
 
330
  return []
331
 
332
  def answer(self, question: str) -> str:
333
+ """Answer a question using RAG + abstractive generation"""
334
  if not self.initialized:
335
  return "❌ Assistant not properly initialized. Please check the logs."
336
 
 
353
  f"💡 Try rephrasing your question or check if relevant documents exist in the knowledge base."
354
  )
355
 
356
+ # Combine contexts into a single block and track sources
357
+ combined_context = []
358
+ used_sources = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
+ for ctx, source, score in contexts:
361
+ used_sources.add(source)
362
+ combined_context.append(f"[Source: {source}]\n{ctx}")
363
+
364
+ combined_text = "\n\n".join(combined_context)
365
+
366
+ # Limit context length to keep it manageable for the model
367
+ max_context_chars = 4000
368
+ if len(combined_text) > max_context_chars:
369
+ combined_text = combined_text[:max_context_chars]
370
+
371
+ # Prompt for the generative model
372
+ prompt = (
373
+ "You are an AI assistant that answers questions using only the provided context. "
374
+ "If the answer cannot be found in the context, reply exactly with: "
375
+ "\"I don't know based on the provided documents.\"\n\n"
376
+ f"Context:\n{combined_text}\n\n"
377
+ f"Question: {question}\n\n"
378
+ "Answer:"
379
+ )
380
 
381
+ try:
382
+ result = self.qa_pipeline(
383
+ prompt,
384
+ max_new_tokens=256,
385
+ do_sample=False,
386
+ )
387
+ # text2text-generation returns list of dicts with 'generated_text'
388
+ answer_text = result[0]["generated_text"].strip()
389
+ except Exception as e:
390
+ print(f"Generation error: {e}")
391
  return (
392
+ "There was an error while generating the answer. "
393
+ "Please try again with a shorter question or different wording."
 
394
  )
395
 
396
+ sources_str = ", ".join(sorted(used_sources)) if used_sources else "N/A"
 
 
397
 
398
  return (
399
+ f"**Answer:** {answer_text}\n\n"
400
+ f"**Sources:** {sources_str}"
 
401
  )
402
 
403
 
 
491
  )
492
 
493
 
494
+ # Description + optional examples
495
  description = WELCOME_MSG
496
  if not rag_index.initialized or rag_index.index is None or not rag_index.chunks:
497
  description += (
 
506
  ]
507
  if not examples and rag_index.initialized and rag_index.index is not None and rag_index.chunks:
508
  examples = [
509
+ "What is a knowledge base?",
510
+ "What are best practices for maintaining a KB?",
511
+ "How should I structure knowledge base articles?",
512
  ]
513
 
514