Subhakanta
updated chatbot.py
b3c9dfb
import os
# --- add these 3 lines before anything Hugging Face runs ---
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/hf_cache"
from typing import List
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain.schema import HumanMessage, AIMessage
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
# ---------------------------
# Load environment variables
# ---------------------------
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
# ---------------------------
# Settings / Tuning
# ---------------------------
DB_FAISS_PATH = "vectorStore"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
K = 5 # how many candidates to check for pre-filter
MAX_DISTANCE = 1.0 # FAISS distance threshold (lower = better).
MAX_CHAT_HISTORY = 50 # cap chat history to avoid unbounded growth
# ---------------------------
# Load FAISS VectorStore
# ---------------------------
embeddings = HuggingFaceEmbeddings(
model_name=EMBED_MODEL,
cache_folder="/tmp/hf_cache" # <--- new
)
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
# ---------------------------
# ChatBot Class
# ---------------------------
class RAGChatBot:
def __init__(self):
# LLM
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not set in environment")
self.llm = ChatGroq(
groq_api_key=GROQ_API_KEY,
model="llama-3.1-8b-instant",
temperature=0
)
self.chat_history: List = []
# Retriever used by RetrievalQA (kept, but we will pre-filter before calling the chain)
self.retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
# Custom Prompt (dynamic fallback included)
custom_prompt = """
Use the following context to answer the user’s question.
If the answer cannot be found in the context, reply exactly with:
"I'm trained only on Odisha disaster management reports (i.e,OSDMA, NDMA, IMD, Research papers). I don't have any information about: '{question}'"
Context:
{context}
Question:
{question}
Answer:
"""
self.prompt = PromptTemplate(template=custom_prompt, input_variables=["context", "question"])
# Retrieval QA Chain (keeps structured QA behavior)
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
retriever=self.retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": self.prompt}
)
# ---------------------------
# NEW: Rewrite function
# ---------------------------
def rewrite_query(self, user_input: str) -> str:
"""Rewrite query into formal disaster-management style language using LLM."""
rewrite_prompt = f"""
Rewrite the following user query into clear, formal disaster management language
as used in government reports (OSDMA, NDMA, IMD).
If it is not disaster-related, just return it unchanged.
Query: {user_input}
"""
try:
response = self.llm.invoke([HumanMessage(content=rewrite_prompt)])
return response.content.strip()
except Exception as e:
print("⚠ Rewrite error:", e)
return user_input # fallback to original
def _prefilter_by_distance(self, query: str, k: int = K, max_distance: float = MAX_DISTANCE) -> bool:
"""Check if query is in-domain using FAISS distance."""
results = db.similarity_search_with_score(query, k=k)
if not results:
return False
best_score = results[0][1] # (Document, score)
return best_score <= max_distance
def chat(self, user_input: str) -> str:
# 1) Rewrite user query
rewritten_query = self.rewrite_query(user_input)
# print(f"[debug] rewritten query: {rewritten_query}")
# 2) Quick in-domain prefilter
try:
in_domain = self._prefilter_by_distance(rewritten_query)
except Exception as e:
print("⚠ prefilter error:", e)
in_domain = True
if not in_domain:
return (
f"I’m trained only on Odisha disaster management reports "
f"(OSDMA, NDMA, IMD, research). I don’t have any information about: '{user_input}'."
)
# 3) Retrieval + QA
try:
response = self.qa_chain.invoke({"query": rewritten_query})
answer = response.get("result") if isinstance(response, dict) else str(response)
except Exception as e:
print("⚠ LLM / chain error:", e)
answer = "Sorry, I encountered an error while generating the answer."
# 4) Update memory (bounded)
self.chat_history.append(HumanMessage(content=user_input))
self.chat_history.append(AIMessage(content=answer))
if len(self.chat_history) > MAX_CHAT_HISTORY * 2:
self.chat_history = self.chat_history[-MAX_CHAT_HISTORY * 2 :]
return answer
# ---------------------------
# Run Chatbot (CLI)
# ---------------------------
if __name__ == "__main__":
bot = RAGChatBot()
print("🤖 Odisha Disaster Management ChatBot ready! Type 'exit' to quit.")
while True:
query = input("You: ")
if query.lower() in ["exit", "quit"]:
break
print("Bot:", bot.chat(query))