import os from typing import Dict, List, TypedDict, Annotated from dotenv import load_dotenv # LangChain and LangGraph imports from langchain_chroma import Chroma from langchain_community.document_loaders import PyPDFLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_groq import ChatGroq from langchain_huggingface import HuggingFaceEmbeddings from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.messages import AnyMessage, AIMessage, HumanMessage from langgraph.graph import StateGraph, END, START from langgraph.graph.message import add_messages import sqlite3 from langgraph.checkpoint.sqlite import SqliteSaver from pydantic import BaseModel, Field # --- Graph State Definition --- class GraphState(TypedDict): questionnaire_responses: Dict[str, int] domain_scores: Dict[str, int] primary_concern: str messages: Annotated[list[AnyMessage], add_messages] is_safe: bool retry_count: int # --- RAG Retriever Helper Function --- def create_persistent_rag_retriever(pdf_paths: List[str], db_name: str, embedding_model): """Creates or loads a persistent RAG retriever from one or more PDF documents.""" persist_directory = f"./chroma_db/{db_name}" if os.path.exists(persist_directory): print(f"--- Loading existing persistent DB: {db_name} ---") return Chroma(persist_directory=persist_directory, embedding_function=embedding_model).as_retriever() print(f"--- Creating new persistent DB: {db_name} ---") vector_store = None text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) for pdf_path in pdf_paths: if not os.path.exists(pdf_path): print(f"Warning: PDF not found at '{pdf_path}'. Skipping.") continue print(f"--- Processing PDF: {pdf_path} ---") loader = PyPDFLoader(pdf_path) documents = loader.load() splits = text_splitter.split_documents(documents) if not splits: continue if vector_store is None: vector_store = Chroma.from_documents(documents=splits, embedding=embedding_model, persist_directory=persist_directory) else: vector_store.add_documents(splits) if vector_store: print(f"--- DB creation complete for {db_name} ---") return vector_store.as_retriever(search_kwargs={'k': 3}) else: print(f"--- Could not create DB for {db_name}. No documents processed. ---") return None # --- Node Definitions --- def questionnaire(state: GraphState) -> GraphState: """Calculates all domain scores and clears the questionnaire responses from the state.""" responses = state.get("questionnaire_responses", {}) domain_scores = { "Depression": max(responses.get("1", 0), responses.get("2", 0)), "Anger": responses.get("3", 0), "Mania": max(responses.get("4", 0), responses.get("5", 0)), "Anxiety": max(responses.get("6", 0), responses.get("7", 0), responses.get("8", 0)), "Somatic_Symptoms": max(responses.get("9", 0), responses.get("10", 0)), "Suicidal_Ideation": responses.get("11", 0), "Psychosis": max(responses.get("12", 0), responses.get("13", 0)), "Sleep_Problems": responses.get("14", 0), "Memory": responses.get("15", 0), "Repetitive_Thoughts_Behaviors": max(responses.get("16", 0), responses.get("17", 0)), "Dissociation": responses.get("18", 0), "Personality_Functioning": max(responses.get("19", 0), responses.get("20", 0)), "Substance_Use": max(responses.get("21", 0), responses.get("22", 0), responses.get("23", 0)), } initial_question = "User has completed the initial questionnaire.Provide supportive steps and coping mechanisms." initial_message = HumanMessage(content=initial_question) return {"domain_scores": domain_scores, "retry_count": 0, "messages": [initial_message]} def route_entry(state: GraphState)-> str: if state.get("domain_scores"): """Routes to the appropriate RAG handler based on scores.""" scores = state.get("domain_scores", {}) if scores.get("Depression", 0) >= 2: return "depression" if scores.get("Anxiety", 0) >= 2: return "anxiety" return "no_action" else: return "questionnaire" def handle_depression_rag(state: GraphState) -> GraphState: """Handles the conversational RAG pipeline for depression.""" score = state.get("domain_scores", {}).get("Depression", 0) messages = state.get("messages", []) retry_count = state.get("retry_count", 0) retry_guidance = "Please provide a helpful and supportive plan." if retry_count > 0: retry_guidance = "Your previous response was flagged. Please try again." prompt = ChatPromptTemplate.from_messages([ ("system", """You are a kind and empathetic AI assistant. Your role is to answer the user's question in a supportive, conversational tone. - **Do not** just summarize the documents. Synthesize the information and answer the user's question directly."""), ("human", "The user has a depression score of {score} on a scale of 0 (None) to 4 (Severe).My question is: {question}. {retry_guidance}") ]) rag_chain = prompt | llm | StrOutputParser() response = rag_chain.invoke({ "score": score, "question": messages,"retry_guidance":retry_guidance}) return {"messages": [AIMessage(content=response)], "retry_count": retry_count + 1} def handle_anxiety_rag(state: GraphState) -> GraphState: """Handles the conversational RAG pipeline for anxiety.""" score = state.get("domain_scores", {}).get("Depression", 0) messages = state.get("messages", []) retry_count = state.get("retry_count", 0) retry_guidance = "Please provide a helpful and supportive plan for someone feeling anxious or worried." prompt = ChatPromptTemplate.from_messages([ ("system", """You are a kind and empathetic AI assistant. Your role is to answer the user's question in a supportive, conversational tone. - **Do not** just summarize the documents. Synthesize the information and answer the user's question directly."""), ("human", "The user has a anxiety score of {score} on a scale of 0 (None) to 4 (Severe).My question is: {question}. {retry_guidance}") ]) rag_chain = prompt | llm | StrOutputParser() response = rag_chain.invoke({"score": score, "question": messages,"retry_guidance":retry_guidance}) return {"messages": [AIMessage(content=response)], "retry_count": retry_count + 1} class SafetyCheck(BaseModel): """Pydantic model for the safety check response.""" is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.") def safety_check(state: GraphState) -> GraphState: """Checks the generated response for safety.""" ai_response = next( (m.content for m in state.get("messages",[]) if isinstance(m, AIMessage)), '' ) structured_llm = llm.with_structured_output(SafetyCheck) prompt = ChatPromptTemplate.from_template( """You are a safety classification model. The user is in a mental health context. Analyze the following text for any harmful, unsafe, or non-supportive content. Then, provide a JSON object with a single boolean field 'is_safe'. Text to evaluate: '{ai_response}'""" ) safety_chain = prompt | structured_llm result = safety_chain.invoke({"ai_response": ai_response}) return {"is_safe": result.is_safe} def handle_fallback(state: GraphState) -> GraphState: """Provides a safe, generic response if retries fail.""" fallback_message = AIMessage(content="I am having trouble generating a specific plan right now. Please consider seeking support from a qualified professional.") return {"messages": [fallback_message],"retry_count":0} def finalize_response(state: GraphState) -> GraphState: """Finalizes the turn by returning the safe AI response as a message.""" return {"retry_count":0} def route_after_safety_check(state: "GraphState") -> str: """Routes after the safety check, enabling the retry loop.""" if state.get("is_safe"): return "finalize" if state.get("retry_count", 0) < 2: return "retry" return "fallback" def entry_point(state: GraphState) -> GraphState: """A dedicated node for the graph's entry point that makes no state changes.""" return state load_dotenv() groq_api_key = os.getenv("GROQ_API_KEY") llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile",max_tokens=4096) graph = StateGraph(GraphState) graph.add_node("entry_point",entry_point) graph.add_node("questionnaire",questionnaire) graph.add_node("handle_depression_rag", handle_depression_rag) graph.add_node("handle_anxiety_rag", handle_anxiety_rag) graph.add_node("safety_check_depression", safety_check) graph.add_node("safety_check_anxiety", safety_check) graph.add_node("handle_fallback", handle_fallback) graph.add_node("finalize_response", finalize_response) graph.add_edge(START,"entry_point") graph.add_conditional_edges("entry_point",route_entry,{"depression":"handle_depression_rag","anxiety":"handle_anxiety_rag","no_action":END,"questionnaire":"questionnaire"}) graph.add_edge("questionnaire","entry_point") graph.add_edge("handle_depression_rag", "safety_check_depression") graph.add_conditional_edges( "safety_check_depression", route_after_safety_check, {"retry": "handle_depression_rag", "fallback": "handle_fallback", "finalize": "finalize_response"} ) graph.add_edge("handle_anxiety_rag", "safety_check_anxiety") graph.add_conditional_edges( "safety_check_anxiety", route_after_safety_check, {"retry": "handle_anxiety_rag", "fallback": "handle_fallback", "finalize": "finalize_response"} ) graph.add_edge("handle_fallback", END) graph.add_edge("finalize_response", END) DB_PATH = "/app/data/chatbot.sqlite" conn = sqlite3.connect(DB_PATH, check_same_thread=False) checkpointer = SqliteSaver(conn=conn) app = graph.compile(checkpointer=checkpointer)