Zubaish commited on
Commit
1afe1ea
·
1 Parent(s): 19be3af
Files changed (1) hide show
  1. rag.py +28 -15
rag.py CHANGED
@@ -1,35 +1,48 @@
1
- # rag.py
2
  import os
3
  from transformers import pipeline
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
  from langchain_chroma import Chroma
6
  from config import EMBEDDING_MODEL, LLM_MODEL, CHROMA_DIR
7
 
 
8
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
9
 
10
- # Check if directory exists AND has files
11
- if os.path.exists(CHROMA_DIR) and any(os.scandir(CHROMA_DIR)):
12
  vectordb = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
13
- print("✅ Vector DB ready")
14
  else:
15
  vectordb = None
16
- print("⚠️ Vector DB not found or empty")
17
 
 
18
  qa_pipeline = pipeline(
19
- task="text-generation",
20
  model=LLM_MODEL,
21
- max_new_tokens=256,
22
- trust_remote_code=True
23
  )
24
 
25
  def ask_rag_with_status(question: str):
26
  if vectordb is None:
27
- return "The knowledge base is not initialized. Please check deployment logs.", "ERROR"
28
 
29
- docs = vectordb.similarity_search(question, k=3)
30
- context = "\n\n".join(d.page_content for d in docs)
31
- prompt = f"Context: {context}\n\nQuestion: {question}\nAnswer:"
32
 
33
- result = qa_pipeline(prompt)
34
- answer = result[0]["generated_text"].split("Answer:")[-1].strip()
35
- return answer, ["Success"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from transformers import pipeline
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_chroma import Chroma
5
  from config import EMBEDDING_MODEL, LLM_MODEL, CHROMA_DIR
6
 
7
+ # 1. Initialize Embeddings
8
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
9
 
10
+ # 2. Load Vector DB
11
+ if os.path.exists(CHROMA_DIR) and os.path.isdir(CHROMA_DIR):
12
  vectordb = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
13
+ print("✅ Vector DB loaded successfully")
14
  else:
15
  vectordb = None
16
+ print("⚠️ Vector DB folder missing")
17
 
18
+ # 3. LLM Pipeline - Using the explicit class to avoid task errors
19
  qa_pipeline = pipeline(
20
+ "text2text-generation", # T5 specifically needs this task name
21
  model=LLM_MODEL,
22
+ max_new_tokens=128, # Reduced to keep responses concise
23
+ model_kwargs={"torch_dtype": "auto"}
24
  )
25
 
26
  def ask_rag_with_status(question: str):
27
  if vectordb is None:
28
+ return "The knowledge base is not initialized properly.", "ERROR"
29
 
30
+ # Search for only 2 docs (k=2) to stay under the 512 token limit
31
+ docs = vectordb.similarity_search(question, k=2)
 
32
 
33
+ # Extract text and keep it short
34
+ context = " ".join([d.page_content[:400] for d in docs])
35
+
36
+ # Specific T5 Prompt Format: "question: ... context: ..."
37
+ prompt = f"question: {question} context: {context}"
38
+
39
+ try:
40
+ result = qa_pipeline(prompt)
41
+ answer = result[0]["generated_text"].strip()
42
+
43
+ if not answer:
44
+ answer = "I couldn't find a specific answer in the documents provided."
45
+
46
+ return answer, ["Context retrieved", "T5 generating"]
47
+ except Exception as e:
48
+ return f"Error generating answer: {str(e)}", "ERROR"