Shubham170793 commited on
Commit
b41f253
·
verified ·
1 Parent(s): c220dec

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +36 -21
src/qa.py CHANGED
@@ -3,8 +3,8 @@ qa.py — Retrieval + Generation Layer
3
  -------------------------------------
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
- • Chunk retrieval (FAISS)
7
- • Answer generation (OpenAI or Flan-T5 fallback)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
@@ -14,16 +14,6 @@ from sentence_transformers import SentenceTransformer
14
  from vectorstore import search_faiss
15
  from sklearn.metrics.pairwise import cosine_similarity
16
 
17
- # Optional: use OpenAI if API key available
18
- USE_OPENAI = bool(os.getenv("OPENAI_API_KEY"))
19
- if USE_OPENAI:
20
- from openai import OpenAI
21
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
22
- print("✅ Using OpenAI for answer generation")
23
- else:
24
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
25
- print("⚙️ Using fallback FLAN-T5 model (local)")
26
-
27
  print("✅ qa.py loaded from:", __file__)
28
 
29
  # ==========================================================
@@ -39,7 +29,24 @@ os.environ.update({
39
  })
40
 
41
  # ==========================================================
42
- # 2️⃣ Query Embedding Model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # ==========================================================
44
  try:
45
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
@@ -49,16 +56,18 @@ except Exception as e:
49
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
50
 
51
  # ==========================================================
52
- # 3️⃣ Fallback LLM (if no OpenAI key)
53
  # ==========================================================
54
  if not USE_OPENAI:
 
55
  MODEL_NAME = "google/flan-t5-base"
 
56
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
57
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
58
  _answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1)
59
 
60
  # ==========================================================
61
- # 4️⃣ Prompt Template
62
  # ==========================================================
63
  PROMPT_TEMPLATE = """
64
  You are an enterprise knowledge assistant.
@@ -77,7 +86,7 @@ Answer:
77
  """
78
 
79
  # ==========================================================
80
- # 5️⃣ Chunk Retrieval Function
81
  # ==========================================================
82
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
83
  """Retrieve top-K relevant chunks, merge nearby ones, and re-rank by cosine similarity."""
@@ -111,10 +120,10 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
111
  return []
112
 
113
  # ==========================================================
114
- # 6️⃣ Answer Generation Function
115
  # ==========================================================
116
  def generate_answer(query: str, retrieved_chunks: list):
117
- """Generate factual, complete answers using OpenAI or FLAN."""
118
  if not retrieved_chunks:
119
  return "Sorry, I couldn’t find relevant information in the document."
120
 
@@ -133,7 +142,7 @@ def generate_answer(query: str, retrieved_chunks: list):
133
  {"role": "user", "content": prompt},
134
  ],
135
  temperature=0.4,
136
- max_tokens=600,
137
  )
138
  return completion.choices[0].message.content.strip()
139
 
@@ -144,11 +153,17 @@ def generate_answer(query: str, retrieved_chunks: list):
144
 
145
  except Exception as e:
146
  print(f"⚠️ Generation failed: {e}")
 
 
 
 
 
 
 
147
  return "⚠️ Error: Could not generate an answer at the moment."
148
 
149
-
150
  # ==========================================================
151
- # 7️⃣ Local Test
152
  # ==========================================================
153
  if __name__ == "__main__":
154
  dummy_chunks = [
 
3
  -------------------------------------
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
+ • Chunk retrieval (FAISS with neighborhood merging + re-ranking)
7
+ • Answer generation (OpenAI GPT-4o-mini or fallback to Flan-T5)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
 
14
  from vectorstore import search_faiss
15
  from sklearn.metrics.pairwise import cosine_similarity
16
 
 
 
 
 
 
 
 
 
 
 
17
  print("✅ qa.py loaded from:", __file__)
18
 
19
  # ==========================================================
 
29
  })
30
 
31
  # ==========================================================
32
+ # 2️⃣ OpenAI Integration (with safe fallback)
33
+ # ==========================================================
34
+ # ⚠️ TEMPORARY: You can hardcode your key here for testing
35
+ os.environ["OPENAI_API_KEY"] = "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"
36
+
37
+ USE_OPENAI = bool(os.getenv("OPENAI_API_KEY"))
38
+
39
+ if USE_OPENAI:
40
+ try:
41
+ from openai import OpenAI
42
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
43
+ print("✅ Using OpenAI GPT-4o-mini for answer generation")
44
+ except Exception as e:
45
+ print(f"⚠️ OpenAI client initialization failed: {e}")
46
+ USE_OPENAI = False
47
+
48
+ # ==========================================================
49
+ # 3️⃣ Query Embedding Model
50
  # ==========================================================
51
  try:
52
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
 
56
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
57
 
58
  # ==========================================================
59
+ # 4️⃣ Fallback LLM (if no OpenAI key or quota exhausted)
60
  # ==========================================================
61
  if not USE_OPENAI:
62
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
63
  MODEL_NAME = "google/flan-t5-base"
64
+ print(f"⚙️ Using fallback model: {MODEL_NAME}")
65
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
66
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
67
  _answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1)
68
 
69
  # ==========================================================
70
+ # 5️⃣ Prompt Template
71
  # ==========================================================
72
  PROMPT_TEMPLATE = """
73
  You are an enterprise knowledge assistant.
 
86
  """
87
 
88
  # ==========================================================
89
+ # 6️⃣ Chunk Retrieval Function
90
  # ==========================================================
91
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
92
  """Retrieve top-K relevant chunks, merge nearby ones, and re-rank by cosine similarity."""
 
120
  return []
121
 
122
  # ==========================================================
123
+ # 7️⃣ Answer Generation Function
124
  # ==========================================================
125
  def generate_answer(query: str, retrieved_chunks: list):
126
+ """Generate factual, complete answers using OpenAI (or Flan-T5 fallback)."""
127
  if not retrieved_chunks:
128
  return "Sorry, I couldn’t find relevant information in the document."
129
 
 
142
  {"role": "user", "content": prompt},
143
  ],
144
  temperature=0.4,
145
+ max_tokens=800,
146
  )
147
  return completion.choices[0].message.content.strip()
148
 
 
153
 
154
  except Exception as e:
155
  print(f"⚠️ Generation failed: {e}")
156
+ # Auto fallback to Flan-T5 if OpenAI fails mid-session
157
+ if USE_OPENAI:
158
+ try:
159
+ result = _answer_model(prompt, max_new_tokens=600, do_sample=False, temperature=0.3)
160
+ return result[0]["generated_text"].strip()
161
+ except Exception as e2:
162
+ print(f"⚠️ Fallback model also failed: {e2}")
163
  return "⚠️ Error: Could not generate an answer at the moment."
164
 
 
165
  # ==========================================================
166
+ # 8️⃣ Local Test
167
  # ==========================================================
168
  if __name__ == "__main__":
169
  dummy_chunks = [