Shubham170793 commited on
Commit
fea3890
·
verified ·
1 Parent(s): 28e4d2b

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +111 -85
src/qa.py CHANGED
@@ -3,22 +3,31 @@ qa.py — Retrieval + Generation Layer
3
  -------------------------------------
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
- • Chunk retrieval (FAISS with neighborhood merging + re-ranking)
7
- • Answer generation (Flan-T5, tuned for factual completeness)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
11
  import os
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
14
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
15
  from vectorstore import search_faiss
16
  from sklearn.metrics.pairwise import cosine_similarity
17
 
 
 
 
 
 
 
 
 
 
 
18
  print("✅ qa.py loaded from:", __file__)
19
 
20
  # ==========================================================
21
- # 1️⃣ Hugging Face Cache Setup (Safe for Spaces)
22
  # ==========================================================
23
  CACHE_DIR = "/tmp/hf_cache"
24
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -33,111 +42,128 @@ os.environ.update({
33
  # 2️⃣ Query Embedding Model
34
  # ==========================================================
35
  try:
36
- _query_model = SentenceTransformer(
37
- "intfloat/e5-small-v2",
38
- cache_folder=CACHE_DIR
39
- )
40
  print("✅ Loaded query model: intfloat/e5-small-v2")
41
  except Exception as e:
42
- print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
43
- _query_model = SentenceTransformer(
44
- "sentence-transformers/all-MiniLM-L6-v2",
45
- cache_folder=CACHE_DIR
46
- )
47
- print("✅ Loaded fallback model: all-MiniLM-L6-v2")
48
 
49
  # ==========================================================
50
- # 3️⃣ LLM for Answer Generation (OpenAI GPT with Flan fallback)
51
  # ==========================================================
52
- from openai import OpenAI
53
- client = None
54
-
55
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
56
- if OPENAI_API_KEY:
57
- client = OpenAI(api_key=OPENAI_API_KEY)
58
- LLM_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
59
- print(f"✅ Using OpenAI model: {LLM_MODEL}")
60
- else:
61
- # Fallback to Flan if no API key is provided
62
  MODEL_NAME = "google/flan-t5-base"
63
- print(f"⚠️ No OpenAI key found. Using fallback model: {MODEL_NAME}")
64
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
65
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
66
- _answer_model = pipeline(
67
- "text2text-generation",
68
- model=_model,
69
- tokenizer=_tokenizer,
70
- device=-1
71
- )
72
-
73
 
74
  # ==========================================================
75
- # 6️⃣ Answer Generation Function (GPT or Flan fallback)
76
  # ==========================================================
77
- def generate_answer(query: str, retrieved_chunks: list):
78
- """
79
- Generates grounded, context-only answers.
80
- Uses GPT (preferred) or Flan-T5 (fallback) for response synthesis.
81
- """
82
- if not retrieved_chunks:
83
- return "Sorry, I couldn’t find relevant information in the document."
84
-
85
- # Combine retrieved chunks
86
- context = "\n\n".join([
87
- f"[Chunk {i+1}]: {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks)
88
- ])
89
-
90
- # --- PROMPT TEMPLATE ---
91
- system_prompt = """You are an enterprise knowledge assistant.
92
- Use ONLY the provided context to answer the user's question accurately.
93
- If the answer is not explicitly in the context, reply exactly:
94
  "I don't know based on the provided document."
95
- Be factual, concise, and structured when relevant.
96
- """
97
 
98
- user_prompt = f"""
99
  Context:
100
  {context}
101
-
102
  Question:
103
  {query}
104
-
105
  Answer:
106
  """
107
 
108
- # --- Use OpenAI GPT if key available ---
109
- if client:
110
- try:
111
- response = client.chat.completions.create(
112
- model=LLM_MODEL,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  messages=[
114
- {"role": "system", "content": system_prompt},
115
- {"role": "user", "content": user_prompt},
116
  ],
117
- temperature=0.2, # factual, low creativity
118
- max_tokens=500,
119
- presence_penalty=0,
120
- frequency_penalty=0
121
  )
122
- answer = response.choices[0].message.content.strip()
 
 
 
 
123
  return answer
124
- except Exception as e:
125
- print(f"⚠️ OpenAI generation failed: {e}")
126
- return "⚠️ Error: Could not generate an answer at the moment."
127
 
128
- # --- Otherwise, use Flan-T5 fallback ---
129
- try:
130
- result = _answer_model(
131
- PROMPT_TEMPLATE.format(context=context, query=query),
132
- max_new_tokens=600,
133
- do_sample=False,
134
- temperature=0.3,
135
- repetition_penalty=1.1
136
- )
137
- answer = result[0]["generated_text"].strip()
138
- if "I don't know" in answer:
139
- return "I don't know based on the provided document."
140
- return answer
141
  except Exception as e:
142
- print(f"⚠️ Flan generation failed: {e}")
143
  return "⚠️ Error: Could not generate an answer at the moment."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  -------------------------------------
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
+ • Chunk retrieval (FAISS)
7
+ • Answer generation (OpenAI or Flan-T5 fallback)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
11
  import os
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
 
14
  from vectorstore import search_faiss
15
  from sklearn.metrics.pairwise import cosine_similarity
16
 
17
+ # Optional: use OpenAI if API key available
18
+ USE_OPENAI = bool(os.getenv("OPENAI_API_KEY"))
19
+ if USE_OPENAI:
20
+ from openai import OpenAI
21
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
22
+ print("✅ Using OpenAI for answer generation")
23
+ else:
24
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
25
+ print("⚙️ Using fallback FLAN-T5 model (local)")
26
+
27
  print("✅ qa.py loaded from:", __file__)
28
 
29
  # ==========================================================
30
+ # 1️⃣ Hugging Face Cache Setup
31
  # ==========================================================
32
  CACHE_DIR = "/tmp/hf_cache"
33
  os.makedirs(CACHE_DIR, exist_ok=True)
 
42
  # 2️⃣ Query Embedding Model
43
  # ==========================================================
44
  try:
45
+ _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
 
 
 
46
  print("✅ Loaded query model: intfloat/e5-small-v2")
47
  except Exception as e:
48
+ print(f"⚠️ Query model load failed ({e}), using fallback MiniLM.")
49
+ _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
 
 
 
 
50
 
51
  # ==========================================================
52
+ # 3️⃣ Fallback LLM (if no OpenAI key)
53
  # ==========================================================
54
+ if not USE_OPENAI:
 
 
 
 
 
 
 
 
 
55
  MODEL_NAME = "google/flan-t5-base"
 
56
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
57
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
58
+ _answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1)
 
 
 
 
 
 
59
 
60
  # ==========================================================
61
+ # 4️⃣ Prompt Template
62
  # ==========================================================
63
+ PROMPT_TEMPLATE = """
64
+ You are an enterprise knowledge assistant.
65
+ Use ONLY the CONTEXT below to answer the QUESTION clearly and factually.
66
+ If the context doesn’t contain the answer, reply exactly:
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  "I don't know based on the provided document."
 
 
68
 
69
+ ---
70
  Context:
71
  {context}
72
+ ---
73
  Question:
74
  {query}
75
+ ---
76
  Answer:
77
  """
78
 
79
+ # ==========================================================
80
+ # 5️⃣ Chunk Retrieval Function
81
+ # ==========================================================
82
+ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
83
+ """Retrieve top-K relevant chunks, merge nearby ones, and re-rank by cosine similarity."""
84
+ if not index or not chunks:
85
+ return []
86
+
87
+ try:
88
+ query_emb = _query_model.encode(
89
+ [f"query: {query.strip()}"],
90
+ convert_to_numpy=True,
91
+ normalize_embeddings=True
92
+ )[0]
93
+
94
+ distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
95
+ merged_chunks = []
96
+ for idx in indices[0]:
97
+ neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
98
+ merged_chunks.append(" ".join(neighbors))
99
+
100
+ chunk_vecs = np.array([
101
+ _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
102
+ for c in merged_chunks
103
+ ])
104
+ scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0]
105
+ sorted_indices = np.argsort(scores)[::-1]
106
+
107
+ return [merged_chunks[i] for i in sorted_indices[:top_k]]
108
+
109
+ except Exception as e:
110
+ print(f"⚠️ Retrieval error: {e}")
111
+ return []
112
+
113
+ # ==========================================================
114
+ # 6️⃣ Answer Generation Function
115
+ # ==========================================================
116
+ def generate_answer(query: str, retrieved_chunks: list):
117
+ """Generate factual, complete answers using OpenAI or FLAN."""
118
+ if not retrieved_chunks:
119
+ return "Sorry, I couldn’t find relevant information in the document."
120
+
121
+ context = "\n\n".join([
122
+ f"[Chunk {i+1}]: {chunk.strip()}"
123
+ for i, chunk in enumerate(retrieved_chunks)
124
+ ])
125
+ prompt = PROMPT_TEMPLATE.format(context=context, query=query)
126
+
127
+ try:
128
+ if USE_OPENAI:
129
+ completion = client.chat.completions.create(
130
+ model="gpt-4o-mini",
131
  messages=[
132
+ {"role": "system", "content": "You are a precise enterprise document assistant."},
133
+ {"role": "user", "content": prompt},
134
  ],
135
+ temperature=0.4,
136
+ max_tokens=600,
 
 
137
  )
138
+ return completion.choices[0].message.content.strip()
139
+
140
+ else:
141
+ result = _answer_model(prompt, max_new_tokens=600, do_sample=False, temperature=0.3)
142
+ answer = result[0]["generated_text"].strip()
143
  return answer
 
 
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  except Exception as e:
146
+ print(f"⚠️ Generation failed: {e}")
147
  return "⚠️ Error: Could not generate an answer at the moment."
148
+
149
+
150
+ # ==========================================================
151
+ # 7️⃣ Local Test
152
+ # ==========================================================
153
+ if __name__ == "__main__":
154
+ dummy_chunks = [
155
+ "Step 1: Open the dashboard and navigate to reports.",
156
+ "Step 2: Click 'Export' to download a CSV summary.",
157
+ "Step 3: Review the generated report in your downloads folder."
158
+ ]
159
+ from vectorstore import build_faiss_index
160
+
161
+ index = build_faiss_index([
162
+ _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
163
+ for chunk in dummy_chunks
164
+ ])
165
+
166
+ query = "What are the steps to export a report?"
167
+ retrieved = retrieve_chunks(query, index, dummy_chunks)
168
+ print("🔍 Retrieved:", retrieved)
169
+ print("💬 Answer:", generate_answer(query, retrieved))