Spaces:
Running
Running
| import os | |
| import logging | |
| from qdrant_client import QdrantClient | |
| from langchain_qdrant import QdrantVectorStore | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_groq import ChatGroq | |
| from dotenv import load_dotenv | |
| from collections import defaultdict | |
| load_dotenv() | |
| # -------------------------- | |
| # GLOBALS (cached) | |
| # -------------------------- | |
| EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| VECTORSTORE = None | |
| LLM_ROUTER = None | |
| LOGGER = logging.getLogger(__name__) | |
| # Simple in-memory session history: list of {"role": "user"/"assistant", "content": "..."} | |
| CONVERSATION_HISTORY = [] | |
| MAX_HISTORY_TURNS = 6 # keep last 6 turns (3 user + 3 assistant) | |
| # -------------------------- | |
| # QDRANT / LLM SETUP | |
| # -------------------------- | |
| def get_client(): | |
| return QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY")) | |
| def get_vectorstore(): | |
| global VECTORSTORE | |
| if VECTORSTORE is None: | |
| VECTORSTORE = QdrantVectorStore( | |
| client=get_client(), | |
| collection_name="psychology_books", | |
| embedding=EMBEDDINGS, | |
| ) | |
| return VECTORSTORE | |
| def _collect_groq_api_keys() -> list[str]: | |
| """ | |
| Collect Groq keys from supported env vars in priority order: | |
| 1) GROQ_API_KEYS (comma-separated) | |
| 2) GROQ_API_KEY | |
| 3) GROQ_API_KEY_2, GROQ_API_KEY_3, ... | |
| """ | |
| keys = [] | |
| combined = os.getenv("GROQ_API_KEYS", "").strip() | |
| if combined: | |
| keys.extend([key.strip() for key in combined.split(",") if key.strip()]) | |
| primary = os.getenv("GROQ_API_KEY", "").strip() | |
| if primary: | |
| keys.append(primary) | |
| numbered_names = [] | |
| prefix = "GROQ_API_KEY_" | |
| for env_name in os.environ.keys(): | |
| if not env_name.startswith(prefix): | |
| continue | |
| suffix = env_name[len(prefix):] | |
| if suffix.isdigit(): | |
| numbered_names.append(env_name) | |
| numbered_names.sort(key=lambda name: int(name[len(prefix):])) | |
| for env_name in numbered_names: | |
| value = os.getenv(env_name, "").strip() | |
| if value: | |
| keys.append(value) | |
| deduped = [] | |
| seen = set() | |
| for key in keys: | |
| if key not in seen: | |
| seen.add(key) | |
| deduped.append(key) | |
| return deduped | |
| class GroqFailoverLLM: | |
| def __init__(self, api_keys: list[str], model: str, temperature: float): | |
| self.api_keys = api_keys | |
| self.model = model | |
| self.temperature = temperature | |
| self.clients = [ | |
| ChatGroq(model=self.model, temperature=self.temperature, api_key=key) | |
| for key in self.api_keys | |
| ] | |
| self.active_idx = 0 | |
| def _is_rate_limit_error(error: Exception) -> bool: | |
| text = str(error).lower() | |
| markers = [ | |
| "rate limit", | |
| "too many requests", | |
| "429", | |
| "quota", | |
| "resource exhausted", | |
| ] | |
| return any(marker in text for marker in markers) | |
| def invoke(self, prompt): | |
| total = len(self.clients) | |
| last_error = None | |
| for offset in range(total): | |
| idx = (self.active_idx + offset) % total | |
| client = self.clients[idx] | |
| try: | |
| result = client.invoke(prompt) | |
| if idx != self.active_idx: | |
| LOGGER.warning("Groq fallback active: switched to key #%s", idx + 1) | |
| self.active_idx = idx | |
| return result | |
| except Exception as exc: | |
| last_error = exc | |
| if total > 1 and self._is_rate_limit_error(exc): | |
| LOGGER.warning( | |
| "Groq key #%s hit limit/quota. Trying next configured key.", | |
| idx + 1, | |
| ) | |
| continue | |
| raise | |
| if last_error is not None: | |
| raise last_error | |
| raise RuntimeError("No Groq clients available") | |
| def get_llm(): | |
| global LLM_ROUTER | |
| if LLM_ROUTER is None: | |
| api_keys = _collect_groq_api_keys() | |
| if not api_keys: | |
| raise ValueError( | |
| "No Groq keys configured. Set GROQ_API_KEY, GROQ_API_KEYS, or GROQ_API_KEY_2+ in .env" | |
| ) | |
| LLM_ROUTER = GroqFailoverLLM( | |
| api_keys=api_keys, | |
| model="llama-3.1-8b-instant", | |
| temperature=0, | |
| ) | |
| return LLM_ROUTER | |
| # -------------------------- | |
| # LAYER 1: INTENT ROUTER | |
| # -------------------------- | |
| INTENT_PROMPT = """Classify this message into EXACTLY one category. | |
| Categories: | |
| - CHITCHAT : greetings, thanks, small talk, "how are you", jokes | |
| - IDENTITY : asking who/what you are, your name, your purpose | |
| - CRISIS : suicidal thoughts, self-harm, abuse, urgent danger | |
| - PSYCHOLOGY : anything about mental health, behavior, emotions, habits, thinking, relationships | |
| Message: "{query}" | |
| Reply with only ONE word: CHITCHAT, IDENTITY, CRISIS, or PSYCHOLOGY""" | |
| def classify_intent(query: str, llm) -> str: | |
| result = llm.invoke(INTENT_PROMPT.format(query=query)) | |
| text = getattr(result, "content", str(result)).strip().upper() | |
| for intent in ["CRISIS", "CHITCHAT", "IDENTITY", "PSYCHOLOGY"]: | |
| if intent in text: | |
| return intent | |
| return "PSYCHOLOGY" # safe default | |
| # -------------------------- | |
| # LAYER 2: QUERY REWRITER | |
| # -------------------------- | |
| def rewrite_query(query: str, history_context: str, llm) -> str: | |
| """ | |
| Rephrase emotional/vague queries into retrieval-friendly ones. | |
| E.g. "I feel so empty" → "emotional numbness causes and psychological explanation" | |
| """ | |
| prompt = f"""You are a psychology search assistant. | |
| Conversation so far: | |
| {history_context if history_context else "None"} | |
| User query: "{query}" | |
| Rewrite this query to be more specific and retrieval-friendly for a psychology book search. | |
| Keep it under 15 words. Return ONLY the rewritten query, nothing else.""" | |
| result = llm.invoke(prompt) | |
| rewritten = getattr(result, "content", str(result)).strip().strip('"') | |
| return rewritten if rewritten else query | |
| # -------------------------- | |
| # LAYER 3: CRAG RELEVANCE GRADER | |
| # -------------------------- | |
| def grade_chunks(query: str, docs: list, llm) -> list: | |
| """ | |
| Score each retrieved chunk 1-5 for relevance. | |
| Filter out chunks scoring below threshold. | |
| Returns (filtered_docs, all_relevant: bool) | |
| """ | |
| relevant_docs = [] | |
| for doc in docs: | |
| prompt = f"""Rate this text chunk's relevance to the query on a scale of 1 to 5. | |
| Query: "{query}" | |
| Chunk: | |
| {doc.page_content[:400]} | |
| Reply with ONLY a number: 1, 2, 3, 4, or 5 | |
| (1=completely irrelevant, 5=directly answers the query)""" | |
| result = llm.invoke(prompt) | |
| text = getattr(result, "content", str(result)).strip() | |
| try: | |
| score = int(text[0]) | |
| except (ValueError, IndexError): | |
| score = 3 # default to neutral | |
| if score >= 3: | |
| relevant_docs.append(doc) | |
| all_relevant = len(relevant_docs) > 0 | |
| return relevant_docs, all_relevant | |
| # -------------------------- | |
| # CONVERSATION MEMORY HELPERS | |
| # -------------------------- | |
| def get_history_context() -> str: | |
| """Format last N turns for injection into prompts.""" | |
| if not CONVERSATION_HISTORY: | |
| return "" | |
| lines = [] | |
| for turn in CONVERSATION_HISTORY[-MAX_HISTORY_TURNS:]: | |
| role = "User" if turn["role"] == "user" else "Assistant" | |
| lines.append(f"{role}: {turn['content']}") | |
| return "\n".join(lines) | |
| def save_to_history(user_msg: str, assistant_msg: str): | |
| CONVERSATION_HISTORY.append({"role": "user", "content": user_msg}) | |
| CONVERSATION_HISTORY.append({"role": "assistant", "content": assistant_msg}) | |
| def clear_history(): | |
| CONVERSATION_HISTORY.clear() | |
| # -------------------------- | |
| # INTENT HANDLERS | |
| # -------------------------- | |
| def handle_chitchat(query: str, history_context: str, llm) -> str: | |
| prompt = f"""You are a warm, supportive psychology assistant named MindBot. | |
| Have a natural, friendly conversation. Keep it brief (1-2 sentences). | |
| {f'Conversation so far: {history_context}' if history_context else ''} | |
| User: {query} | |
| MindBot:""" | |
| result = llm.invoke(prompt) | |
| return getattr(result, "content", str(result)).strip() | |
| def handle_identity(query: str, llm) -> str: | |
| prompt = f"""The user asked: "{query}" | |
| You are MindBot, a psychology assistant trained on books covering therapy, behavior, | |
| cognitive psychology, habits, and emotional well-being. Answer in 2-3 sentences.""" | |
| result = llm.invoke(prompt) | |
| return getattr(result, "content", str(result)).strip() | |
| def handle_crisis(query: str) -> str: | |
| return ( | |
| "I can hear that you're going through something really difficult right now. " | |
| "Please know you're not alone.\n\n" | |
| "**iCall (India):** 9152987821\n" | |
| "**Vandrevala Foundation (24/7):** 1860-2662-345\n" | |
| "**iCall email:** icall@tiss.edu\n\n" | |
| "Would you like to talk about what's on your mind? I'm here to listen." | |
| ) | |
| # -------------------------- | |
| # CORE RAG PIPELINE (with CRAG) | |
| # -------------------------- | |
| def run_rag_with_crag(query: str, history_context: str, llm, max_retries: int = 2) -> tuple: | |
| vectorstore = get_vectorstore() | |
| # Step 1: rewrite query for better retrieval | |
| search_query = rewrite_query(query, history_context, llm) | |
| for attempt in range(max_retries): | |
| # Step 2: retrieve | |
| retriever = vectorstore.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": 5}, | |
| ) | |
| docs = retriever.invoke(search_query) | |
| if not docs: | |
| return "I couldn't find relevant information in the books.", [] | |
| # Step 3: CRAG grade | |
| relevant_docs, has_relevant = grade_chunks(search_query, docs, llm) | |
| if has_relevant: | |
| break | |
| # If nothing relevant: rewrite more aggressively and retry | |
| if attempt < max_retries - 1: | |
| search_query = rewrite_query( | |
| f"explain in detail: {query}", | |
| history_context, | |
| llm | |
| ) | |
| # Step 4: build grouped context | |
| if not relevant_docs: | |
| relevant_docs = docs[:2] # fallback: use top 2 anyway | |
| grouped = defaultdict(list) | |
| for doc in relevant_docs: | |
| book = doc.metadata.get("book", "Unknown") | |
| grouped[book].append(doc.page_content) | |
| context_parts = [] | |
| for book, texts in grouped.items(): | |
| joined = "\n".join(texts[:2]) | |
| context_parts.append(f"[Source: {book}]\n{joined}") | |
| context = "\n\n---\n\n".join(context_parts) | |
| # Step 5: generate with memory + empathy constraint | |
| prompt = f"""You are MindBot, a compassionate psychology assistant. | |
| Conversation history: | |
| {history_context if history_context else "This is the start of the conversation."} | |
| Knowledge from books: | |
| {context} | |
| User's question: {query} | |
| Instructions: | |
| - Be warm and empathetic first, then informative | |
| - Answer using ONLY the provided book context | |
| - If multiple perspectives exist, present them clearly | |
| - If the question isn't covered in the context, say so honestly | |
| - Keep response focused and under 200 words unless depth is needed""" | |
| result = llm.invoke(prompt) | |
| answer = getattr(result, "content", str(result)).strip() | |
| return answer, relevant_docs | |
| # -------------------------- | |
| # MAIN ENTRY POINT | |
| # -------------------------- | |
| def ask_question(query: str) -> tuple: | |
| llm = get_llm() | |
| history_context = get_history_context() | |
| # Route intent | |
| intent = classify_intent(query, llm) | |
| if intent == "CRISIS": | |
| response = handle_crisis(query) | |
| save_to_history(query, response) | |
| return response, [] | |
| elif intent == "CHITCHAT": | |
| response = handle_chitchat(query, history_context, llm) | |
| save_to_history(query, response) | |
| return response, [] | |
| elif intent == "IDENTITY": | |
| response = handle_identity(query, llm) | |
| save_to_history(query, response) | |
| return response, [] | |
| else: # PSYCHOLOGY | |
| response, docs = run_rag_with_crag(query, history_context, llm) | |
| save_to_history(query, response) | |
| return response, docs |