Shubham170793 commited on
Commit
9f0da7b
·
verified ·
1 Parent(s): 6944855

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +21 -16
src/qa.py CHANGED
@@ -1,8 +1,13 @@
1
- # ----------------------------
2
- # Hugging Face cache bootstrap
3
- # ----------------------------
4
  import os
 
 
 
5
 
 
 
 
 
 
6
  CACHE_DIR = "/tmp/hf_cache"
7
  os.makedirs(CACHE_DIR, exist_ok=True)
8
 
@@ -11,15 +16,6 @@ os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
11
  os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
12
  os.environ["HF_MODULES_CACHE"] = CACHE_DIR
13
 
14
- print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
15
-
16
- # ----------------------------
17
- # Imports AFTER cache bootstrap
18
- # ----------------------------
19
- from sentence_transformers import SentenceTransformer
20
- from transformers import pipeline
21
- from vectorstore import search_faiss
22
-
23
  # ----------------------------
24
  # Query embedding model
25
  # ----------------------------
@@ -32,7 +28,6 @@ _query_model = SentenceTransformer(
32
  # LLM for answers
33
  # ----------------------------
34
  MODEL_NAME = "google/flan-t5-small"
35
-
36
  _answer_model = pipeline(
37
  "text2text-generation",
38
  model=MODEL_NAME,
@@ -43,17 +38,27 @@ _answer_model = pipeline(
43
  # Functions
44
  # ----------------------------
45
  def retrieve_chunks(query, index, chunks, top_k=3):
 
46
  q_emb = _query_model.encode([query], convert_to_numpy=True)[0]
47
  return search_faiss(q_emb, index, chunks, top_k)
48
 
49
  def generate_answer(query, retrieved_chunks):
 
50
  if not retrieved_chunks:
51
  return "Sorry, I could not find relevant information."
52
 
53
  context = " ".join(retrieved_chunks)
54
  prompt = (
55
- "You are an assistant. Use the context to answer the question clearly.\n"
56
- f"Context:\n{context}\n\nQuestion:\n{query}\n\nAnswer:"
 
 
 
 
 
 
 
 
 
57
  )
58
- result = _answer_model(prompt, max_length=300, do_sample=False)
59
  return result[0]["generated_text"].strip()
 
 
 
 
1
  import os
2
+ from sentence_transformers import SentenceTransformer
3
+ from transformers import pipeline
4
+ from vectorstore import search_faiss
5
 
6
+ print("✅ qa.py loaded from:", __file__)
7
+
8
+ # ----------------------------
9
+ # Hugging Face cache setup
10
+ # ----------------------------
11
  CACHE_DIR = "/tmp/hf_cache"
12
  os.makedirs(CACHE_DIR, exist_ok=True)
13
 
 
16
  os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
17
  os.environ["HF_MODULES_CACHE"] = CACHE_DIR
18
 
 
 
 
 
 
 
 
 
 
19
  # ----------------------------
20
  # Query embedding model
21
  # ----------------------------
 
28
  # LLM for answers
29
  # ----------------------------
30
  MODEL_NAME = "google/flan-t5-small"
 
31
  _answer_model = pipeline(
32
  "text2text-generation",
33
  model=MODEL_NAME,
 
38
  # Functions
39
  # ----------------------------
40
  def retrieve_chunks(query, index, chunks, top_k=3):
41
+ """Embed the query and retrieve top-k chunks from FAISS."""
42
  q_emb = _query_model.encode([query], convert_to_numpy=True)[0]
43
  return search_faiss(q_emb, index, chunks, top_k)
44
 
45
  def generate_answer(query, retrieved_chunks):
46
+ """Generate an answer using retrieved chunks as context."""
47
  if not retrieved_chunks:
48
  return "Sorry, I could not find relevant information."
49
 
50
  context = " ".join(retrieved_chunks)
51
  prompt = (
52
+ "You are an assistant. Use the context below to answer the question clearly.\n\n"
53
+ f"Context:\n{context}\n\n"
54
+ f"Question:\n{query}\n\n"
55
+ "Answer:"
56
+ )
57
+
58
+ # ✅ Use max_new_tokens instead of max_length to avoid version mismatch errors
59
+ result = _answer_model(
60
+ prompt,
61
+ max_new_tokens=300,
62
+ do_sample=False
63
  )
 
64
  return result[0]["generated_text"].strip()