Spaces:
Sleeping
Sleeping
File size: 4,908 Bytes
fca6e17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
from typing import Annotated, TypedDict, List, Dict
from dotenv import load_dotenv
# LangChain and LangGraph imports
from langchain_groq import ChatGroq
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
# --- 1. Pydantic Model for Structured Output ---
class SafetyCheck(BaseModel):
"""Pydantic model for the safety check response."""
is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.")
# --- 2. Define the State for the Chatbot ---
class ChatState(TypedDict):
"""State of the chatbot, continuously appending messages."""
messages: Annotated[list[AnyMessage], add_messages]
is_safe: bool
retry_count: int
# --- 3. Define the Chatbot Nodes ---
def chat_node(state: ChatState, llm):
"""
This node invokes the LLM to get a response using a ChatPromptTemplate.
"""
messages = state['messages']
retry_count = state.get('retry_count', 0)
# Create a prompt template with the system message
prompt = ChatPromptTemplate.from_messages([
("system", "You are a kind and empathetic AI assistant for mental well-being. "
"Provide supportive, safe, and general advice. Do not provide medical diagnoses. "
"Keep your responses concise and helpful."),
("placeholder", "{chat_history}")
])
# Add a warning to the LLM if this is a retry attempt
if retry_count > 0:
messages.append(SystemMessage(
content="Your previous response was not suitable. Please try again, "
"focusing on being supportive and safe."
))
chain = prompt | llm | StrOutputParser()
# Invoke the LLM with the conversation history
response_text = chain.invoke({"chat_history": messages})
return {"messages": [AIMessage(content=response_text)], "retry_count": retry_count + 1}
def safety_check_node(state: ChatState, llm) -> Dict:
"""
Evaluates the safety of the generated response using a structured LLM.
"""
# Get the most recent AI message to check
ai_message = state['messages'][-1]
# Create a dedicated structured LLM for this task
structured_llm = llm.with_structured_output(SafetyCheck)
prompt = ChatPromptTemplate.from_template(
"""You are a safety classification model. The user is in a mental health context.
Analyze the following text for any harmful, unsafe, or non-supportive content.
Then, provide a JSON object with a single boolean field 'is_safe'.
Text to evaluate:
'{text_to_evaluate}'"""
)
safety_chain = prompt | structured_llm
result = safety_chain.invoke({"text_to_evaluate": ai_message.content})
return {"is_safe": result.is_safe}
def handle_fallback_node(state: ChatState) -> Dict:
"""
Provides a safe, generic response if the main LLM fails after retries.
"""
fallback_message = AIMessage(
content="I am having a little trouble formulating a response right now. "
"Remember that taking a moment to focus on your breath can be a helpful step."
)
return {"messages": [fallback_message]}
# --- 4. Define the Conditional Router ---
def route_after_safety_check(state: ChatState) -> str:
"""
This router decides the next step after the safety check, enabling a retry loop.
"""
if state.get("is_safe"):
return "end"
if state.get("retry_count", 0) < 4:
return "retry"
return "fallback"
# --- 5. Global Setup ---
load_dotenv()
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
groq_api_key = os.getenv("GROQ_API_KEY")
# --- Model Configuration ---
LLM_MODEL_NAME = "openai/gpt-oss-20b"
llm = ChatGroq(model_name=LLM_MODEL_NAME, groq_api_key=groq_api_key,max_tokens=4096)
# --- 6. Build the Graph ---
graph = StateGraph(ChatState)
graph.add_node("chat_node", lambda state: chat_node(state, llm))
graph.add_node("safety_check_node", lambda state: safety_check_node(state, llm))
graph.add_node("handle_fallback_node", handle_fallback_node)
graph.set_entry_point("chat_node")
graph.add_edge("chat_node", "safety_check_node")
graph.add_edge("handle_fallback_node", END)
# Add the conditional edge for the retry loop
graph.add_conditional_edges(
"safety_check_node",
route_after_safety_check,
{
"retry": "chat_node",
"fallback": "handle_fallback_node",
"end": END
}
)
# Compile the graph
quick_help_app = graph.compile() |