Harry_potter_wiki / chatbot_rag.py
Subha95's picture
Update chatbot_rag.py
cff40c6 verified
raw
history blame
3.33 kB
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, pipeline
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import traceback
def build_qa():
"""Builds and returns the RAG QA pipeline (rag_chain style)."""
print("πŸš€ Starting QA pipeline...")
# 1. Embeddings
print("πŸ”Ή Loading embeddings...")
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# 2. Load vector DB
print("πŸ”Ή Loading Chroma DB...")
vectorstore = Chroma(
persist_directory="db",
collection_name="rag-docs",
embedding_function=embeddings,
)
print("πŸ“‚ Docs in DB:", vectorstore._collection.count())
# 3. Load LLM (Phi-3 mini)
print("πŸ”Ή Loading LLM...")
model_id = "declare-lab/flan-alpaca-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=300,
do_sample=False,
)
llm = HuggingFacePipeline(pipeline=pipe)
# 4. Retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# 5. Prompt
prompt = PromptTemplate(
input_variables=["context", "question"],
template="""
Use the following context to answer the question at the end.
If you don't know the answer, just say "I don't know" β€” do not make up an answer.
Context:
{context}
Question: {question}
Answer (one short sentence):
""",
)
# 6. Helper functions
def format_docs(docs):
return "\n".join(doc.page_content for doc in docs)
def hf_to_str(x):
"""Convert Hugging Face pipeline output to plain string"""
if isinstance(x, list) and "generated_text" in x[0]:
return x[0]["generated_text"]
return str(x)
# 7. RAG chain
rag_chain = (
{
"context": retriever | format_docs,
"question": RunnablePassthrough(),
}
| prompt
| (lambda x: str(x)) # convert PromptTemplate value to str
| llm
| (lambda x: hf_to_str(x)) # clean HF output
| StrOutputParser()
)
print("βœ… QA pipeline ready.")
return rag_chain
# Build once
try:
qa_pipeline = build_qa()
print("βœ… qa_pipeline built successfully:", type(qa_pipeline))
except Exception as e:
qa_pipeline = None
print("❌ Failed to build QA pipeline")
print("Error message:", str(e))
traceback.print_exc()
def get_answer(query: str) -> str:
"""
Run a query against the QA pipeline and return the answer text.
"""
if qa_pipeline is None:
return "⚠️ QA pipeline not initialized."
try:
result = qa_pipeline.invoke(query) # for LCEL chain
return result
except Exception as e:
return f"❌ QA run failed: {e}"