Shubham170793 commited on
Commit
cd266a5
·
verified ·
1 Parent(s): 4da661f

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +90 -49
src/qa.py CHANGED
@@ -1,3 +1,13 @@
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from sentence_transformers import SentenceTransformer
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
@@ -5,93 +15,124 @@ from vectorstore import search_faiss
5
 
6
  print("✅ qa.py loaded from:", __file__)
7
 
8
- # ----------------------------
9
- # Hugging Face cache setup
10
- # ----------------------------
11
  CACHE_DIR = "/tmp/hf_cache"
12
  os.makedirs(CACHE_DIR, exist_ok=True)
13
 
14
- os.environ["HF_HOME"] = CACHE_DIR
15
- os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
16
- os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
17
- os.environ["HF_MODULES_CACHE"] = CACHE_DIR
 
 
18
 
19
- # ----------------------------
20
- # Query embedding model
21
- # ----------------------------
22
  _query_model = SentenceTransformer(
23
  "sentence-transformers/all-MiniLM-L6-v2",
24
  cache_folder=CACHE_DIR
25
  )
 
 
 
 
 
 
 
26
 
27
- # ----------------------------
28
- # LLM for answers (FLAN)
29
- # ----------------------------
30
- MODEL_NAME = "google/flan-t5-large" # you can switch to flan-t5-base if Codespace is low on RAM
31
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
32
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
33
 
 
34
  _answer_model = pipeline(
35
  "text2text-generation",
36
  model=_model,
37
- tokenizer=_tokenizer
 
38
  )
39
 
40
- # ----------------------------
41
- # Prompt Template
42
- # ----------------------------
43
- PROMPT_CONCISE = """
44
- You are an expert analyst. Using ONLY the CONTEXT below, answer the QUESTION clearly and concisely.
45
- If the answer cannot be found in the context, reply exactly: "I don't know based on the provided document."
46
-
47
- Instructions:
48
- • Start with a one-sentence answer.
49
- • Then give up to 3 short numbered supporting points (each ≤ 25 words).
50
- • After that, list the sources referenced as [Chunk N].
51
 
 
52
  Context:
53
  {context}
54
-
55
  Question:
56
  {query}
 
 
57
 
58
- Answer:
59
- """
60
-
61
- # ----------------------------
62
- # Functions
63
- # ----------------------------
64
- def retrieve_chunks(query, index, chunks, top_k=3):
65
  """
66
- Embed the query and retrieve top-k chunks from FAISS.
67
  """
68
- q_emb = _query_model.encode([query], convert_to_numpy=True)[0]
69
- return search_faiss(q_emb, index, chunks, top_k)
 
 
 
 
 
 
 
 
70
 
71
 
72
- def generate_answer(query, retrieved_chunks):
 
 
 
73
  """
74
- Generate an answer using FLAN and the retrieved chunks as context.
75
  """
76
  if not retrieved_chunks:
77
  return "Sorry, I couldn’t find relevant information in the document."
78
 
79
- # Format chunks for context clarity
80
  context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)])
81
 
82
- # Build prompt using the concise structured template
83
- prompt = PROMPT_CONCISE.format(context=context, query=query)
84
 
85
  try:
86
  result = _answer_model(
87
  prompt,
88
- max_new_tokens=300,
89
  do_sample=False,
90
- temperature=0.2
91
  )
92
- answer = result[0]["generated_text"].strip()
93
  except Exception as e:
