Shubham170793 commited on
Commit
641185f
·
verified ·
1 Parent(s): 1e62275

Create qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +63 -0
src/qa.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sentence_transformers import SentenceTransformer
3
+ from transformers import pipeline
4
+ from vectorstore import search_faiss
5
+
6
+ print("✅ qa.py loaded from:", __file__)
7
+
8
+ # Always redirect Hugging Face caches to /tmp
9
+ CACHE_DIR = "/tmp/huggingface"
10
+ os.environ["HF_HOME"] = CACHE_DIR
11
+ os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
12
+ os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
13
+
14
+ # ----------------------------
15
+ # Embedding model for queries
16
+ # ----------------------------
17
+ _query_model = SentenceTransformer(
18
+ "sentence-transformers/all-MiniLM-L6-v2",
19
+ cache_folder=CACHE_DIR
20
+ )
21
+
22
+ # ----------------------------
23
+ # LLM for answers
24
+ # ----------------------------
25
+ MODEL_NAME = "google/flan-t5-small"
26
+ MODEL_PATH = os.path.join(CACHE_DIR, MODEL_NAME)
27
+
28
+ if not os.path.exists(MODEL_PATH):
29
+ print(f"⬇️ Downloading {MODEL_NAME} to {MODEL_PATH}")
30
+ _answer_model = pipeline(
31
+ "text2text-generation",
32
+ model=MODEL_NAME,
33
+ cache_dir=CACHE_DIR
34
+ )
35
+ # Save pipeline model locally
36
+ _answer_model.model.save_pretrained(MODEL_PATH)
37
+ _answer_model.tokenizer.save_pretrained(MODEL_PATH)
38
+ else:
39
+ print(f"✅ Loading {MODEL_NAME} from {MODEL_PATH}")
40
+ _answer_model = pipeline(
41
+ "text2text-generation",
42
+ model=MODEL_PATH,
43
+ cache_dir=CACHE_DIR
44
+ )
45
+
46
+ # ----------------------------
47
+ # Functions
48
+ # ----------------------------
49
+ def retrieve_chunks(query, index, chunks, top_k=3):
50
+ q_emb = _query_model.encode([query], convert_to_numpy=True)[0]
51
+ return search_faiss(q_emb, index, chunks, top_k)
52
+
53
+ def generate_answer(query, retrieved_chunks):
54
+ if not retrieved_chunks:
55
+ return "Sorry, I could not find relevant information."
56
+
57
+ context = " ".join(retrieved_chunks)
58
+ prompt = (
59
+ "You are an assistant. Use the context to answer the question clearly.\n"
60
+ f"Context:\n{context}\n\nQuestion:\n{query}\n\nAnswer:"
61
+ )
62
+ result = _answer_model(prompt, max_length=300, do_sample=False)
63
+ return result[0]["generated_text"].strip()