File size: 10,247 Bytes
fca6e17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import os
from typing import Dict, List, TypedDict, Annotated
from dotenv import load_dotenv

# LangChain and LangGraph imports
from langchain_chroma import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage

from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
import sqlite3
from langgraph.checkpoint.sqlite import SqliteSaver
from pydantic import BaseModel, Field

# ---  Graph State Definition ---
class GraphState(TypedDict):
    questionnaire_responses: Dict[str, int]
    domain_scores: Dict[str, int]
    primary_concern: str
    messages: Annotated[list[AnyMessage], add_messages]
    is_safe: bool
    retry_count: int
    

# ---  RAG Retriever Helper Function ---
def create_persistent_rag_retriever(pdf_paths: List[str], db_name: str, embedding_model):
    """Creates or loads a persistent RAG retriever from one or more PDF documents."""
    persist_directory = f"./chroma_db/{db_name}"
    if os.path.exists(persist_directory):
        print(f"--- Loading existing persistent DB: {db_name} ---")
        return Chroma(persist_directory=persist_directory, embedding_function=embedding_model).as_retriever()
    
    print(f"--- Creating new persistent DB: {db_name} ---")
    vector_store = None
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)

    for pdf_path in pdf_paths:
        if not os.path.exists(pdf_path):
            print(f"Warning: PDF not found at '{pdf_path}'. Skipping.")
            continue
        
        print(f"--- Processing PDF: {pdf_path} ---")
        loader = PyPDFLoader(pdf_path)
        documents = loader.load()
        splits = text_splitter.split_documents(documents)

        if not splits: continue
            
        if vector_store is None:
            vector_store = Chroma.from_documents(documents=splits, embedding=embedding_model, persist_directory=persist_directory)
        else:
            vector_store.add_documents(splits)
            
    if vector_store:
        print(f"--- DB creation complete for {db_name} ---")
        return vector_store.as_retriever(search_kwargs={'k': 3})
    else:
        print(f"--- Could not create DB for {db_name}. No documents processed. ---")
        return None

# ---  Node Definitions ---
def questionnaire(state: GraphState) -> GraphState:
    """Calculates all domain scores and clears the questionnaire responses from the state."""
    responses = state.get("questionnaire_responses", {})
    domain_scores = {
        "Depression": max(responses.get("1", 0), responses.get("2", 0)),
        "Anger": responses.get("3", 0),
        "Mania": max(responses.get("4", 0), responses.get("5", 0)),
        "Anxiety": max(responses.get("6", 0), responses.get("7", 0), responses.get("8", 0)),
        "Somatic_Symptoms": max(responses.get("9", 0), responses.get("10", 0)),
        "Suicidal_Ideation": responses.get("11", 0),
        "Psychosis": max(responses.get("12", 0), responses.get("13", 0)),
        "Sleep_Problems": responses.get("14", 0),
        "Memory": responses.get("15", 0),
        "Repetitive_Thoughts_Behaviors": max(responses.get("16", 0), responses.get("17", 0)),
        "Dissociation": responses.get("18", 0),
        "Personality_Functioning": max(responses.get("19", 0), responses.get("20", 0)),
        "Substance_Use": max(responses.get("21", 0), responses.get("22", 0), responses.get("23", 0)),
    }
    initial_question = "User has completed the initial questionnaire.Provide supportive steps and coping mechanisms."
    initial_message = HumanMessage(content=initial_question)
    
    return {"domain_scores": domain_scores, "retry_count": 0, "messages": [initial_message]}

def route_entry(state: GraphState)-> str:
    if state.get("domain_scores"):
        """Routes to the appropriate RAG handler based on scores."""
        scores = state.get("domain_scores", {})
        if scores.get("Depression", 0) >= 2:
            return "depression"
        if scores.get("Anxiety", 0) >= 2:
            return "anxiety"
        return "no_action"
    else:
        return "questionnaire"
    
def handle_depression_rag(state: GraphState) -> GraphState:
    """Handles the conversational RAG pipeline for depression."""
    score = state.get("domain_scores", {}).get("Depression", 0)
    messages = state.get("messages", [])
    retry_count = state.get("retry_count", 0)
    
    retry_guidance = "Please provide a helpful and supportive plan."
    if retry_count > 0: retry_guidance = "Your previous response was flagged. Please try again."
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a kind and empathetic AI assistant. Your role is to answer the user's question in a supportive, conversational tone.
- **Do not** just summarize the documents. Synthesize the information and answer the user's question directly."""),
        ("human", "The user has a depression score of {score} on a scale of 0 (None) to 4 (Severe).My question is: {question}. {retry_guidance}")
    ])

    rag_chain = prompt | llm | StrOutputParser()
    response = rag_chain.invoke({ "score": score, "question": messages,"retry_guidance":retry_guidance})
    return {"messages": [AIMessage(content=response)], "retry_count": retry_count + 1}

