Zubaish
commited on
Commit
·
1afe1ea
1
Parent(s):
19be3af
update
Browse files
rag.py
CHANGED
|
@@ -1,35 +1,48 @@
|
|
| 1 |
-
# rag.py
|
| 2 |
import os
|
| 3 |
from transformers import pipeline
|
| 4 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 5 |
from langchain_chroma import Chroma
|
| 6 |
from config import EMBEDDING_MODEL, LLM_MODEL, CHROMA_DIR
|
| 7 |
|
|
|
|
| 8 |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
| 9 |
|
| 10 |
-
#
|
| 11 |
-
if os.path.exists(CHROMA_DIR) and
|
| 12 |
vectordb = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
|
| 13 |
-
print("✅ Vector DB
|
| 14 |
else:
|
| 15 |
vectordb = None
|
| 16 |
-
print("⚠️ Vector DB
|
| 17 |
|
|
|
|
| 18 |
qa_pipeline = pipeline(
|
| 19 |
-
|
| 20 |
model=LLM_MODEL,
|
| 21 |
-
max_new_tokens=
|
| 22 |
-
|
| 23 |
)
|
| 24 |
|
| 25 |
def ask_rag_with_status(question: str):
|
| 26 |
if vectordb is None:
|
| 27 |
-
return "The knowledge base is not initialized
|
| 28 |
|
| 29 |
-
docs
|
| 30 |
-
|
| 31 |
-
prompt = f"Context: {context}\n\nQuestion: {question}\nAnswer:"
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from transformers import pipeline
|
| 3 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
from langchain_chroma import Chroma
|
| 5 |
from config import EMBEDDING_MODEL, LLM_MODEL, CHROMA_DIR
|
| 6 |
|
| 7 |
+
# 1. Initialize Embeddings
|
| 8 |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
| 9 |
|
| 10 |
+
# 2. Load Vector DB
|
| 11 |
+
if os.path.exists(CHROMA_DIR) and os.path.isdir(CHROMA_DIR):
|
| 12 |
vectordb = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
|
| 13 |
+
print("✅ Vector DB loaded successfully")
|
| 14 |
else:
|
| 15 |
vectordb = None
|
| 16 |
+
print("⚠️ Vector DB folder missing")
|
| 17 |
|
| 18 |
+
# 3. LLM Pipeline - Using the explicit class to avoid task errors
|
| 19 |
qa_pipeline = pipeline(
|
| 20 |
+
"text2text-generation", # T5 specifically needs this task name
|
| 21 |
model=LLM_MODEL,
|
| 22 |
+
max_new_tokens=128, # Reduced to keep responses concise
|
| 23 |
+
model_kwargs={"torch_dtype": "auto"}
|
| 24 |
)
|
| 25 |
|
| 26 |
def ask_rag_with_status(question: str):
|
| 27 |
if vectordb is None:
|
| 28 |
+
return "The knowledge base is not initialized properly.", "ERROR"
|
| 29 |
|
| 30 |
+
# Search for only 2 docs (k=2) to stay under the 512 token limit
|
| 31 |
+
docs = vectordb.similarity_search(question, k=2)
|
|
|
|
| 32 |
|
| 33 |
+
# Extract text and keep it short
|
| 34 |
+
context = " ".join([d.page_content[:400] for d in docs])
|
| 35 |
+
|
| 36 |
+
# Specific T5 Prompt Format: "question: ... context: ..."
|
| 37 |
+
prompt = f"question: {question} context: {context}"
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
result = qa_pipeline(prompt)
|
| 41 |
+
answer = result[0]["generated_text"].strip()
|
| 42 |
+
|
| 43 |
+
if not answer:
|
| 44 |
+
answer = "I couldn't find a specific answer in the documents provided."
|
| 45 |
+
|
| 46 |
+
return answer, ["Context retrieved", "T5 generating"]
|
| 47 |
+
except Exception as e:
|
| 48 |
+
return f"Error generating answer: {str(e)}", "ERROR"
|