Zubaish commited on
Commit
cf1df19
·
1 Parent(s): 2194516
Files changed (1) hide show
  1. rag.py +42 -16
rag.py CHANGED
@@ -1,38 +1,64 @@
1
  import os
 
2
  from transformers import pipeline
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_chroma import Chroma
5
  from config import EMBEDDING_MODEL, LLM_MODEL, CHROMA_DIR, LLM_TASK
6
 
 
7
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
8
 
 
9
  if os.path.exists(CHROMA_DIR) and any(os.scandir(CHROMA_DIR)):
10
  vectordb = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
11
- print("✅ Vector DB loaded")
12
  else:
13
  vectordb = None
14
- print("⚠️ Vector DB missing")
15
 
16
- qa_pipeline = pipeline(task=LLM_TASK, model=LLM_MODEL, device_map="cpu", max_new_tokens=512, trust_remote_code=True)
 
 
 
 
 
 
 
 
17
 
18
  def ask_rag_with_status(question: str):
19
  if vectordb is None:
20
  return "Knowledge base not ready.", "ERROR"
21
 
 
22
  docs = vectordb.similarity_search(question, k=3)
23
- context = "\n\n".join(d.page_content for d in docs)
24
 
25
- # Qwen Chat Template
26
- messages = [
27
- {"role": "system", "content": "You are a Gandhi ji expert. Answer the question using ONLY the provided context."},
28
- {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
29
- ]
30
 
31
- prompt = qa_pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
32
- result = qa_pipeline(prompt, pad_token_id=qa_pipeline.tokenizer.eos_token_id)
33
-
34
- # Extract Qwen answer
35
- full_text = result[0]["generated_text"]
36
- answer = full_text.split("<|im_start|>assistant")[-1].strip().replace("<|im_end|>", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- return answer, ["Success"]
 
 
 
 
 
1
  import os
2
+ import torch
3
  from transformers import pipeline
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
  from langchain_chroma import Chroma
6
  from config import EMBEDDING_MODEL, LLM_MODEL, CHROMA_DIR, LLM_TASK
7
 
8
+ # 1. Initialize Embeddings
9
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
10
 
11
+ # 2. Load Vector DB
12
  if os.path.exists(CHROMA_DIR) and any(os.scandir(CHROMA_DIR)):
13
  vectordb = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
14
+ print("✅ Vector DB loaded successfully")
15
  else:
16
  vectordb = None
17
+ print("⚠️ Vector DB folder missing or empty")
18
 
19
+ # 3. LLM Pipeline - Optimized for CPU stability
20
+ qa_pipeline = pipeline(
21
+ LLM_TASK,
22
+ model=LLM_MODEL,
23
+ device_map="cpu",
24
+ max_new_tokens=256, # Sufficient for detailed answers
25
+ trust_remote_code=True,
26
+ model_kwargs={"torch_dtype": torch.float32} # Safer for CPU
27
+ )
28
 
29
  def ask_rag_with_status(question: str):
30
  if vectordb is None:
31
  return "Knowledge base not ready.", "ERROR"
32
 
33
+ # Search for context
34
  docs = vectordb.similarity_search(question, k=3)
35
+ context = "\n".join([d.page_content for d in docs])
36
 
37
+ # Simple, clear prompt for Qwen
38
+ prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
 
 
 
39
 
40
+ try:
41
+ # Generate with specific stopping criteria to prevent "looping"
42
+ result = qa_pipeline(
43
+ prompt,
44
+ do_sample=False, # Use greedy decoding for faster, consistent answers
45
+ temperature=0.0,
46
+ pad_token_id=qa_pipeline.tokenizer.eos_token_id
47
+ )
48
+
49
+ full_output = result[0]["generated_text"]
50
+
51
+ # Extract everything after the word "Answer:"
52
+ if "Answer:" in full_output:
53
+ answer = full_output.split("Answer:")[-1].strip()
54
+ else:
55
+ answer = full_output.strip()
56
+
57
+ if not answer:
58
+ answer = "I found context in the documents but could not generate a coherent summary. Please rephrase."
59
 
60
+ return answer, ["Context retrieved", "Qwen generated answer"]
61
+
62
+ except Exception as e:
63
+ print(f"❌ Generation error: {e}")
64
+ return "The model timed out while thinking. Try a shorter question.", "TIMEOUT"