Shubham170793 commited on
Commit
43b802c
·
verified ·
1 Parent(s): 3a56dbd

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +48 -26
src/qa.py CHANGED
@@ -2,7 +2,7 @@
2
  qa.py — Retrieval + Generation Layer
3
  -------------------------------------
4
  Handles:
5
- • Query embedding (SentenceTransformer)
6
  • Chunk retrieval (FAISS)
7
  • Answer generation (Flan-T5)
8
  Optimized for Hugging Face Spaces & Streamlit.
@@ -16,7 +16,7 @@ from vectorstore import search_faiss
16
  print("✅ qa.py loaded from:", __file__)
17
 
18
  # ==========================================================
19
- # 1️⃣ Cache Configuration (Hugging Face safe /tmp folder)
20
  # ==========================================================
21
  CACHE_DIR = "/tmp/hf_cache"
22
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -29,37 +29,46 @@ os.environ.update({
29
  })
30
 
31
  # ==========================================================
32
- # 2️⃣ Embedding Model (for Query Encoding)
33
  # ==========================================================
34
- _query_model = SentenceTransformer(
35
- "sentence-transformers/all-MiniLM-L6-v2",
36
- cache_folder=CACHE_DIR
37
- )
38
- print("✅ Loaded embedding model: all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
39
 
40
  # ==========================================================
41
- # 3️⃣ LLM for Answers (Google FLAN-T5)
42
  # ==========================================================
43
- MODEL_NAME = "google/flan-t5-base" # lighter & faster; can switch to 'large' for higher accuracy
44
  print(f"✅ Loading LLM: {MODEL_NAME}")
45
 
46
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
47
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
48
 
49
- # Efficient text2text generation pipeline
50
  _answer_model = pipeline(
51
  "text2text-generation",
52
  model=_model,
53
  tokenizer=_tokenizer,
54
- device=-1 # ensures CPU-safe execution (avoid GPU dependency)
55
  )
56
 
57
  # ==========================================================
58
- # 4️⃣ Prompt Template
59
  # ==========================================================
60
- PROMPT_TEMPLATE = """You are an expert enterprise assistant.
61
- Using ONLY the context provided below, answer the question clearly and factually.
62
- If the context doesn’t contain the answer, reply exactly:
 
63
  "I don't know based on the provided document."
64
 
65
  ---
@@ -69,22 +78,31 @@ Context:
69
  Question:
70
  {query}
71
  ---
72
- Answer:"""
 
73
 
74
  # ==========================================================
75
- # 5️⃣ Retrieval Function
76
  # ==========================================================
77
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
78
  """
79
- Encodes the user query and retrieves top-k most relevant chunks from FAISS.
 
80
  """
81
  if not index or not chunks:
82
  return []
83
 
84
  try:
85
- q_emb = _query_model.encode([query], convert_to_numpy=True)[0]
86
- results = search_faiss(q_emb, index, chunks, top_k)
 
 
 
 
 
 
87
  return results
 
88
  except Exception as e:
89
  print(f"⚠️ Retrieval error: {e}")
90
  return []
@@ -100,17 +118,18 @@ def generate_answer(query: str, retrieved_chunks: list):
100
  if not retrieved_chunks:
101
  return "Sorry, I couldn’t find relevant information in the document."
102
 
103
- # Merge top chunks into one context block
104
  context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)])
105
 
 
106
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
107
 
108
  try:
109
  result = _answer_model(
110
  prompt,
111
- max_new_tokens=250,
112
  do_sample=False,
113
- temperature=0.3
114
  )
115
  return result[0]["generated_text"].strip()
116
  except Exception as e:
@@ -119,7 +138,7 @@ def generate_answer(query: str, retrieved_chunks: list):
119
 
120
 
121
  # ==========================================================
122
- # 7️⃣ Optional: Test Run
123
  # ==========================================================
124
  if __name__ == "__main__":
125
  dummy_chunks = [
@@ -128,10 +147,13 @@ if __name__ == "__main__":
128
  "Integration with SAP ERP allows for seamless data synchronization."
129
  ]
130
  from vectorstore import build_faiss_index
 
 
131
  index = build_faiss_index([
132
- _query_model.encode([chunk], convert_to_numpy=True)[0]
133
  for chunk in dummy_chunks
134
  ])
 
135
  query = "What is SAP Ariba used for?"
136
  retrieved = retrieve_chunks(query, index, dummy_chunks)
137
  print("🔍 Retrieved:", retrieved)
 
2
  qa.py — Retrieval + Generation Layer
