File size: 5,801 Bytes
4787e22
b3c9dfb
 
 
 
 
 
4787e22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3c9dfb
 
 
 
4787e22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os

# --- add these 3 lines before anything Hugging Face runs ---
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/hf_cache"

from typing import List
from dotenv import load_dotenv

from langchain_groq import ChatGroq
from langchain.schema import HumanMessage, AIMessage
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA

# ---------------------------
# Load environment variables
# ---------------------------
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

# ---------------------------
# Settings / Tuning
# ---------------------------
DB_FAISS_PATH = "vectorStore"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
K = 5                    # how many candidates to check for pre-filter
MAX_DISTANCE = 1.0       # FAISS distance threshold (lower = better). 
MAX_CHAT_HISTORY = 50    # cap chat history to avoid unbounded growth

# ---------------------------
# Load FAISS VectorStore
# ---------------------------
embeddings = HuggingFaceEmbeddings(
    model_name=EMBED_MODEL,
    cache_folder="/tmp/hf_cache"   # <--- new
)
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)

# ---------------------------
# ChatBot Class
# ---------------------------
class RAGChatBot:
    def __init__(self):
        # LLM
        if not GROQ_API_KEY:
            raise ValueError("GROQ_API_KEY not set in environment")
        self.llm = ChatGroq(
            groq_api_key=GROQ_API_KEY,
            model="llama-3.1-8b-instant",
            temperature=0
        )
        self.chat_history: List = []

        # Retriever used by RetrievalQA (kept, but we will pre-filter before calling the chain)
        self.retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})

        # Custom Prompt (dynamic fallback included)
        custom_prompt = """

Use the following context to answer the user’s question.

If the answer cannot be found in the context, reply exactly with:

"I'm trained only on Odisha disaster management reports (i.e,OSDMA, NDMA, IMD, Research papers). I don't have any information about: '{question}'"



Context:

{context}



Question:

{question}



Answer:

"""
        self.prompt = PromptTemplate(template=custom_prompt, input_variables=["context", "question"])

        # Retrieval QA Chain (keeps structured QA behavior)
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            retriever=self.retriever,
            return_source_documents=True,
            chain_type_kwargs={"prompt": self.prompt}
        )

    # ---------------------------
    # NEW: Rewrite function
    # ---------------------------
    def rewrite_query(self, user_input: str) -> str:
        """Rewrite query into formal disaster-management style language using LLM."""
        rewrite_prompt = f"""

        Rewrite the following user query into clear, formal disaster management language

        as used in government reports (OSDMA, NDMA, IMD).

        If it is not disaster-related, just return it unchanged.



        Query: {user_input}

        """
        try:
            response = self.llm.invoke([HumanMessage(content=rewrite_prompt)])
            return response.content.strip()
        except Exception as e:
            print("⚠ Rewrite error:", e)
            return user_input  # fallback to original

    def _prefilter_by_distance(self, query: str, k: int = K, max_distance: float = MAX_DISTANCE) -> bool:
        """Check if query is in-domain using FAISS distance."""
        results = db.similarity_search_with_score(query, k=k)
        if not results:
            return False
        best_score = results[0][1]  # (Document, score)
        return best_score <= max_distance

    def chat(self, user_input: str) -> str:
        # 1) Rewrite user query
        rewritten_query = self.rewrite_query(user_input)
        # print(f"[debug] rewritten query: {rewritten_query}")

        # 2) Quick in-domain prefilter
        try:
            in_domain = self._prefilter_by_distance(rewritten_query)
        except Exception as e:
            print("⚠ prefilter error:", e)
            in_domain = True

        if not in_domain:
            return (
                f"I’m trained only on Odisha disaster management reports "
                f"(OSDMA, NDMA, IMD, research). I don’t have any information about: '{user_input}'."
            )

        # 3) Retrieval + QA
        try:
            response = self.qa_chain.invoke({"query": rewritten_query})
            answer = response.get("result") if isinstance(response, dict) else str(response)
        except Exception as e:
            print("⚠ LLM / chain error:", e)
            answer = "Sorry, I encountered an error while generating the answer."

        # 4) Update memory (bounded)
        self.chat_history.append(HumanMessage(content=user_input))
        self.chat_history.append(AIMessage(content=answer))
        if len(self.chat_history) > MAX_CHAT_HISTORY * 2:
            self.chat_history = self.chat_history[-MAX_CHAT_HISTORY * 2 :]

        return answer


# ---------------------------
# Run Chatbot (CLI)
# ---------------------------
if __name__ == "__main__":
    bot = RAGChatBot()
    print("🤖 Odisha Disaster Management ChatBot ready! Type 'exit' to quit.")
    while True:
        query = input("You: ")
        if query.lower() in ["exit", "quit"]:
            break
        print("Bot:", bot.chat(query))