File size: 5,801 Bytes
4787e22 b3c9dfb 4787e22 b3c9dfb 4787e22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
# --- add these 3 lines before anything Hugging Face runs ---
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/hf_cache"
from typing import List
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain.schema import HumanMessage, AIMessage
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
# ---------------------------
# Load environment variables
# ---------------------------
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
# ---------------------------
# Settings / Tuning
# ---------------------------
DB_FAISS_PATH = "vectorStore"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
K = 5 # how many candidates to check for pre-filter
MAX_DISTANCE = 1.0 # FAISS distance threshold (lower = better).
MAX_CHAT_HISTORY = 50 # cap chat history to avoid unbounded growth
# ---------------------------
# Load FAISS VectorStore
# ---------------------------
embeddings = HuggingFaceEmbeddings(
model_name=EMBED_MODEL,
cache_folder="/tmp/hf_cache" # <--- new
)
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
# ---------------------------
# ChatBot Class
# ---------------------------
class RAGChatBot:
def __init__(self):
# LLM
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not set in environment")
self.llm = ChatGroq(
groq_api_key=GROQ_API_KEY,
model="llama-3.1-8b-instant",
temperature=0
)
self.chat_history: List = []
# Retriever used by RetrievalQA (kept, but we will pre-filter before calling the chain)
self.retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
# Custom Prompt (dynamic fallback included)
custom_prompt = """
Use the following context to answer the user’s question.
If the answer cannot be found in the context, reply exactly with:
"I'm trained only on Odisha disaster management reports (i.e,OSDMA, NDMA, IMD, Research papers). I don't have any information about: '{question}'"
Context:
{context}
Question:
{question}
Answer:
"""
self.prompt = PromptTemplate(template=custom_prompt, input_variables=["context", "question"])
# Retrieval QA Chain (keeps structured QA behavior)
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
retriever=self.retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": self.prompt}
)
# ---------------------------
# NEW: Rewrite function
# ---------------------------
def rewrite_query(self, user_input: str) -> str:
"""Rewrite query into formal disaster-management style language using LLM."""
rewrite_prompt = f"""
Rewrite the following user query into clear, formal disaster management language
as used in government reports (OSDMA, NDMA, IMD).
If it is not disaster-related, just return it unchanged.
Query: {user_input}
"""
try:
response = self.llm.invoke([HumanMessage(content=rewrite_prompt)])
return response.content.strip()
except Exception as e:
print("⚠ Rewrite error:", e)
return user_input # fallback to original
def _prefilter_by_distance(self, query: str, k: int = K, max_distance: float = MAX_DISTANCE) -> bool:
"""Check if query is in-domain using FAISS distance."""
results = db.similarity_search_with_score(query, k=k)
if not results:
return False
best_score = results[0][1] # (Document, score)
return best_score <= max_distance
def chat(self, user_input: str) -> str:
# 1) Rewrite user query
rewritten_query = self.rewrite_query(user_input)
# print(f"[debug] rewritten query: {rewritten_query}")
# 2) Quick in-domain prefilter
try:
in_domain = self._prefilter_by_distance(rewritten_query)
except Exception as e:
print("⚠ prefilter error:", e)
in_domain = True
if not in_domain:
return (
f"I’m trained only on Odisha disaster management reports "
f"(OSDMA, NDMA, IMD, research). I don’t have any information about: '{user_input}'."
)
# 3) Retrieval + QA
try:
response = self.qa_chain.invoke({"query": rewritten_query})
answer = response.get("result") if isinstance(response, dict) else str(response)
except Exception as e:
print("⚠ LLM / chain error:", e)
answer = "Sorry, I encountered an error while generating the answer."
# 4) Update memory (bounded)
self.chat_history.append(HumanMessage(content=user_input))
self.chat_history.append(AIMessage(content=answer))
if len(self.chat_history) > MAX_CHAT_HISTORY * 2:
self.chat_history = self.chat_history[-MAX_CHAT_HISTORY * 2 :]
return answer
# ---------------------------
# Run Chatbot (CLI)
# ---------------------------
if __name__ == "__main__":
bot = RAGChatBot()
print("🤖 Odisha Disaster Management ChatBot ready! Type 'exit' to quit.")
while True:
query = input("You: ")
if query.lower() in ["exit", "quit"]:
break
print("Bot:", bot.chat(query))
|