3
  -------------------------------------
4
  Handles:
5
+ • Query embedding (SentenceTransformer / E5-compatible)
6
  • Chunk retrieval (FAISS)
7
  • Answer generation (Flan-T5)
8
  Optimized for Hugging Face Spaces & Streamlit.
 
16
  print("✅ qa.py loaded from:", __file__)
17
 
18
  # ==========================================================
19
+ # 1️⃣ Hugging Face Cache Setup (Safe for Spaces)
20
  # ==========================================================
21
  CACHE_DIR = "/tmp/hf_cache"
22
  os.makedirs(CACHE_DIR, exist_ok=True)
 
29
  })
30
 
31
  # ==========================================================
32
+ # 2️⃣ Query Embedding Model
33
  # ==========================================================
34
+ # Use E5-small-v2 for retrieval consistency with embeddings.py
35
+ try:
36
+ _query_model = SentenceTransformer(
37
+ "intfloat/e5-small-v2",
38
+ cache_folder=CACHE_DIR
39
+ )
40
+ print("✅ Loaded query model: intfloat/e5-small-v2")
41
+ except Exception as e:
42
+ print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
43
+ _query_model = SentenceTransformer(
44
+ "sentence-transformers/all-MiniLM-L6-v2",
45
+ cache_folder=CACHE_DIR
46
+ )
47
+ print("✅ Loaded fallback model: all-MiniLM-L6-v2")
48
 
49
  # ==========================================================
50
+ # 3️⃣ LLM for Answer Generation (FLAN-T5)
51
  # ==========================================================
52
+ MODEL_NAME = "google/flan-t5-base" # switch to 'large' if RAM allows
53
  print(f"✅ Loading LLM: {MODEL_NAME}")
54
 
55
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
56
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
57
 
 
58
  _answer_model = pipeline(
59
  "text2text-generation",
60
  model=_model,
61
  tokenizer=_tokenizer,
62
+ device=-1 # CPU-safe for Spaces
63
  )
64
 
65
  # ==========================================================
66
+ # 4️⃣ Prompt Template (concise and factual)
67
  # ==========================================================
68
+ PROMPT_TEMPLATE = """
69
+ You are an expert enterprise assistant.
70
+ Using ONLY the CONTEXT below, answer the QUESTION clearly and factually.
71
+ If the context doesn’t contain the answer, reply exactly:
72
  "I don't know based on the provided document."
73
 
74
  ---
 
78
  Question:
79
  {query}
80
  ---
81
+ Answer:
82
+ """
83
 
84
  # ==========================================================
85
+ # 5️⃣ Chunk Retrieval Function
86
  # ==========================================================
87
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
88
  """
89
+ Encodes the user query and retrieves top-k relevant chunks via FAISS.
90
+ Uses 'query:' prefix (E5 training style) for semantic alignment.
91
  """
92
  if not index or not chunks:
93
  return []
94
 
95
  try:
96
+ # E5 expects 'query:' prefix for better retrieval accuracy
97
+ query_emb = _query_model.encode(
98
+ [f"query: {query.strip()}"],
99
+ convert_to_numpy=True,
100
+ normalize_embeddings=True
101
+ )[0]
102
+
103
+ results = search_faiss(query_emb, index, chunks, top_k)
104
  return results
105
+
106
  except Exception as e:
107
  print(f"⚠️ Retrieval error: {e}")
108
  return []
 
118
  if not retrieved_chunks:
119
  return "Sorry, I couldn’t find relevant information in the document."
120
 
121
+ # Merge retrieved chunks for context
122
  context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)])
123
 
124
+ # Build structured prompt
125
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
126
 
127
  try:
128
  result = _answer_model(
129
  prompt,
130
+ max_new_tokens=300,
131
  do_sample=False,
132
+ temperature=0.2
133
  )
134
  return result[0]["generated_text"].strip()
135
  except Exception as e:
 
138
 
139
 
140
  # ==========================================================
141
+ # 7️⃣ Optional Local Test
142
  # ==========================================================
143
  if __name__ == "__main__":
144
  dummy_chunks = [
 
147
  "Integration with SAP ERP allows for seamless data synchronization."
148
  ]
149
  from vectorstore import build_faiss_index
150
+ import numpy as np
151
+
152
  index = build_faiss_index([
153
+ _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
154
  for chunk in dummy_chunks
155
  ])
156
+
157
  query = "What is SAP Ariba used for?"
158
  retrieved = retrieve_chunks(query, index, dummy_chunks)
159
  print("🔍 Retrieved:", retrieved)