Shubham170793 commited on
Commit
a610ce4
·
verified ·
1 Parent(s): 6dc0a8b

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +21 -29
src/qa.py CHANGED
@@ -2,7 +2,7 @@
2
  qa.py — Retrieval + Generation Layer
3
  -------------------------------------
4
  Handles:
5
- • Query embedding (MPNet / E5 / MiniLM fallback)
6
  • Chunk retrieval (FAISS)
7
  • Answer generation (Flan-T5)
8
  Optimized for Hugging Face Spaces & Streamlit.
@@ -29,36 +29,27 @@ os.environ.update({
29
  })
30
 
31
  # ==========================================================
32
- # 2️⃣ Query Embedding Model (MPNet → E5 → MiniLM)
33
  # ==========================================================
34
- # Try best retrieval model first, then gracefully degrade
35
  try:
36
  _query_model = SentenceTransformer(
37
- "sentence-transformers/all-mpnet-base-v2", # ✅ Best for QA and reasoning-heavy text
38
  cache_folder=CACHE_DIR
39
  )
40
- print("✅ Loaded query model: all-mpnet-base-v2")
41
-
42
- except Exception as e1:
43
- print(f"⚠️ MPNet load failed ({e1}), trying E5-small-v2...")
44
- try:
45
- _query_model = SentenceTransformer(
46
- "intfloat/e5-small-v2",
47
- cache_folder=CACHE_DIR
48
- )
49
- print("✅ Loaded fallback model: e5-small-v2")
50
- except Exception as e2:
51
- print(f"⚠️ E5 load failed ({e2}), falling back to MiniLM...")
52
- _query_model = SentenceTransformer(
53
- "sentence-transformers/all-MiniLM-L6-v2",
54
- cache_folder=CACHE_DIR
55
- )
56
- print("✅ Loaded fallback model: MiniLM-L6-v2")
57
 
58
  # ==========================================================
59
  # 3️⃣ LLM for Answer Generation (FLAN-T5)
60
  # ==========================================================
61
- MODEL_NAME = "google/flan-t5-base" # use 'large' if you have enough memory
62
  print(f"✅ Loading LLM: {MODEL_NAME}")
63
 
64
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
@@ -68,7 +59,7 @@ _answer_model = pipeline(
68
  "text2text-generation",
69
  model=_model,
70
  tokenizer=_tokenizer,
71
- device=-1 # CPU-safe for Hugging Face Spaces
72
  )
73
 
74
  # ==========================================================
@@ -96,16 +87,15 @@ Answer:
96
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
97
  """
98
  Encodes the user query and retrieves top-k relevant chunks via FAISS.
99
- Uses 'query:' prefix (E5 / instruction-tuned) for semantic alignment.
100
  """
101
  if not index or not chunks:
102
  return []
103
 
104
  try:
105
- # Prefix improves intent understanding (esp. for E5 / MPNet)
106
- prefix = "query: " if "e5" in _query_model.name_or_path.lower() else ""
107
  query_emb = _query_model.encode(
108
- [f"{prefix}{query.strip()}"],
109
  convert_to_numpy=True,
110
  normalize_embeddings=True
111
  )[0]
@@ -128,8 +118,10 @@ def generate_answer(query: str, retrieved_chunks: list):
128
  if not retrieved_chunks:
129
  return "Sorry, I couldn’t find relevant information in the document."
130
 
131
- # Combine chunks as structured context
132
  context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)])
 
 
133
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
134
 
135
  try:
@@ -146,7 +138,7 @@ def generate_answer(query: str, retrieved_chunks: list):
146
 
147
 
148
  # ==========================================================
149
- # 7️⃣ Optional: Local Test Run
150
  # ==========================================================
151
  if __name__ == "__main__":
152
  dummy_chunks = [
 
2
  qa.py — Retrieval + Generation Layer
3
  -------------------------------------
4
  Handles:
5
+ • Query embedding (SentenceTransformer / E5-compatible)
6
  • Chunk retrieval (FAISS)
7
  • Answer generation (Flan-T5)
8
  Optimized for Hugging Face Spaces & Streamlit.
 
29
  })
30
 
31
  # ==========================================================
32
+ # 2️⃣ Query Embedding Model
33
  # ==========================================================
34
+ # Use E5-small-v2 for retrieval consistency with embeddings.py
35
  try:
36
  _query_model = SentenceTransformer(
37
+ "intfloat/e5-small-v2",
38
  cache_folder=CACHE_DIR
39
  )
40
+ print("✅ Loaded query model: intfloat/e5-small-v2")
41
+ except Exception as e:
42
+ print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
43
+ _query_model = SentenceTransformer(
44
+ "sentence-transformers/all-MiniLM-L6-v2",
45
+ cache_folder=CACHE_DIR
46
+ )
47
+ print("✅ Loaded fallback model: all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
48
 
49
  # ==========================================================
50
  # 3️⃣ LLM for Answer Generation (FLAN-T5)
51
  # ==========================================================
52
+ MODEL_NAME = "google/flan-t5-base" # switch to 'large' if RAM allows
53
  print(f"✅ Loading LLM: {MODEL_NAME}")
54
 
55
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
 
59
  "text2text-generation",
60
  model=_model,
61
  tokenizer=_tokenizer,
62
+ device=-1 # CPU-safe for Spaces
63
  )
64
 
65
  # ==========================================================
 
87
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
88
  """
89
  Encodes the user query and retrieves top-k relevant chunks via FAISS.
90
+ Uses 'query:' prefix (E5 training style) for semantic alignment.
91
  """
92
  if not index or not chunks:
93
  return []
94
 
95
  try:
96
+ # E5 expects 'query:' prefix for better retrieval accuracy
 
97
  query_emb = _query_model.encode(
98
+ [f"query: {query.strip()}"],
99
  convert_to_numpy=True,
100
  normalize_embeddings=True
101
  )[0]
 
118
  if not retrieved_chunks:
119
  return "Sorry, I couldn’t find relevant information in the document."
120
 
121
+ # Merge retrieved chunks for context
122
  context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)])
123
+
124
+ # Build structured prompt
125
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
126
 
127
  try:
 
138
 
139
 
140
  # ==========================================================
141
+ # 7️⃣ Optional Local Test
142
  # ==========================================================
143
  if __name__ == "__main__":
144
  dummy_chunks = [