Harry_potter_wiki / chatbot_rag.py
Subha95's picture
Update chatbot_rag.py
ff9f12d verified
raw
history blame
4.88 kB
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, pipeline
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import traceback
import re
import os
from huggingface_hub import login
token = os.getenv("HF_TOKEN")
print("πŸ”‘ HF_TOKEN available?", token is not None)
if token:
login(token=token)
else:
print("❌ No HF_TOKEN found in environment")
def build_qa():
"""Builds and returns the RAG QA pipeline (rag_chain style)."""
print("πŸš€ Starting QA pipeline...")
# 1. Embeddings
print("πŸ”Ή Loading embeddings...")
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# 2. Load vector DB
print("πŸ”Ή Loading Chroma DB...")
vectorstore = Chroma(
persist_directory="db",
collection_name="rag-docs",
embedding_function=embeddings,
)
print("πŸ“‚ Docs in DB:", vectorstore._collection.count())
# 3. Load LLM (Phi-3 mini)
print("πŸ”Ή Loading LLM...")
model_id = "meta-llama/Llama-3.2-1B-Instruct" # or "meta-llama/Llama-3.1-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
trust_remote_code=True # ensures it runs on available CPU
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=128,
temperature=0.4, # keeps answers deterministic but less rigid than 0
do_sample=True, # allow some randomness
top_p=0.9, # nucleus sampling to avoid loops
repetition_penalty=1.2, # πŸš€ penalize repeats
eos_token_id=tokenizer.eos_token_id, # stop at EOS
return_full_text=False
)
llm = HuggingFacePipeline(pipeline=pipe)
# 4. Retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
prompt = PromptTemplate(
input_variables=["context", "question"],
template="""
Use the following context to answer the question.
- Answer from the docs
- Answer in plain natural language.
- Do not include code, imports, functions, or explanations of how to implement code.
- If you don't know, just say "I don't know."
Context:
{context}
Question: {question}
Answer (one short sentence):
""",
)
# 6. Helper functions
def format_docs(docs):
return "\n".join(doc.page_content for doc in docs)
def hf_to_str(x):
"""Convert Hugging Face pipeline output to clean plain text."""
if isinstance(x, list) and "generated_text" in x[0]:
text = x[0]["generated_text"]
else:
text = str(x)
# Remove code-like patterns (imports, defs, classes, etc.)
text = re.sub(r"(from\s+\w+\s+import\s+.*|import\s+\w+.*)", "", text)
text = re.sub(r"def\s+\w+\(.*?\):.*", "", text, flags=re.DOTALL)
text = re.sub(r"class\s+\w+.*?:.*", "", text, flags=re.DOTALL)
text = re.sub(r"text\s*\+=.*", "", text)
# Remove markdown/code fences & quotes
text = text.replace("```", "").replace("'''", "").replace('"""', "").replace("\\n", " ")
# Normalize whitespace
text = re.sub(r"\s+", " ", text)
# Deduplicate repeated sentences
sentences = []
for s in re.split(r"(?<=[.!?])\s+", text):
if s and s not in sentences:
sentences.append(s)
text = " ".join(sentences)
return text.strip()
# 7. RAG chain
rag_chain = (
{
"context": retriever | format_docs,
"question": RunnablePassthrough(),
}
| prompt
| (lambda x: str(x)) # convert PromptTemplate value to str
| llm
| (lambda x: hf_to_str(x)) # clean HF output
| StrOutputParser()
)
print("βœ… QA pipeline ready.")
return rag_chain
# Build once
try:
qa_pipeline = build_qa()
print("βœ… qa_pipeline built successfully:", type(qa_pipeline))
except Exception as e:
qa_pipeline = None
print("❌ Failed to build QA pipeline")
print("Error message:", str(e))
traceback.print_exc()
def get_answer(query: str) -> str:
"""
Run a query against the QA pipeline and return the answer text.
"""
if qa_pipeline is None:
return "⚠️ QA pipeline not initialized."
try:
result = qa_pipeline.invoke(query) # for LCEL chain
return result
except Exception as e:
return f"❌ QA run failed: {e}"