MindBot-v0 / query.py
Chirag20's picture
added knowledge
edabb92
import os
import logging
from qdrant_client import QdrantClient
from langchain_qdrant import QdrantVectorStore
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from dotenv import load_dotenv
from collections import defaultdict
load_dotenv()
# --------------------------
# GLOBALS (cached)
# --------------------------
EMBEDDINGS = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
VECTORSTORE = None
LLM_ROUTER = None
LOGGER = logging.getLogger(__name__)
# Simple in-memory session history: list of {"role": "user"/"assistant", "content": "..."}
CONVERSATION_HISTORY = []
MAX_HISTORY_TURNS = 6 # keep last 6 turns (3 user + 3 assistant)
# --------------------------
# QDRANT / LLM SETUP
# --------------------------
def get_client():
return QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
def get_vectorstore():
global VECTORSTORE
if VECTORSTORE is None:
VECTORSTORE = QdrantVectorStore(
client=get_client(),
collection_name="psychology_books",
embedding=EMBEDDINGS,
)
return VECTORSTORE
def _collect_groq_api_keys() -> list[str]:
"""
Collect Groq keys from supported env vars in priority order:
1) GROQ_API_KEYS (comma-separated)
2) GROQ_API_KEY
3) GROQ_API_KEY_2, GROQ_API_KEY_3, ...
"""
keys = []
combined = os.getenv("GROQ_API_KEYS", "").strip()
if combined:
keys.extend([key.strip() for key in combined.split(",") if key.strip()])
primary = os.getenv("GROQ_API_KEY", "").strip()
if primary:
keys.append(primary)
numbered_names = []
prefix = "GROQ_API_KEY_"
for env_name in os.environ.keys():
if not env_name.startswith(prefix):
continue
suffix = env_name[len(prefix):]
if suffix.isdigit():
numbered_names.append(env_name)
numbered_names.sort(key=lambda name: int(name[len(prefix):]))
for env_name in numbered_names:
value = os.getenv(env_name, "").strip()
if value:
keys.append(value)
deduped = []
seen = set()
for key in keys:
if key not in seen:
seen.add(key)
deduped.append(key)
return deduped
class GroqFailoverLLM:
def __init__(self, api_keys: list[str], model: str, temperature: float):
self.api_keys = api_keys
self.model = model
self.temperature = temperature
self.clients = [
ChatGroq(model=self.model, temperature=self.temperature, api_key=key)
for key in self.api_keys
]
self.active_idx = 0
@staticmethod
def _is_rate_limit_error(error: Exception) -> bool:
text = str(error).lower()
markers = [
"rate limit",
"too many requests",
"429",
"quota",
"resource exhausted",
]
return any(marker in text for marker in markers)
def invoke(self, prompt):
total = len(self.clients)
last_error = None
for offset in range(total):
idx = (self.active_idx + offset) % total
client = self.clients[idx]
try:
result = client.invoke(prompt)
if idx != self.active_idx:
LOGGER.warning("Groq fallback active: switched to key #%s", idx + 1)
self.active_idx = idx
return result
except Exception as exc:
last_error = exc
if total > 1 and self._is_rate_limit_error(exc):
LOGGER.warning(
"Groq key #%s hit limit/quota. Trying next configured key.",
idx + 1,
)
continue
raise
if last_error is not None:
raise last_error
raise RuntimeError("No Groq clients available")
def get_llm():
global LLM_ROUTER
if LLM_ROUTER is None:
api_keys = _collect_groq_api_keys()
if not api_keys:
raise ValueError(
"No Groq keys configured. Set GROQ_API_KEY, GROQ_API_KEYS, or GROQ_API_KEY_2+ in .env"
)
LLM_ROUTER = GroqFailoverLLM(
api_keys=api_keys,
model="llama-3.1-8b-instant",
temperature=0,
)
return LLM_ROUTER
# --------------------------
# LAYER 1: INTENT ROUTER
# --------------------------
INTENT_PROMPT = """Classify this message into EXACTLY one category.
Categories:
- CHITCHAT : greetings, thanks, small talk, "how are you", jokes
- IDENTITY : asking who/what you are, your name, your purpose
- CRISIS : suicidal thoughts, self-harm, abuse, urgent danger
- PSYCHOLOGY : anything about mental health, behavior, emotions, habits, thinking, relationships
Message: "{query}"
Reply with only ONE word: CHITCHAT, IDENTITY, CRISIS, or PSYCHOLOGY"""
def classify_intent(query: str, llm) -> str:
result = llm.invoke(INTENT_PROMPT.format(query=query))
text = getattr(result, "content", str(result)).strip().upper()
for intent in ["CRISIS", "CHITCHAT", "IDENTITY", "PSYCHOLOGY"]:
if intent in text:
return intent
return "PSYCHOLOGY" # safe default
# --------------------------
# LAYER 2: QUERY REWRITER
# --------------------------
def rewrite_query(query: str, history_context: str, llm) -> str:
"""
Rephrase emotional/vague queries into retrieval-friendly ones.
E.g. "I feel so empty" → "emotional numbness causes and psychological explanation"
"""
prompt = f"""You are a psychology search assistant.
Conversation so far:
{history_context if history_context else "None"}
User query: "{query}"
Rewrite this query to be more specific and retrieval-friendly for a psychology book search.
Keep it under 15 words. Return ONLY the rewritten query, nothing else."""
result = llm.invoke(prompt)
rewritten = getattr(result, "content", str(result)).strip().strip('"')
return rewritten if rewritten else query
# --------------------------
# LAYER 3: CRAG RELEVANCE GRADER
# --------------------------
def grade_chunks(query: str, docs: list, llm) -> list:
"""
Score each retrieved chunk 1-5 for relevance.
Filter out chunks scoring below threshold.
Returns (filtered_docs, all_relevant: bool)
"""
relevant_docs = []
for doc in docs:
prompt = f"""Rate this text chunk's relevance to the query on a scale of 1 to 5.
Query: "{query}"
Chunk:
{doc.page_content[:400]}
Reply with ONLY a number: 1, 2, 3, 4, or 5
(1=completely irrelevant, 5=directly answers the query)"""
result = llm.invoke(prompt)
text = getattr(result, "content", str(result)).strip()
try:
score = int(text[0])
except (ValueError, IndexError):
score = 3 # default to neutral
if score >= 3:
relevant_docs.append(doc)
all_relevant = len(relevant_docs) > 0
return relevant_docs, all_relevant
# --------------------------
# CONVERSATION MEMORY HELPERS
# --------------------------
def get_history_context() -> str:
"""Format last N turns for injection into prompts."""
if not CONVERSATION_HISTORY:
return ""
lines = []
for turn in CONVERSATION_HISTORY[-MAX_HISTORY_TURNS:]:
role = "User" if turn["role"] == "user" else "Assistant"
lines.append(f"{role}: {turn['content']}")
return "\n".join(lines)
def save_to_history(user_msg: str, assistant_msg: str):
CONVERSATION_HISTORY.append({"role": "user", "content": user_msg})
CONVERSATION_HISTORY.append({"role": "assistant", "content": assistant_msg})
def clear_history():
CONVERSATION_HISTORY.clear()
# --------------------------
# INTENT HANDLERS
# --------------------------
def handle_chitchat(query: str, history_context: str, llm) -> str:
prompt = f"""You are a warm, supportive psychology assistant named MindBot.
Have a natural, friendly conversation. Keep it brief (1-2 sentences).
{f'Conversation so far: {history_context}' if history_context else ''}
User: {query}
MindBot:"""
result = llm.invoke(prompt)
return getattr(result, "content", str(result)).strip()
def handle_identity(query: str, llm) -> str:
prompt = f"""The user asked: "{query}"
You are MindBot, a psychology assistant trained on books covering therapy, behavior,
cognitive psychology, habits, and emotional well-being. Answer in 2-3 sentences."""
result = llm.invoke(prompt)
return getattr(result, "content", str(result)).strip()
def handle_crisis(query: str) -> str:
return (
"I can hear that you're going through something really difficult right now. "
"Please know you're not alone.\n\n"
"**iCall (India):** 9152987821\n"
"**Vandrevala Foundation (24/7):** 1860-2662-345\n"
"**iCall email:** icall@tiss.edu\n\n"
"Would you like to talk about what's on your mind? I'm here to listen."
)
# --------------------------
# CORE RAG PIPELINE (with CRAG)
# --------------------------
def run_rag_with_crag(query: str, history_context: str, llm, max_retries: int = 2) -> tuple:
vectorstore = get_vectorstore()
# Step 1: rewrite query for better retrieval
search_query = rewrite_query(query, history_context, llm)
for attempt in range(max_retries):
# Step 2: retrieve
retriever = vectorstore.as_retriever(
search_type="mmr",
search_kwargs={"k": 5},
)
docs = retriever.invoke(search_query)
if not docs:
return "I couldn't find relevant information in the books.", []
# Step 3: CRAG grade
relevant_docs, has_relevant = grade_chunks(search_query, docs, llm)
if has_relevant:
break
# If nothing relevant: rewrite more aggressively and retry
if attempt < max_retries - 1:
search_query = rewrite_query(
f"explain in detail: {query}",
history_context,
llm
)
# Step 4: build grouped context
if not relevant_docs:
relevant_docs = docs[:2] # fallback: use top 2 anyway
grouped = defaultdict(list)
for doc in relevant_docs:
book = doc.metadata.get("book", "Unknown")
grouped[book].append(doc.page_content)
context_parts = []
for book, texts in grouped.items():
joined = "\n".join(texts[:2])
context_parts.append(f"[Source: {book}]\n{joined}")
context = "\n\n---\n\n".join(context_parts)
# Step 5: generate with memory + empathy constraint
prompt = f"""You are MindBot, a compassionate psychology assistant.
Conversation history:
{history_context if history_context else "This is the start of the conversation."}
Knowledge from books:
{context}
User's question: {query}
Instructions:
- Be warm and empathetic first, then informative
- Answer using ONLY the provided book context
- If multiple perspectives exist, present them clearly
- If the question isn't covered in the context, say so honestly
- Keep response focused and under 200 words unless depth is needed"""
result = llm.invoke(prompt)
answer = getattr(result, "content", str(result)).strip()
return answer, relevant_docs
# --------------------------
# MAIN ENTRY POINT
# --------------------------
def ask_question(query: str) -> tuple:
llm = get_llm()
history_context = get_history_context()
# Route intent
intent = classify_intent(query, llm)
if intent == "CRISIS":
response = handle_crisis(query)
save_to_history(query, response)
return response, []
elif intent == "CHITCHAT":
response = handle_chitchat(query, history_context, llm)
save_to_history(query, response)
return response, []
elif intent == "IDENTITY":
response = handle_identity(query, llm)
save_to_history(query, response)
return response, []
else: # PSYCHOLOGY
response, docs = run_rag_with_crag(query, history_context, llm)
save_to_history(query, response)
return response, docs