Zubaish commited on
Commit ·
cf1df19
1
Parent(s): 2194516
update
Browse files
rag.py
CHANGED
|
@@ -1,38 +1,64 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from transformers import pipeline
|
| 3 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
from langchain_chroma import Chroma
|
| 5 |
from config import EMBEDDING_MODEL, LLM_MODEL, CHROMA_DIR, LLM_TASK
|
| 6 |
|
|
|
|
| 7 |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
| 8 |
|
|
|
|
| 9 |
if os.path.exists(CHROMA_DIR) and any(os.scandir(CHROMA_DIR)):
|
| 10 |
vectordb = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
|
| 11 |
-
print("✅ Vector DB loaded")
|
| 12 |
else:
|
| 13 |
vectordb = None
|
| 14 |
-
print("⚠️ Vector DB missing")
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def ask_rag_with_status(question: str):
|
| 19 |
if vectordb is None:
|
| 20 |
return "Knowledge base not ready.", "ERROR"
|
| 21 |
|
|
|
|
| 22 |
docs = vectordb.similarity_search(question, k=3)
|
| 23 |
-
context = "\n
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
{"role": "system", "content": "You are a Gandhi ji expert. Answer the question using ONLY the provided context."},
|
| 28 |
-
{"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
|
| 29 |
-
]
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import torch
|
| 3 |
from transformers import pipeline
|
| 4 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 5 |
from langchain_chroma import Chroma
|
| 6 |
from config import EMBEDDING_MODEL, LLM_MODEL, CHROMA_DIR, LLM_TASK
|
| 7 |
|
| 8 |
+
# 1. Initialize Embeddings
|
| 9 |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
| 10 |
|
| 11 |
+
# 2. Load Vector DB
|
| 12 |
if os.path.exists(CHROMA_DIR) and any(os.scandir(CHROMA_DIR)):
|
| 13 |
vectordb = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
|
| 14 |
+
print("✅ Vector DB loaded successfully")
|
| 15 |
else:
|
| 16 |
vectordb = None
|
| 17 |
+
print("⚠️ Vector DB folder missing or empty")
|
| 18 |
|
| 19 |
+
# 3. LLM Pipeline - Optimized for CPU stability
|
| 20 |
+
qa_pipeline = pipeline(
|
| 21 |
+
LLM_TASK,
|
| 22 |
+
model=LLM_MODEL,
|
| 23 |
+
device_map="cpu",
|
| 24 |
+
max_new_tokens=256, # Sufficient for detailed answers
|
| 25 |
+
trust_remote_code=True,
|
| 26 |
+
model_kwargs={"torch_dtype": torch.float32} # Safer for CPU
|
| 27 |
+
)
|
| 28 |
|
| 29 |
def ask_rag_with_status(question: str):
|
| 30 |
if vectordb is None:
|
| 31 |
return "Knowledge base not ready.", "ERROR"
|
| 32 |
|
| 33 |
+
# Search for context
|
| 34 |
docs = vectordb.similarity_search(question, k=3)
|
| 35 |
+
context = "\n".join([d.page_content for d in docs])
|
| 36 |
|
| 37 |
+
# Simple, clear prompt for Qwen
|
| 38 |
+
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
try:
|
| 41 |
+
# Generate with specific stopping criteria to prevent "looping"
|
| 42 |
+
result = qa_pipeline(
|
| 43 |
+
prompt,
|
| 44 |
+
do_sample=False, # Use greedy decoding for faster, consistent answers
|
| 45 |
+
temperature=0.0,
|
| 46 |
+
pad_token_id=qa_pipeline.tokenizer.eos_token_id
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
full_output = result[0]["generated_text"]
|
| 50 |
+
|
| 51 |
+
# Extract everything after the word "Answer:"
|
| 52 |
+
if "Answer:" in full_output:
|
| 53 |
+
answer = full_output.split("Answer:")[-1].strip()
|
| 54 |
+
else:
|
| 55 |
+
answer = full_output.strip()
|
| 56 |
+
|
| 57 |
+
if not answer:
|
| 58 |
+
answer = "I found context in the documents but could not generate a coherent summary. Please rephrase."
|
| 59 |
|
| 60 |
+
return answer, ["Context retrieved", "Qwen generated answer"]
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"❌ Generation error: {e}")
|
| 64 |
+
return "The model timed out while thinking. Try a shorter question.", "TIMEOUT"
|