File size: 12,228 Bytes
edabb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
import os
import logging
from qdrant_client import QdrantClient
from langchain_qdrant import QdrantVectorStore
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from dotenv import load_dotenv
from collections import defaultdict

load_dotenv()

# --------------------------
# GLOBALS (cached)
# --------------------------

EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
VECTORSTORE = None
LLM_ROUTER = None
LOGGER = logging.getLogger(__name__)

# Simple in-memory session history: list of {"role": "user"/"assistant", "content": "..."}
CONVERSATION_HISTORY = []
MAX_HISTORY_TURNS = 6  # keep last 6 turns (3 user + 3 assistant)


# --------------------------
# QDRANT / LLM SETUP
# --------------------------

def get_client():
    return QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))

def get_vectorstore():
    global VECTORSTORE
    if VECTORSTORE is None:
        VECTORSTORE = QdrantVectorStore(
            client=get_client(),
            collection_name="psychology_books",
            embedding=EMBEDDINGS,
        )
    return VECTORSTORE


def _collect_groq_api_keys() -> list[str]:
    """
    Collect Groq keys from supported env vars in priority order:
    1) GROQ_API_KEYS (comma-separated)
    2) GROQ_API_KEY
    3) GROQ_API_KEY_2, GROQ_API_KEY_3, ...
    """
    keys = []

    combined = os.getenv("GROQ_API_KEYS", "").strip()
    if combined:
        keys.extend([key.strip() for key in combined.split(",") if key.strip()])

    primary = os.getenv("GROQ_API_KEY", "").strip()
    if primary:
        keys.append(primary)

    numbered_names = []
    prefix = "GROQ_API_KEY_"
    for env_name in os.environ.keys():
        if not env_name.startswith(prefix):
            continue
        suffix = env_name[len(prefix):]
        if suffix.isdigit():
            numbered_names.append(env_name)

    numbered_names.sort(key=lambda name: int(name[len(prefix):]))
    for env_name in numbered_names:
        value = os.getenv(env_name, "").strip()
        if value:
            keys.append(value)

    deduped = []
    seen = set()
    for key in keys:
        if key not in seen:
            seen.add(key)
            deduped.append(key)

    return deduped


class GroqFailoverLLM:
    def __init__(self, api_keys: list[str], model: str, temperature: float):
        self.api_keys = api_keys
        self.model = model
        self.temperature = temperature
        self.clients = [
            ChatGroq(model=self.model, temperature=self.temperature, api_key=key)
            for key in self.api_keys
        ]
        self.active_idx = 0

    @staticmethod
    def _is_rate_limit_error(error: Exception) -> bool:
        text = str(error).lower()
        markers = [
            "rate limit",
            "too many requests",
            "429",
            "quota",
            "resource exhausted",
        ]
        return any(marker in text for marker in markers)

    def invoke(self, prompt):
        total = len(self.clients)
        last_error = None

        for offset in range(total):
            idx = (self.active_idx + offset) % total
            client = self.clients[idx]

            try:
                result = client.invoke(prompt)
                if idx != self.active_idx:
                    LOGGER.warning("Groq fallback active: switched to key #%s", idx + 1)
                self.active_idx = idx
                return result
            except Exception as exc:
                last_error = exc
                if total > 1 and self._is_rate_limit_error(exc):
                    LOGGER.warning(
                        "Groq key #%s hit limit/quota. Trying next configured key.",
                        idx + 1,
                    )
                    continue
                raise

        if last_error is not None:
            raise last_error
        raise RuntimeError("No Groq clients available")

def get_llm():
    global LLM_ROUTER

    if LLM_ROUTER is None:
        api_keys = _collect_groq_api_keys()
        if not api_keys:
            raise ValueError(
                "No Groq keys configured. Set GROQ_API_KEY, GROQ_API_KEYS, or GROQ_API_KEY_2+ in .env"
            )
        LLM_ROUTER = GroqFailoverLLM(
            api_keys=api_keys,
            model="llama-3.1-8b-instant",
            temperature=0,
        )

    return LLM_ROUTER


