import os import logging from qdrant_client import QdrantClient from langchain_qdrant import QdrantVectorStore from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_groq import ChatGroq from dotenv import load_dotenv from collections import defaultdict load_dotenv() # -------------------------- # GLOBALS (cached) # -------------------------- EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") VECTORSTORE = None LLM_ROUTER = None LOGGER = logging.getLogger(__name__) # Simple in-memory session history: list of {"role": "user"/"assistant", "content": "..."} CONVERSATION_HISTORY = [] MAX_HISTORY_TURNS = 6 # keep last 6 turns (3 user + 3 assistant) # -------------------------- # QDRANT / LLM SETUP # -------------------------- def get_client(): return QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")) def get_vectorstore(): global VECTORSTORE if VECTORSTORE is None: VECTORSTORE = QdrantVectorStore( client=get_client(), collection_name="psychology_books", embedding=EMBEDDINGS, ) return VECTORSTORE def _collect_groq_api_keys() -> list[str]: """ Collect Groq keys from supported env vars in priority order: 1) GROQ_API_KEYS (comma-separated) 2) GROQ_API_KEY 3) GROQ_API_KEY_2, GROQ_API_KEY_3, ... """ keys = [] combined = os.getenv("GROQ_API_KEYS", "").strip() if combined: keys.extend([key.strip() for key in combined.split(",") if key.strip()]) primary = os.getenv("GROQ_API_KEY", "").strip() if primary: keys.append(primary) numbered_names = [] prefix = "GROQ_API_KEY_" for env_name in os.environ.keys(): if not env_name.startswith(prefix): continue suffix = env_name[len(prefix):] if suffix.isdigit(): numbered_names.append(env_name) numbered_names.sort(key=lambda name: int(name[len(prefix):])) for env_name in numbered_names: value = os.getenv(env_name, "").strip() if value: keys.append(value) deduped = [] seen = set() for key in keys: if key not in seen: seen.add(key) deduped.append(key) return deduped class GroqFailoverLLM: def __init__(self, api_keys: list[str], model: str, temperature: float): self.api_keys = api_keys self.model = model self.temperature = temperature self.clients = [ ChatGroq(model=self.model, temperature=self.temperature, api_key=key) for key in self.api_keys ] self.active_idx = 0 @staticmethod def _is_rate_limit_error(error: Exception) -> bool: text = str(error).lower() markers = [ "rate limit", "too many requests", "429", "quota", "resource exhausted", ] return any(marker in text for marker in markers) def invoke(self, prompt): total = len(self.clients) last_error = None for offset in range(total): idx = (self.active_idx + offset) % total client = self.clients[idx] try: result = client.invoke(prompt) if idx != self.active_idx: LOGGER.warning("Groq fallback active: switched to key #%s", idx + 1) self.active_idx = idx return result except Exception as exc: last_error = exc if total > 1 and self._is_rate_limit_error(exc): LOGGER.warning( "Groq key #%s hit limit/quota. Trying next configured key.", idx + 1, ) continue raise if last_error is not None: raise last_error raise RuntimeError("No Groq clients available") def get_llm(): global LLM_ROUTER if LLM_ROUTER is None: api_keys = _collect_groq_api_keys() if not api_keys: raise ValueError( "No Groq keys configured. Set GROQ_API_KEY, GROQ_API_KEYS, or GROQ_API_KEY_2+ in .env" ) LLM_ROUTER = GroqFailoverLLM( api_keys=api_keys, model="llama-3.1-8b-instant", temperature=0, ) return LLM_ROUTER # -------------------------- # LAYER 1: INTENT ROUTER # -------------------------- INTENT_PROMPT = """Classify this message into EXACTLY one category. Categories: - CHITCHAT : greetings, thanks, small talk, "how are you", jokes - IDENTITY : asking who/what you are, your name, your purpose - CRISIS : suicidal thoughts, self-harm, abuse, urgent danger - PSYCHOLOGY : anything about mental health, behavior, emotions, habits, thinking, relationships Message: "{query}" Reply with only ONE word: CHITCHAT, IDENTITY, CRISIS, or PSYCHOLOGY""" def classify_intent(query: str, llm) -> str: result = llm.invoke(INTENT_PROMPT.format(query=query)) text = getattr(result, "content", str(result)).strip().upper() for intent in ["CRISIS", "CHITCHAT", "IDENTITY", "PSYCHOLOGY"]: if intent in text: return intent return "PSYCHOLOGY" # safe default # -------------------------- # LAYER 2: QUERY REWRITER # -------------------------- def rewrite_query(query: str, history_context: str, llm) -> str: """ Rephrase emotional/vague queries into retrieval-friendly ones. E.g. "I feel so empty" → "emotional numbness causes and psychological explanation" """ prompt = f"""You are a psychology search assistant. Conversation so far: {history_context if history_context else "None"} User query: "{query}" Rewrite this query to be more specific and retrieval-friendly for a psychology book search. Keep it under 15 words. Return ONLY the rewritten query, nothing else.""" result = llm.invoke(prompt) rewritten = getattr(result, "content", str(result)).strip().strip('"') return rewritten if rewritten else query # -------------------------- # LAYER 3: CRAG RELEVANCE GRADER # -------------------------- def grade_chunks(query: str, docs: list, llm) -> list: """ Score each retrieved chunk 1-5 for relevance. Filter out chunks scoring below threshold. Returns (filtered_docs, all_relevant: bool) """ relevant_docs = [] for doc in docs: prompt = f"""Rate this text chunk's relevance to the query on a scale of 1 to 5. Query: "{query}" Chunk: {doc.page_content[:400]} Reply with ONLY a number: 1, 2, 3, 4, or 5 (1=completely irrelevant, 5=directly answers the query)""" result = llm.invoke(prompt) text = getattr(result, "content", str(result)).strip() try: score = int(text[0]) except (ValueError, IndexError): score = 3 # default to neutral if score >= 3: relevant_docs.append(doc) all_relevant = len(relevant_docs) > 0 return relevant_docs, all_relevant # -------------------------- # CONVERSATION MEMORY HELPERS # -------------------------- def get_history_context() -> str: """Format last N turns for injection into prompts.""" if not CONVERSATION_HISTORY: return "" lines = [] for turn in CONVERSATION_HISTORY[-MAX_HISTORY_TURNS:]: role = "User" if turn["role"] == "user" else "Assistant" lines.append(f"{role}: {turn['content']}") return "\n".join(lines) def save_to_history(user_msg: str, assistant_msg: str): CONVERSATION_HISTORY.append({"role": "user", "content": user_msg}) CONVERSATION_HISTORY.append({"role": "assistant", "content": assistant_msg}) def clear_history(): CONVERSATION_HISTORY.clear() # -------------------------- # INTENT HANDLERS # -------------------------- def handle_chitchat(query: str, history_context: str, llm) -> str: prompt = f"""You are a warm, supportive psychology assistant named MindBot. Have a natural, friendly conversation. Keep it brief (1-2 sentences). {f'Conversation so far: {history_context}' if history_context else ''} User: {query} MindBot:""" result = llm.invoke(prompt) return getattr(result, "content", str(result)).strip() def handle_identity(query: str, llm) -> str: prompt = f"""The user asked: "{query}" You are MindBot, a psychology assistant trained on books covering therapy, behavior, cognitive psychology, habits, and emotional well-being. Answer in 2-3 sentences.""" result = llm.invoke(prompt) return getattr(result, "content", str(result)).strip() def handle_crisis(query: str) -> str: return ( "I can hear that you're going through something really difficult right now. " "Please know you're not alone.\n\n" "**iCall (India):** 9152987821\n" "**Vandrevala Foundation (24/7):** 1860-2662-345\n" "**iCall email:** icall@tiss.edu\n\n" "Would you like to talk about what's on your mind? I'm here to listen." ) # -------------------------- # CORE RAG PIPELINE (with CRAG) # -------------------------- def run_rag_with_crag(query: str, history_context: str, llm, max_retries: int = 2) -> tuple: vectorstore = get_vectorstore() # Step 1: rewrite query for better retrieval search_query = rewrite_query(query, history_context, llm) for attempt in range(max_retries): # Step 2: retrieve retriever = vectorstore.as_retriever( search_type="mmr", search_kwargs={"k": 5}, ) docs = retriever.invoke(search_query) if not docs: return "I couldn't find relevant information in the books.", [] # Step 3: CRAG grade relevant_docs, has_relevant = grade_chunks(search_query, docs, llm) if has_relevant: break # If nothing relevant: rewrite more aggressively and retry if attempt < max_retries - 1: search_query = rewrite_query( f"explain in detail: {query}", history_context, llm ) # Step 4: build grouped context if not relevant_docs: relevant_docs = docs[:2] # fallback: use top 2 anyway grouped = defaultdict(list) for doc in relevant_docs: book = doc.metadata.get("book", "Unknown") grouped[book].append(doc.page_content) context_parts = [] for book, texts in grouped.items(): joined = "\n".join(texts[:2]) context_parts.append(f"[Source: {book}]\n{joined}") context = "\n\n---\n\n".join(context_parts) # Step 5: generate with memory + empathy constraint prompt = f"""You are MindBot, a compassionate psychology assistant. Conversation history: {history_context if history_context else "This is the start of the conversation."} Knowledge from books: {context} User's question: {query} Instructions: - Be warm and empathetic first, then informative - Answer using ONLY the provided book context - If multiple perspectives exist, present them clearly - If the question isn't covered in the context, say so honestly - Keep response focused and under 200 words unless depth is needed""" result = llm.invoke(prompt) answer = getattr(result, "content", str(result)).strip() return answer, relevant_docs # -------------------------- # MAIN ENTRY POINT # -------------------------- def ask_question(query: str) -> tuple: llm = get_llm() history_context = get_history_context() # Route intent intent = classify_intent(query, llm) if intent == "CRISIS": response = handle_crisis(query) save_to_history(query, response) return response, [] elif intent == "CHITCHAT": response = handle_chitchat(query, history_context, llm) save_to_history(query, response) return response, [] elif intent == "IDENTITY": response = handle_identity(query, llm) save_to_history(query, response) return response, [] else: # PSYCHOLOGY response, docs = run_rag_with_crag(query, history_context, llm) save_to_history(query, response) return response, docs