Shubham170793 commited on
Commit
5491531
·
verified ·
1 Parent(s): e2d3059

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +40 -51
src/qa.py CHANGED
@@ -3,21 +3,34 @@ qa.py — Retrieval + Generation Layer
3
  -------------------------------------
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
- • Chunk retrieval (FAISS with neighborhood merging + re-ranking)
7
- • Answer generation (OpenAI GPT-4o-mini or fallback to Flan-T5)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
11
  import os
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
14
- from vectorstore import search_faiss
15
  from sklearn.metrics.pairwise import cosine_similarity
 
16
 
17
- print("✅ qa.py loaded from:", __file__)
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # ==========================================================
20
- # 1️⃣ Hugging Face Cache Setup
21
  # ==========================================================
22
  CACHE_DIR = "/tmp/hf_cache"
23
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -29,39 +42,21 @@ os.environ.update({
29
  })
30
 
31
  # ==========================================================
32
- # 2️⃣ OpenAI Integration (with safe fallback)
33
- # ==========================================================
34
- # ⚠️ TEMPORARY: You can hardcode your key here for testing
35
- os.environ["OPENAI_API_KEY"] = "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"
36
-
37
- USE_OPENAI = bool(os.getenv("OPENAI_API_KEY"))
38
-
39
- if USE_OPENAI:
40
- try:
41
- from openai import OpenAI
42
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
43
- print("✅ Using OpenAI GPT-4o-mini for answer generation")
44
- except Exception as e:
45
- print(f"⚠️ OpenAI client initialization failed: {e}")
46
- USE_OPENAI = False
47
-
48
- # ==========================================================
49
- # 3️⃣ Query Embedding Model
50
  # ==========================================================
51
  try:
52
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
53
- print("✅ Loaded query model: intfloat/e5-small-v2")
54
  except Exception as e:
55
- print(f"⚠️ Query model load failed ({e}), using fallback MiniLM.")
56
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
 
57
 
58
  # ==========================================================
59
- # 4️⃣ Fallback LLM (if no OpenAI key or quota exhausted)
60
  # ==========================================================
61
  if not USE_OPENAI:
62
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
63
  MODEL_NAME = "google/flan-t5-base"
64
- print(f"⚙️ Using fallback model: {MODEL_NAME}")
65
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
66
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
67
  _answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1)
@@ -71,8 +66,8 @@ if not USE_OPENAI:
71
  # ==========================================================
72
  PROMPT_TEMPLATE = """
73
  You are an enterprise knowledge assistant.
74
- Use ONLY the CONTEXT below to answer the QUESTION clearly and factually.
75
- If the context doesn’t contain the answer, reply exactly:
76
  "I don't know based on the provided document."
77
 
78
  ---
@@ -86,10 +81,10 @@ Answer:
86
  """
87
 
88
  # ==========================================================
89
- # 6️⃣ Chunk Retrieval Function
90
  # ==========================================================
91
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
92
- """Retrieve top-K relevant chunks, merge nearby ones, and re-rank by cosine similarity."""
93
  if not index or not chunks:
94
  return []
95
 
@@ -100,12 +95,14 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
100
  normalize_embeddings=True
101
  )[0]
102
 
 
103
  distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
104
  merged_chunks = []
105
  for idx in indices[0]:
106
  neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
107
  merged_chunks.append(" ".join(neighbors))
