|
|
import os
|
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp/hf_cache"
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
|
|
|
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/hf_cache"
|
|
|
|
|
|
from typing import List
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
from langchain_groq import ChatGroq
|
|
|
from langchain.schema import HumanMessage, AIMessage
|
|
|
from langchain_community.vectorstores import FAISS
|
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
from langchain.prompts import PromptTemplate
|
|
|
from langchain.chains import RetrievalQA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DB_FAISS_PATH = "vectorStore"
|
|
|
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
K = 5
|
|
|
MAX_DISTANCE = 1.0
|
|
|
MAX_CHAT_HISTORY = 50
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(
|
|
|
model_name=EMBED_MODEL,
|
|
|
cache_folder="/tmp/hf_cache"
|
|
|
)
|
|
|
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RAGChatBot:
|
|
|
def __init__(self):
|
|
|
|
|
|
if not GROQ_API_KEY:
|
|
|
raise ValueError("GROQ_API_KEY not set in environment")
|
|
|
self.llm = ChatGroq(
|
|
|
groq_api_key=GROQ_API_KEY,
|
|
|
model="llama-3.1-8b-instant",
|
|
|
temperature=0
|
|
|
)
|
|
|
self.chat_history: List = []
|
|
|
|
|
|
|
|
|
self.retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
|
|
|
|
|
|
|
|
custom_prompt = """
|
|
|
Use the following context to answer the user’s question.
|
|
|
If the answer cannot be found in the context, reply exactly with:
|
|
|
"I'm trained only on Odisha disaster management reports (i.e,OSDMA, NDMA, IMD, Research papers). I don't have any information about: '{question}'"
|
|
|
|
|
|
Context:
|
|
|
{context}
|
|
|
|
|
|
Question:
|
|
|
{question}
|
|
|
|
|
|
Answer:
|
|
|
"""
|
|
|
self.prompt = PromptTemplate(template=custom_prompt, input_variables=["context", "question"])
|
|
|
|
|
|
|
|
|
self.qa_chain = RetrievalQA.from_chain_type(
|
|
|
llm=self.llm,
|
|
|
retriever=self.retriever,
|
|
|
return_source_documents=True,
|
|
|
chain_type_kwargs={"prompt": self.prompt}
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rewrite_query(self, user_input: str) -> str:
|
|
|
"""Rewrite query into formal disaster-management style language using LLM."""
|
|
|
rewrite_prompt = f"""
|
|
|
Rewrite the following user query into clear, formal disaster management language
|
|
|
as used in government reports (OSDMA, NDMA, IMD).
|
|
|
If it is not disaster-related, just return it unchanged.
|
|
|
|
|
|
Query: {user_input}
|
|
|
"""
|
|
|
try:
|
|
|
response = self.llm.invoke([HumanMessage(content=rewrite_prompt)])
|
|
|
return response.content.strip()
|
|
|
except Exception as e:
|
|
|
print("⚠ Rewrite error:", e)
|
|
|
return user_input
|
|
|
|
|
|
def _prefilter_by_distance(self, query: str, k: int = K, max_distance: float = MAX_DISTANCE) -> bool:
|
|
|
"""Check if query is in-domain using FAISS distance."""
|
|
|
results = db.similarity_search_with_score(query, k=k)
|
|
|
if not results:
|
|
|
return False
|
|
|
best_score = results[0][1]
|
|
|
return best_score <= max_distance
|
|
|
|
|
|
def chat(self, user_input: str) -> str:
|
|
|
|
|
|
rewritten_query = self.rewrite_query(user_input)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
in_domain = self._prefilter_by_distance(rewritten_query)
|
|
|
except Exception as e:
|
|
|
print("⚠ prefilter error:", e)
|
|
|
in_domain = True
|
|
|
|
|
|
if not in_domain:
|
|
|
return (
|
|
|
f"I’m trained only on Odisha disaster management reports "
|
|
|
f"(OSDMA, NDMA, IMD, research). I don’t have any information about: '{user_input}'."
|
|
|
)
|
|
|
|
|
|
|
|
|
try:
|
|
|
response = self.qa_chain.invoke({"query": rewritten_query})
|
|
|
answer = response.get("result") if isinstance(response, dict) else str(response)
|
|
|
except Exception as e:
|
|
|
print("⚠ LLM / chain error:", e)
|
|
|
answer = "Sorry, I encountered an error while generating the answer."
|
|
|
|
|
|
|
|
|
self.chat_history.append(HumanMessage(content=user_input))
|
|
|
self.chat_history.append(AIMessage(content=answer))
|
|
|
if len(self.chat_history) > MAX_CHAT_HISTORY * 2:
|
|
|
self.chat_history = self.chat_history[-MAX_CHAT_HISTORY * 2 :]
|
|
|
|
|
|
return answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
bot = RAGChatBot()
|
|
|
print("🤖 Odisha Disaster Management ChatBot ready! Type 'exit' to quit.")
|
|
|
while True:
|
|
|
query = input("You: ")
|
|
|
if query.lower() in ["exit", "quit"]:
|
|
|
break
|
|
|
print("Bot:", bot.chat(query))
|
|
|
|