Harry_potter_wiki / chatbot_rag.py
Subha95's picture
Update chatbot_rag.py
89a0ac2 verified
raw
history blame
2.14 kB
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.chains import RetrievalQA
import traceback # βœ… added
def build_qa():
"""Builds and returns the RAG QA pipeline."""
print("πŸš€ Starting QA pipeline...")
# 1. Embeddings
print("πŸ”Ή Loading embeddings...")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# 2. Load vector DB
print("πŸ”Ή Loading Chroma DB...")
vectorstore = Chroma(
persist_directory="db",
collection_name="rag-docs",
embedding_function=embeddings,
)
print("πŸ“‚ Docs in DB:", vectorstore._collection.count())
# 3. LLM
print("πŸ”Ή Loading LLM...")
model_id = "microsoft/phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto"
)
print("βœ… LLM loaded.")
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.2,
)
llm = HuggingFacePipeline(pipeline=pipe)
# 4. QA Chain
print("πŸ”Ή Building RetrievalQA...")
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
qa = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
return_source_documents=False
)
print("βœ… QA pipeline ready.")
return qa
# Build at import time (so it's ready when app runs)
try:
qa_pipeline = build_qa()
except Exception as e:
qa_pipeline = None
print("❌ Failed to build QA pipeline:", e)
traceback.print_exc() # βœ… added: full error details
def get_answer(query: str) -> str:
"""Takes user query and returns chatbot response."""
if qa_pipeline is None:
return "⚠️ QA pipeline not initialized."
return qa_pipeline.run(query)