File size: 4,908 Bytes
fca6e17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from typing import Annotated, TypedDict, List, Dict
from dotenv import load_dotenv

# LangChain and LangGraph imports
from langchain_groq import ChatGroq
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages

# --- 1. Pydantic Model for Structured Output ---
class SafetyCheck(BaseModel):
    """Pydantic model for the safety check response."""
    is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.")

# --- 2. Define the State for the Chatbot ---
class ChatState(TypedDict):
    """State of the chatbot, continuously appending messages."""
    messages: Annotated[list[AnyMessage], add_messages]
    is_safe: bool
    retry_count: int

# --- 3. Define the Chatbot Nodes ---
def chat_node(state: ChatState, llm):
    """

    This node invokes the LLM to get a response using a ChatPromptTemplate.

    """
    messages = state['messages']
    retry_count = state.get('retry_count', 0)
    
    # Create a prompt template with the system message
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a kind and empathetic AI assistant for mental well-being. "
                   "Provide supportive, safe, and general advice. Do not provide medical diagnoses. "
                   "Keep your responses concise and helpful."),
        ("placeholder", "{chat_history}")
    ])
    
    # Add a warning to the LLM if this is a retry attempt
    if retry_count > 0:
        messages.append(SystemMessage(
            content="Your previous response was not suitable. Please try again, "
                    "focusing on being supportive and safe."
        ))

    chain = prompt | llm | StrOutputParser()
    
    # Invoke the LLM with the conversation history
    response_text = chain.invoke({"chat_history": messages})
    
    return {"messages": [AIMessage(content=response_text)], "retry_count": retry_count + 1}

def safety_check_node(state: ChatState, llm) -> Dict:
    """

    Evaluates the safety of the generated response using a structured LLM.

    """
    # Get the most recent AI message to check
    ai_message = state['messages'][-1]
    
    # Create a dedicated structured LLM for this task
    structured_llm = llm.with_structured_output(SafetyCheck)
    
    prompt = ChatPromptTemplate.from_template(
        """You are a safety classification model. The user is in a mental health context. 

        Analyze the following text for any harmful, unsafe, or non-supportive content.

        Then, provide a JSON object with a single boolean field 'is_safe'.



        Text to evaluate:

        '{text_to_evaluate}'"""
    )
    
    safety_chain = prompt | structured_llm
    result = safety_chain.invoke({"text_to_evaluate": ai_message.content})
    
    
    return {"is_safe": result.is_safe}

def handle_fallback_node(state: ChatState) -> Dict:
    """

    Provides a safe, generic response if the main LLM fails after retries.

    """
    fallback_message = AIMessage(
        content="I am having a little trouble formulating a response right now. "
                "Remember that taking a moment to focus on your breath can be a helpful step."
    )
    return {"messages": [fallback_message]}

# --- 4. Define the Conditional Router ---
def route_after_safety_check(state: ChatState) -> str:
    """

    This router decides the next step after the safety check, enabling a retry loop.

    """
    if state.get("is_safe"):
        return "end"
    if state.get("retry_count", 0) < 4:
        return "retry"
    return "fallback"

# --- 5. Global Setup ---
load_dotenv()
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
groq_api_key = os.getenv("GROQ_API_KEY")

# --- Model Configuration ---
LLM_MODEL_NAME = "openai/gpt-oss-20b" 

llm = ChatGroq(model_name=LLM_MODEL_NAME, groq_api_key=groq_api_key,max_tokens=4096)

# --- 6. Build the Graph ---
graph = StateGraph(ChatState)
graph.add_node("chat_node", lambda state: chat_node(state, llm))
graph.add_node("safety_check_node", lambda state: safety_check_node(state, llm))
graph.add_node("handle_fallback_node", handle_fallback_node)

graph.set_entry_point("chat_node")
graph.add_edge("chat_node", "safety_check_node")
graph.add_edge("handle_fallback_node", END)

# Add the conditional edge for the retry loop
graph.add_conditional_edges(
    "safety_check_node",
    route_after_safety_check,
    {
        "retry": "chat_node",
        "fallback": "handle_fallback_node",
        "end": END
    }
)

# Compile the graph
quick_help_app = graph.compile()