108
 
 
109
  chunk_vecs = np.array([
110
  _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
111
  for c in merged_chunks
@@ -120,46 +117,38 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
120
  return []
121
 
122
  # ==========================================================
123
- # 7️⃣ Answer Generation Function
124
  # ==========================================================
125
  def generate_answer(query: str, retrieved_chunks: list):
126
- """Generate factual, complete answers using OpenAI (or Flan-T5 fallback)."""
127
  if not retrieved_chunks:
128
  return "Sorry, I couldn’t find relevant information in the document."
129
 
130
- context = "\n\n".join([
131
- f"[Chunk {i+1}]: {chunk.strip()}"
132
- for i, chunk in enumerate(retrieved_chunks)
133
- ])
134
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
135
 
136
  try:
137
  if USE_OPENAI:
138
- completion = client.chat.completions.create(
139
- model="gpt-3.5-turbo",
140
  messages=[
141
- {"role": "system", "content": "You are a precise enterprise document assistant."},
142
  {"role": "user", "content": prompt},
143
  ],
144
  temperature=0.4,
145
  max_tokens=800,
146
  )
147
- return completion.choices[0].message.content.strip()
148
 
149
  else:
150
  result = _answer_model(prompt, max_new_tokens=600, do_sample=False, temperature=0.3)
151
- answer = result[0]["generated_text"].strip()
152
- return answer
153
 
154
  except Exception as e:
155
  print(f"⚠️ Generation failed: {e}")
156
- # Auto fallback to Flan-T5 if OpenAI fails mid-session
157
- if USE_OPENAI:
158
- try:
159
- result = _answer_model(prompt, max_new_tokens=600, do_sample=False, temperature=0.3)
160
- return result[0]["generated_text"].strip()
161
- except Exception as e2:
162
- print(f"⚠️ Fallback model also failed: {e2}")
163
  return "⚠️ Error: Could not generate an answer at the moment."
164
 
165
  # ==========================================================
 
3
  -------------------------------------
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
+ • Chunk retrieval (FAISS + cosine re-ranking)
7
+ • Answer generation (OpenAI GPT-4o-mini or FLAN-T5 fallback)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
11
  import os
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
 
14
  from sklearn.metrics.pairwise import cosine_similarity
15
+ from vectorstore import search_faiss
16
 
17
+ # ==========================================================
18
+ # 1️⃣ Load OpenAI if key available
19
+ # ==========================================================
20
+ USE_OPENAI = bool(os.getenv("OPENAI_API_KEY"))
21
+
22
+ if USE_OPENAI:
23
+ from openai import OpenAI
24
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
25
+ print("✅ Using OpenAI GPT-4o-mini for answer generation")
26
+ else:
27
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
28
+ print("⚙️ No OpenAI key found — using fallback FLAN-T5 model")
29
+
30
+ print("✅ qa.py loaded successfully")
31
 
32
  # ==========================================================
33
+ # 2️⃣ Hugging Face Cache Setup (Safe for Spaces)
34
  # ==========================================================
35
  CACHE_DIR = "/tmp/hf_cache"
36
  os.makedirs(CACHE_DIR, exist_ok=True)
 
42
  })
43
 
44
  # ==========================================================
45
+ # 3️⃣ Embedding Model (E5 for better retrieval)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # ==========================================================
47
  try:
48
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
49
+ print("✅ Loaded embedding model: intfloat/e5-small-v2")
50
  except Exception as e:
51
+ print(f"⚠️ Failed to load e5-small-v2 ({e}), switching to MiniLM.")
52
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
53
+ print("✅ Loaded fallback: all-MiniLM-L6-v2")
54
 
55
  # ==========================================================
56
+ # 4️⃣ Fallback Model (FLAN-T5)
57
  # ==========================================================
58
  if not USE_OPENAI:
 
59
  MODEL_NAME = "google/flan-t5-base"
 
60
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
61
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
62
  _answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1)
 
66
  # ==========================================================
67
  PROMPT_TEMPLATE = """
68
  You are an enterprise knowledge assistant.
69
+ Use ONLY the context below to answer the question clearly, precisely, and factually.
70
+ If the context doesn’t contain the answer, say exactly:
71
  "I don't know based on the provided document."
72
 
73
  ---
 
81
  """
82
 
83
  # ==========================================================
84
+ # 6️⃣ Chunk Retrieval
85
  # ==========================================================
86
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
87
+ """Retrieve top-K relevant chunks and re-rank by semantic similarity."""
88
  if not index or not chunks:
89
  return []
90
 
 
95
  normalize_embeddings=True
96
  )[0]
97
 
98
+ # Retrieve more and then re-rank
99
  distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
100
  merged_chunks = []
101
  for idx in indices[0]:
102
  neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
103
  merged_chunks.append(" ".join(neighbors))
104
 
105
+ # Re-rank by cosine similarity
106
  chunk_vecs = np.array([
107
  _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
108
  for c in merged_chunks
 
117
  return []
118
 
119
  # ==========================================================
120
+ # 7️⃣ Answer Generation
121
  # ==========================================================
122
  def generate_answer(query: str, retrieved_chunks: list):
123
+ """Generate factual answer using OpenAI GPT-4o-mini (preferred) or FLAN fallback."""
124
  if not retrieved_chunks:
125
  return "Sorry, I couldn’t find relevant information in the document."
126
 
127
+ # Merge retrieved chunks
128
+ context = "\n\n".join(
129
+ [f"[Chunk {i+1}]: {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks)]
130
+ )
131
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
132
 
133
  try:
134
  if USE_OPENAI:
135
+ response = client.chat.completions.create(
136
+ model="gpt-4o-mini",
137
  messages=[
138
+ {"role": "system", "content": "You are a precise enterprise assistant that answers only from the provided context."},
139
  {"role": "user", "content": prompt},
140
  ],
141
  temperature=0.4,
142
  max_tokens=800,
143
  )
144
+ return response.choices[0].message.content.strip()
145
 
146
  else:
147
  result = _answer_model(prompt, max_new_tokens=600, do_sample=False, temperature=0.3)
148
+ return result[0]["generated_text"].strip()
 
149
 
150
  except Exception as e:
151
  print(f"⚠️ Generation failed: {e}")
 
 
 
 
 
 
 
152
  return "⚠️ Error: Could not generate an answer at the moment."
153
 
154
  # ==========================================================