mindease / src /main_back.py
harsh122e2wr's picture
Upload 3 files
fca6e17 verified
import os
from typing import Dict, List, TypedDict, Annotated
from dotenv import load_dotenv
# LangChain and LangGraph imports
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
import sqlite3
from langgraph.checkpoint.sqlite import SqliteSaver
from pydantic import BaseModel, Field
# --- Graph State Definition ---
class GraphState(TypedDict):
questionnaire_responses: Dict[str, int]
domain_scores: Dict[str, int]
primary_concern: str
messages: Annotated[list[AnyMessage], add_messages]
is_safe: bool
retry_count: int
# --- RAG Retriever Helper Function ---
def create_persistent_rag_retriever(pdf_paths: List[str], db_name: str, embedding_model):
"""Creates or loads a persistent RAG retriever from one or more PDF documents."""
persist_directory = f"./chroma_db/{db_name}"
if os.path.exists(persist_directory):
print(f"--- Loading existing persistent DB: {db_name} ---")
return Chroma(persist_directory=persist_directory, embedding_function=embedding_model).as_retriever()
print(f"--- Creating new persistent DB: {db_name} ---")
vector_store = None
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
for pdf_path in pdf_paths:
if not os.path.exists(pdf_path):
print(f"Warning: PDF not found at '{pdf_path}'. Skipping.")
continue
print(f"--- Processing PDF: {pdf_path} ---")
loader = PyPDFLoader(pdf_path)
documents = loader.load()
splits = text_splitter.split_documents(documents)
if not splits: continue
if vector_store is None:
vector_store = Chroma.from_documents(documents=splits, embedding=embedding_model, persist_directory=persist_directory)
else:
vector_store.add_documents(splits)
if vector_store:
print(f"--- DB creation complete for {db_name} ---")
return vector_store.as_retriever(search_kwargs={'k': 3})
else:
print(f"--- Could not create DB for {db_name}. No documents processed. ---")
return None
# --- Node Definitions ---
def questionnaire(state: GraphState) -> GraphState:
"""Calculates all domain scores and clears the questionnaire responses from the state."""
responses = state.get("questionnaire_responses", {})
domain_scores = {
"Depression": max(responses.get("1", 0), responses.get("2", 0)),
"Anger": responses.get("3", 0),
"Mania": max(responses.get("4", 0), responses.get("5", 0)),
"Anxiety": max(responses.get("6", 0), responses.get("7", 0), responses.get("8", 0)),
"Somatic_Symptoms": max(responses.get("9", 0), responses.get("10", 0)),
"Suicidal_Ideation": responses.get("11", 0),
"Psychosis": max(responses.get("12", 0), responses.get("13", 0)),
"Sleep_Problems": responses.get("14", 0),
"Memory": responses.get("15", 0),
"Repetitive_Thoughts_Behaviors": max(responses.get("16", 0), responses.get("17", 0)),
"Dissociation": responses.get("18", 0),
"Personality_Functioning": max(responses.get("19", 0), responses.get("20", 0)),
"Substance_Use": max(responses.get("21", 0), responses.get("22", 0), responses.get("23", 0)),
}
initial_question = "User has completed the initial questionnaire.Provide supportive steps and coping mechanisms."
initial_message = HumanMessage(content=initial_question)
return {"domain_scores": domain_scores, "retry_count": 0, "messages": [initial_message]}
def route_entry(state: GraphState)-> str:
if state.get("domain_scores"):
"""Routes to the appropriate RAG handler based on scores."""
scores = state.get("domain_scores", {})
if scores.get("Depression", 0) >= 2:
return "depression"
if scores.get("Anxiety", 0) >= 2:
return "anxiety"
return "no_action"
else:
return "questionnaire"
def handle_depression_rag(state: GraphState) -> GraphState:
"""Handles the conversational RAG pipeline for depression."""
score = state.get("domain_scores", {}).get("Depression", 0)
messages = state.get("messages", [])
retry_count = state.get("retry_count", 0)
retry_guidance = "Please provide a helpful and supportive plan."
if retry_count > 0: retry_guidance = "Your previous response was flagged. Please try again."
prompt = ChatPromptTemplate.from_messages([
("system", """You are a kind and empathetic AI assistant. Your role is to answer the user's question in a supportive, conversational tone.
- **Do not** just summarize the documents. Synthesize the information and answer the user's question directly."""),
("human", "The user has a depression score of {score} on a scale of 0 (None) to 4 (Severe).My question is: {question}. {retry_guidance}")
])
rag_chain = prompt | llm | StrOutputParser()
response = rag_chain.invoke({ "score": score, "question": messages,"retry_guidance":retry_guidance})
return {"messages": [AIMessage(content=response)], "retry_count": retry_count + 1}
def handle_anxiety_rag(state: GraphState) -> GraphState:
"""Handles the conversational RAG pipeline for anxiety."""
score = state.get("domain_scores", {}).get("Depression", 0)
messages = state.get("messages", [])
retry_count = state.get("retry_count", 0)
retry_guidance = "Please provide a helpful and supportive plan for someone feeling anxious or worried."
prompt = ChatPromptTemplate.from_messages([
("system", """You are a kind and empathetic AI assistant. Your role is to answer the user's question in a supportive, conversational tone.
- **Do not** just summarize the documents. Synthesize the information and answer the user's question directly."""),
("human", "The user has a anxiety score of {score} on a scale of 0 (None) to 4 (Severe).My question is: {question}. {retry_guidance}")
])
rag_chain = prompt | llm | StrOutputParser()
response = rag_chain.invoke({"score": score, "question": messages,"retry_guidance":retry_guidance})
return {"messages": [AIMessage(content=response)], "retry_count": retry_count + 1}
class SafetyCheck(BaseModel):
"""Pydantic model for the safety check response."""
is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.")
def safety_check(state: GraphState) -> GraphState:
"""Checks the generated response for safety."""
ai_response = next(
(m.content for m in state.get("messages",[]) if isinstance(m, AIMessage)), ''
)
structured_llm = llm.with_structured_output(SafetyCheck)
prompt = ChatPromptTemplate.from_template(
"""You are a safety classification model. The user is in a mental health context.
Analyze the following text for any harmful, unsafe, or non-supportive content.
Then, provide a JSON object with a single boolean field 'is_safe'.
Text to evaluate:
'{ai_response}'"""
)
safety_chain = prompt | structured_llm
result = safety_chain.invoke({"ai_response": ai_response})
return {"is_safe": result.is_safe}
def handle_fallback(state: GraphState) -> GraphState:
"""Provides a safe, generic response if retries fail."""
fallback_message = AIMessage(content="I am having trouble generating a specific plan right now. Please consider seeking support from a qualified professional.")
return {"messages": [fallback_message],"retry_count":0}
def finalize_response(state: GraphState) -> GraphState:
"""Finalizes the turn by returning the safe AI response as a message."""
return {"retry_count":0}
def route_after_safety_check(state: "GraphState") -> str:
"""Routes after the safety check, enabling the retry loop."""
if state.get("is_safe"): return "finalize"
if state.get("retry_count", 0) < 2: return "retry"
return "fallback"
def entry_point(state: GraphState) -> GraphState:
"""A dedicated node for the graph's entry point that makes no state changes."""
return state
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile",max_tokens=4096)
graph = StateGraph(GraphState)
graph.add_node("entry_point",entry_point)
graph.add_node("questionnaire",questionnaire)
graph.add_node("handle_depression_rag", handle_depression_rag)
graph.add_node("handle_anxiety_rag", handle_anxiety_rag)
graph.add_node("safety_check_depression", safety_check)
graph.add_node("safety_check_anxiety", safety_check)
graph.add_node("handle_fallback", handle_fallback)
graph.add_node("finalize_response", finalize_response)
graph.add_edge(START,"entry_point")
graph.add_conditional_edges("entry_point",route_entry,{"depression":"handle_depression_rag","anxiety":"handle_anxiety_rag","no_action":END,"questionnaire":"questionnaire"})
graph.add_edge("questionnaire","entry_point")
graph.add_edge("handle_depression_rag", "safety_check_depression")
graph.add_conditional_edges(
"safety_check_depression", route_after_safety_check,
{"retry": "handle_depression_rag", "fallback": "handle_fallback", "finalize": "finalize_response"}
)
graph.add_edge("handle_anxiety_rag", "safety_check_anxiety")
graph.add_conditional_edges(
"safety_check_anxiety", route_after_safety_check,
{"retry": "handle_anxiety_rag", "fallback": "handle_fallback", "finalize": "finalize_response"}
)
graph.add_edge("handle_fallback", END)
graph.add_edge("finalize_response", END)
DB_PATH = "/app/data/chatbot.sqlite"
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
checkpointer = SqliteSaver(conn=conn)
app = graph.compile(checkpointer=checkpointer)