94
- print("⚠️ FLAN generation failed:", e)
95
- answer = "Sorry, I couldn’t generate an answer at the moment."
96
-
97
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ qa.py — Retrieval + Generation Layer
3
+ -------------------------------------
4
+ Handles:
5
+ • Query embedding (SentenceTransformer)
6
+ • Chunk retrieval (FAISS)
7
+ • Answer generation (Flan-T5)
8
+ Optimized for Hugging Face Spaces & Streamlit.
9
+ """
10
+
11
  import os
12
  from sentence_transformers import SentenceTransformer
13
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
15
 
16
  print("✅ qa.py loaded from:", __file__)
17
 
18
+ # ==========================================================
19
+ # 1️⃣ Cache Configuration (Hugging Face safe /tmp folder)
20
+ # ==========================================================
21
  CACHE_DIR = "/tmp/hf_cache"
22
  os.makedirs(CACHE_DIR, exist_ok=True)
23
 
24
+ os.environ.update({
25
+ "HF_HOME": CACHE_DIR,
26
+ "TRANSFORMERS_CACHE": CACHE_DIR,
27
+ "HF_DATASETS_CACHE": CACHE_DIR,
28
+ "HF_MODULES_CACHE": CACHE_DIR
29
+ })
30
 
31
+ # ==========================================================
32
+ # 2️⃣ Embedding Model (for Query Encoding)
33
+ # ==========================================================
34
  _query_model = SentenceTransformer(
35
  "sentence-transformers/all-MiniLM-L6-v2",
36
  cache_folder=CACHE_DIR
37
  )
38
+ print("✅ Loaded embedding model: all-MiniLM-L6-v2")
39
+
40
+ # ==========================================================
41
+ # 3️⃣ LLM for Answers (Google FLAN-T5)
42
+ # ==========================================================
43
+ MODEL_NAME = "google/flan-t5-base" # lighter & faster; can switch to 'large' for higher accuracy
44
+ print(f"✅ Loading LLM: {MODEL_NAME}")
45
 
 
 
 
 
46
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
47
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
48
 
49
+ # Efficient text2text generation pipeline
50
  _answer_model = pipeline(
51
  "text2text-generation",
52
  model=_model,
53
+ tokenizer=_tokenizer,
54
+ device=-1 # ensures CPU-safe execution (avoid GPU dependency)
55
  )
56
 
57
+ # ==========================================================
58
+ # 4️⃣ Prompt Template
59
+ # ==========================================================
60
+ PROMPT_TEMPLATE = """You are an expert enterprise assistant.
61
+ Using ONLY the context provided below, answer the question clearly and factually.
62
+ If the context doesn’t contain the answer, reply exactly:
63
+ "I don't know based on the provided document."
 
 
 
 
64
 
65
+ ---
66
  Context:
67
  {context}
68
+ ---
69
  Question:
70
  {query}
71
+ ---
72
+ Answer:"""
73
 
74
+ # ==========================================================
75
+ # 5️⃣ Retrieval Function
76
+ # ==========================================================
77
+ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
 
 
 
78
  """
79
+ Encodes the user query and retrieves top-k most relevant chunks from FAISS.
80
  """
81
+ if not index or not chunks:
82
+ return []
83
+
84
+ try:
85
+ q_emb = _query_model.encode([query], convert_to_numpy=True)[0]
86
+ results = search_faiss(q_emb, index, chunks, top_k)
87
+ return results
88
+ except Exception as e:
89
+ print(f"⚠️ Retrieval error: {e}")
90
+ return []
91
 
92
 
93
+ # ==========================================================
94
+ # 6️⃣ Answer Generation Function
95
+ # ==========================================================
96
+ def generate_answer(query: str, retrieved_chunks: list):
97
  """
98
+ Generates an answer using FLAN-T5 and retrieved chunks as context.
99
  """
100
  if not retrieved_chunks:
101
  return "Sorry, I couldn’t find relevant information in the document."
102
 
103
+ # Merge top chunks into one context block
104
  context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)])
105
 
106
+ prompt = PROMPT_TEMPLATE.format(context=context, query=query)
 
107
 
108
  try:
109
  result = _answer_model(
110
  prompt,
111
+ max_new_tokens=250,
112
  do_sample=False,
113
+ temperature=0.3
114
  )
115
+ return result[0]["generated_text"].strip()
116
  except Exception as e:
117
+ print(f"⚠️ Generation failed: {e}")
118
+ return "⚠️ Error: Could not generate an answer at the moment."
119
+
120
+
121
+ # ==========================================================
122
+ # 7️⃣ Optional: Test Run
123
+ # ==========================================================
124
+ if __name__ == "__main__":
125
+ dummy_chunks = [
126
+ "SAP Ariba is a cloud-based procurement solution.",
127
+ "It helps companies manage suppliers and sourcing processes efficiently.",
128
+ "Integration with SAP ERP allows for seamless data synchronization."
129
+ ]
130
+ from vectorstore import build_faiss_index
131
+ index = build_faiss_index([
132
+ _query_model.encode([chunk], convert_to_numpy=True)[0]
133
+ for chunk in dummy_chunks
134
+ ])
135
+ query = "What is SAP Ariba used for?"
136
+ retrieved = retrieve_chunks(query, index, dummy_chunks)
137
+ print("🔍 Retrieved:", retrieved)
138
+ print("💬 Answer:", generate_answer(query, retrieved))