def handle_anxiety_rag(state: GraphState) -> GraphState:
    """Handles the conversational RAG pipeline for anxiety."""
    score = state.get("domain_scores", {}).get("Depression", 0)
    messages = state.get("messages", [])
    retry_count = state.get("retry_count", 0)
    
    retry_guidance = "Please provide a helpful and supportive plan for someone feeling anxious or worried."
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a kind and empathetic AI assistant. Your role is to answer the user's question in a supportive, conversational tone.
- **Do not** just summarize the documents. Synthesize the information and answer the user's question directly."""),
        ("human", "The user has a anxiety score of {score} on a scale of 0 (None) to 4 (Severe).My question is: {question}. {retry_guidance}")
    ])

    rag_chain = prompt | llm | StrOutputParser()
    response = rag_chain.invoke({"score": score,  "question": messages,"retry_guidance":retry_guidance})
    return {"messages": [AIMessage(content=response)], "retry_count": retry_count + 1}

class SafetyCheck(BaseModel):
    """Pydantic model for the safety check response."""
    is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.")

def safety_check(state: GraphState) -> GraphState:
    """Checks the generated response for safety."""
    ai_response = next(
        (m.content for m in state.get("messages",[]) if isinstance(m, AIMessage)), ''
    )
    structured_llm = llm.with_structured_output(SafetyCheck)
    
    prompt = ChatPromptTemplate.from_template(
        """You are a safety classification model. The user is in a mental health context. 
        Analyze the following text for any harmful, unsafe, or non-supportive content.
        Then, provide a JSON object with a single boolean field 'is_safe'.

        Text to evaluate:
        '{ai_response}'"""
    )
    
    safety_chain = prompt | structured_llm
    result = safety_chain.invoke({"ai_response": ai_response})
    
    return {"is_safe": result.is_safe}

def handle_fallback(state: GraphState) -> GraphState:
    """Provides a safe, generic response if retries fail."""
    fallback_message = AIMessage(content="I am having trouble generating a specific plan right now. Please consider seeking support from a qualified professional.")
    return {"messages": [fallback_message],"retry_count":0}

def finalize_response(state: GraphState) -> GraphState:
    """Finalizes the turn by returning the safe AI response as a message."""
    return {"retry_count":0}

def route_after_safety_check(state: "GraphState") -> str:
    """Routes after the safety check, enabling the retry loop."""
    if state.get("is_safe"): return "finalize"
    if state.get("retry_count", 0) < 2: return "retry"
    return "fallback"

def entry_point(state: GraphState) -> GraphState:
    """A dedicated node for the graph's entry point that makes no state changes."""
    return state

load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile",max_tokens=4096)

graph = StateGraph(GraphState)

graph.add_node("entry_point",entry_point)
graph.add_node("questionnaire",questionnaire)
graph.add_node("handle_depression_rag", handle_depression_rag)
graph.add_node("handle_anxiety_rag", handle_anxiety_rag)
graph.add_node("safety_check_depression", safety_check)
graph.add_node("safety_check_anxiety", safety_check)
graph.add_node("handle_fallback", handle_fallback)
graph.add_node("finalize_response", finalize_response)

graph.add_edge(START,"entry_point")
graph.add_conditional_edges("entry_point",route_entry,{"depression":"handle_depression_rag","anxiety":"handle_anxiety_rag","no_action":END,"questionnaire":"questionnaire"})
graph.add_edge("questionnaire","entry_point")

graph.add_edge("handle_depression_rag", "safety_check_depression")
graph.add_conditional_edges(
    "safety_check_depression", route_after_safety_check,
    {"retry": "handle_depression_rag", "fallback": "handle_fallback", "finalize": "finalize_response"}
)

graph.add_edge("handle_anxiety_rag", "safety_check_anxiety")
graph.add_conditional_edges(
    "safety_check_anxiety", route_after_safety_check,
    {"retry": "handle_anxiety_rag", "fallback": "handle_fallback", "finalize": "finalize_response"}
)

graph.add_edge("handle_fallback", END)
graph.add_edge("finalize_response", END)
DB_PATH = "/app/data/chatbot.sqlite"
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
checkpointer = SqliteSaver(conn=conn)
app = graph.compile(checkpointer=checkpointer)