import os # --- add these 3 lines before anything Hugging Face runs --- os.environ["HF_HOME"] = "/tmp/hf_cache" os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/hf_cache" from typing import List from dotenv import load_dotenv from langchain_groq import ChatGroq from langchain.schema import HumanMessage, AIMessage from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA # --------------------------- # Load environment variables # --------------------------- load_dotenv() GROQ_API_KEY = os.getenv("GROQ_API_KEY") # --------------------------- # Settings / Tuning # --------------------------- DB_FAISS_PATH = "vectorStore" EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" K = 5 # how many candidates to check for pre-filter MAX_DISTANCE = 1.0 # FAISS distance threshold (lower = better). MAX_CHAT_HISTORY = 50 # cap chat history to avoid unbounded growth # --------------------------- # Load FAISS VectorStore # --------------------------- embeddings = HuggingFaceEmbeddings( model_name=EMBED_MODEL, cache_folder="/tmp/hf_cache" # <--- new ) db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) # --------------------------- # ChatBot Class # --------------------------- class RAGChatBot: def __init__(self): # LLM if not GROQ_API_KEY: raise ValueError("GROQ_API_KEY not set in environment") self.llm = ChatGroq( groq_api_key=GROQ_API_KEY, model="llama-3.1-8b-instant", temperature=0 ) self.chat_history: List = [] # Retriever used by RetrievalQA (kept, but we will pre-filter before calling the chain) self.retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) # Custom Prompt (dynamic fallback included) custom_prompt = """ Use the following context to answer the user’s question. If the answer cannot be found in the context, reply exactly with: "I'm trained only on Odisha disaster management reports (i.e,OSDMA, NDMA, IMD, Research papers). I don't have any information about: '{question}'" Context: {context} Question: {question} Answer: """ self.prompt = PromptTemplate(template=custom_prompt, input_variables=["context", "question"]) # Retrieval QA Chain (keeps structured QA behavior) self.qa_chain = RetrievalQA.from_chain_type( llm=self.llm, retriever=self.retriever, return_source_documents=True, chain_type_kwargs={"prompt": self.prompt} ) # --------------------------- # NEW: Rewrite function # --------------------------- def rewrite_query(self, user_input: str) -> str: """Rewrite query into formal disaster-management style language using LLM.""" rewrite_prompt = f""" Rewrite the following user query into clear, formal disaster management language as used in government reports (OSDMA, NDMA, IMD). If it is not disaster-related, just return it unchanged. Query: {user_input} """ try: response = self.llm.invoke([HumanMessage(content=rewrite_prompt)]) return response.content.strip() except Exception as e: print("⚠ Rewrite error:", e) return user_input # fallback to original def _prefilter_by_distance(self, query: str, k: int = K, max_distance: float = MAX_DISTANCE) -> bool: """Check if query is in-domain using FAISS distance.""" results = db.similarity_search_with_score(query, k=k) if not results: return False best_score = results[0][1] # (Document, score) return best_score <= max_distance def chat(self, user_input: str) -> str: # 1) Rewrite user query rewritten_query = self.rewrite_query(user_input) # print(f"[debug] rewritten query: {rewritten_query}") # 2) Quick in-domain prefilter try: in_domain = self._prefilter_by_distance(rewritten_query) except Exception as e: print("⚠ prefilter error:", e) in_domain = True if not in_domain: return ( f"I’m trained only on Odisha disaster management reports " f"(OSDMA, NDMA, IMD, research). I don’t have any information about: '{user_input}'." ) # 3) Retrieval + QA try: response = self.qa_chain.invoke({"query": rewritten_query}) answer = response.get("result") if isinstance(response, dict) else str(response) except Exception as e: print("⚠ LLM / chain error:", e) answer = "Sorry, I encountered an error while generating the answer." # 4) Update memory (bounded) self.chat_history.append(HumanMessage(content=user_input)) self.chat_history.append(AIMessage(content=answer)) if len(self.chat_history) > MAX_CHAT_HISTORY * 2: self.chat_history = self.chat_history[-MAX_CHAT_HISTORY * 2 :] return answer # --------------------------- # Run Chatbot (CLI) # --------------------------- if __name__ == "__main__": bot = RAGChatBot() print("🤖 Odisha Disaster Management ChatBot ready! Type 'exit' to quit.") while True: query = input("You: ") if query.lower() in ["exit", "quit"]: break print("Bot:", bot.chat(query))