# --------------------------
# LAYER 1: INTENT ROUTER
# --------------------------

INTENT_PROMPT = """Classify this message into EXACTLY one category.

Categories:
- CHITCHAT     : greetings, thanks, small talk, "how are you", jokes
- IDENTITY     : asking who/what you are, your name, your purpose
- CRISIS       : suicidal thoughts, self-harm, abuse, urgent danger
- PSYCHOLOGY   : anything about mental health, behavior, emotions, habits, thinking, relationships

Message: "{query}"

Reply with only ONE word: CHITCHAT, IDENTITY, CRISIS, or PSYCHOLOGY"""

def classify_intent(query: str, llm) -> str:
    result = llm.invoke(INTENT_PROMPT.format(query=query))
    text = getattr(result, "content", str(result)).strip().upper()

    for intent in ["CRISIS", "CHITCHAT", "IDENTITY", "PSYCHOLOGY"]:
        if intent in text:
            return intent
    return "PSYCHOLOGY"  # safe default


# --------------------------
# LAYER 2: QUERY REWRITER
# --------------------------

def rewrite_query(query: str, history_context: str, llm) -> str:
    """
    Rephrase emotional/vague queries into retrieval-friendly ones.
    E.g. "I feel so empty" → "emotional numbness causes and psychological explanation"
    """
    prompt = f"""You are a psychology search assistant.

Conversation so far:
{history_context if history_context else "None"}

User query: "{query}"

Rewrite this query to be more specific and retrieval-friendly for a psychology book search.
Keep it under 15 words. Return ONLY the rewritten query, nothing else."""

    result = llm.invoke(prompt)
    rewritten = getattr(result, "content", str(result)).strip().strip('"')
    return rewritten if rewritten else query


# --------------------------
# LAYER 3: CRAG RELEVANCE GRADER
# --------------------------

def grade_chunks(query: str, docs: list, llm) -> list:
    """
    Score each retrieved chunk 1-5 for relevance.
    Filter out chunks scoring below threshold.
    Returns (filtered_docs, all_relevant: bool)
    """
    relevant_docs = []

    for doc in docs:
        prompt = f"""Rate this text chunk's relevance to the query on a scale of 1 to 5.

Query: "{query}"

Chunk:
{doc.page_content[:400]}

Reply with ONLY a number: 1, 2, 3, 4, or 5
(1=completely irrelevant, 5=directly answers the query)"""

        result = llm.invoke(prompt)
        text = getattr(result, "content", str(result)).strip()

        try:
            score = int(text[0])
        except (ValueError, IndexError):
            score = 3  # default to neutral

        if score >= 3:
            relevant_docs.append(doc)

    all_relevant = len(relevant_docs) > 0
    return relevant_docs, all_relevant


# --------------------------
# CONVERSATION MEMORY HELPERS
# --------------------------

def get_history_context() -> str:
    """Format last N turns for injection into prompts."""
    if not CONVERSATION_HISTORY:
        return ""
    
    lines = []
    for turn in CONVERSATION_HISTORY[-MAX_HISTORY_TURNS:]:
        role = "User" if turn["role"] == "user" else "Assistant"
        lines.append(f"{role}: {turn['content']}")
    return "\n".join(lines)

def save_to_history(user_msg: str, assistant_msg: str):
    CONVERSATION_HISTORY.append({"role": "user", "content": user_msg})
    CONVERSATION_HISTORY.append({"role": "assistant", "content": assistant_msg})

def clear_history():
    CONVERSATION_HISTORY.clear()


# --------------------------
# INTENT HANDLERS
# --------------------------

