Harry_potter_wiki / chatbot_rag.py
Subha95's picture
Update chatbot_rag.py
772864e verified
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import re
import os
import traceback
from huggingface_hub import login
token = os.getenv("HF_TOKEN")
print("πŸ”‘ HF_TOKEN available?", token is not None)
if token:
login(token=token)
else:
print("❌ No HF_TOKEN found in environment")
def build_qa():
"""Builds and returns the RAG QA pipeline (rag_chain style)."""
print("πŸš€ Starting QA pipeline...")
# 1. Embeddings
print("πŸ”Ή Loading embeddings...")
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# 2. Load vector DB
print("πŸ”Ή Loading Chroma DB...")
vectorstore = Chroma(
persist_directory="db",
collection_name="rag-docs",
embedding_function=embeddings,
)
print("πŸ“‚ Docs in DB:", vectorstore._collection.count())
# 3. Load LLM (Phi-3.5-mini-instruct)
print("πŸ”Ή Loading LLM...")
model_id = "microsoft/Phi-3.5-mini-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True
)
model.config.use_cache = False
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=80, # shorter answers
temperature=0.2, # deterministic
do_sample=False,
repetition_penalty=1.2,
eos_token_id=tokenizer.eos_token_id,
return_full_text=False
)
llm = HuggingFacePipeline(pipeline=pipe)
# 4. Retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# 5. Prompt
prompt = PromptTemplate(
input_variables=["context", "question"],
template="""Answer the question using the context below.
Respond in ONE short factual sentence only.
If you don't know, say "I don't know."
Context:
{context}
Question:
{question}
Answer:""",
)
# 6. Helper
def format_docs(docs):
texts = [doc.page_content.strip() for doc in docs if doc.page_content]
return "\n".join(texts)
def hf_to_str(x):
if isinstance(x, list) and "generated_text" in x[0]:
text = x[0]["generated_text"]
else:
text = str(x)
text = re.sub(r"\s+", " ", text).strip()
# βœ… Only keep first sentence
return re.split(r"(?<=[.!?])\s+", text)[0]
# 7. Chain
rag_chain = (
{
"context": retriever | format_docs,
"question": RunnablePassthrough(),
}
| prompt
| llm
| (lambda x: hf_to_str(x))
| StrOutputParser()
)
print("βœ… QA pipeline ready.")
return rag_chain
# Build once
try:
qa_pipeline = build_qa()
print("βœ… qa_pipeline built successfully:", type(qa_pipeline))
except Exception as e:
qa_pipeline = None
print("❌ Failed to build QA pipeline")
print("Error message:", str(e))
traceback.print_exc()
def get_answer(query: str) -> str:
"""Run a query against the QA pipeline and return the answer text."""
if qa_pipeline is None:
return "⚠️ QA pipeline not initialized."
try:
result = qa_pipeline.invoke(query)
return result
except Exception as e:
return f"❌ QA run failed: {e}"