def handle_chitchat(query: str, history_context: str, llm) -> str:
    prompt = f"""You are a warm, supportive psychology assistant named MindBot.
Have a natural, friendly conversation. Keep it brief (1-2 sentences).

{f'Conversation so far: {history_context}' if history_context else ''}

User: {query}
MindBot:"""
    result = llm.invoke(prompt)
    return getattr(result, "content", str(result)).strip()

def handle_identity(query: str, llm) -> str:
    prompt = f"""The user asked: "{query}"

You are MindBot, a psychology assistant trained on books covering therapy, behavior, 
cognitive psychology, habits, and emotional well-being. Answer in 2-3 sentences."""
    result = llm.invoke(prompt)
    return getattr(result, "content", str(result)).strip()

def handle_crisis(query: str) -> str:
    return (
        "I can hear that you're going through something really difficult right now. "
        "Please know you're not alone.\n\n"
        "**iCall (India):** 9152987821\n"
        "**Vandrevala Foundation (24/7):** 1860-2662-345\n"
        "**iCall email:** icall@tiss.edu\n\n"
        "Would you like to talk about what's on your mind? I'm here to listen."
    )


# --------------------------
# CORE RAG PIPELINE (with CRAG)
# --------------------------

def run_rag_with_crag(query: str, history_context: str, llm, max_retries: int = 2) -> tuple:
    vectorstore = get_vectorstore()

    # Step 1: rewrite query for better retrieval
    search_query = rewrite_query(query, history_context, llm)

    for attempt in range(max_retries):
        # Step 2: retrieve
        retriever = vectorstore.as_retriever(
            search_type="mmr",
            search_kwargs={"k": 5},
        )
        docs = retriever.invoke(search_query)

        if not docs:
            return "I couldn't find relevant information in the books.", []

        # Step 3: CRAG grade
        relevant_docs, has_relevant = grade_chunks(search_query, docs, llm)

        if has_relevant:
            break

        # If nothing relevant: rewrite more aggressively and retry
        if attempt < max_retries - 1:
            search_query = rewrite_query(
                f"explain in detail: {query}",
                history_context,
                llm
            )

    # Step 4: build grouped context
    if not relevant_docs:
        relevant_docs = docs[:2]  # fallback: use top 2 anyway

    grouped = defaultdict(list)
    for doc in relevant_docs:
        book = doc.metadata.get("book", "Unknown")
        grouped[book].append(doc.page_content)

    context_parts = []
    for book, texts in grouped.items():
        joined = "\n".join(texts[:2])
        context_parts.append(f"[Source: {book}]\n{joined}")
    context = "\n\n---\n\n".join(context_parts)

    # Step 5: generate with memory + empathy constraint
    prompt = f"""You are MindBot, a compassionate psychology assistant.

Conversation history:
{history_context if history_context else "This is the start of the conversation."}

Knowledge from books:
{context}

User's question: {query}

Instructions:
- Be warm and empathetic first, then informative
- Answer using ONLY the provided book context
- If multiple perspectives exist, present them clearly
- If the question isn't covered in the context, say so honestly
- Keep response focused and under 200 words unless depth is needed"""

    result = llm.invoke(prompt)
    answer = getattr(result, "content", str(result)).strip()

    return answer, relevant_docs


# --------------------------
# MAIN ENTRY POINT
# --------------------------

def ask_question(query: str) -> tuple:
    llm = get_llm()
    history_context = get_history_context()

    # Route intent
    intent = classify_intent(query, llm)

    if intent == "CRISIS":
        response = handle_crisis(query)
        save_to_history(query, response)
        return response, []

    elif intent == "CHITCHAT":
        response = handle_chitchat(query, history_context, llm)
        save_to_history(query, response)
        return response, []

    elif intent == "IDENTITY":
        response = handle_identity(query, llm)
        save_to_history(query, response)
        return response, []

    else:  # PSYCHOLOGY
        response, docs = run_rag_with_crag(query, history_context, llm)
        save_to_history(query, response)
        return response, docs