mindease / src /quick_help.py
harsh122e2wr's picture
Upload 3 files
fca6e17 verified
import os
from typing import Annotated, TypedDict, List, Dict
from dotenv import load_dotenv
# LangChain and LangGraph imports
from langchain_groq import ChatGroq
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
# --- 1. Pydantic Model for Structured Output ---
class SafetyCheck(BaseModel):
"""Pydantic model for the safety check response."""
is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.")
# --- 2. Define the State for the Chatbot ---
class ChatState(TypedDict):
"""State of the chatbot, continuously appending messages."""
messages: Annotated[list[AnyMessage], add_messages]
is_safe: bool
retry_count: int
# --- 3. Define the Chatbot Nodes ---
def chat_node(state: ChatState, llm):
"""
This node invokes the LLM to get a response using a ChatPromptTemplate.
"""
messages = state['messages']
retry_count = state.get('retry_count', 0)
# Create a prompt template with the system message
prompt = ChatPromptTemplate.from_messages([
("system", "You are a kind and empathetic AI assistant for mental well-being. "
"Provide supportive, safe, and general advice. Do not provide medical diagnoses. "
"Keep your responses concise and helpful."),
("placeholder", "{chat_history}")
])
# Add a warning to the LLM if this is a retry attempt
if retry_count > 0:
messages.append(SystemMessage(
content="Your previous response was not suitable. Please try again, "
"focusing on being supportive and safe."
))
chain = prompt | llm | StrOutputParser()
# Invoke the LLM with the conversation history
response_text = chain.invoke({"chat_history": messages})
return {"messages": [AIMessage(content=response_text)], "retry_count": retry_count + 1}
def safety_check_node(state: ChatState, llm) -> Dict:
"""
Evaluates the safety of the generated response using a structured LLM.
"""
# Get the most recent AI message to check
ai_message = state['messages'][-1]
# Create a dedicated structured LLM for this task
structured_llm = llm.with_structured_output(SafetyCheck)
prompt = ChatPromptTemplate.from_template(
"""You are a safety classification model. The user is in a mental health context.
Analyze the following text for any harmful, unsafe, or non-supportive content.
Then, provide a JSON object with a single boolean field 'is_safe'.
Text to evaluate:
'{text_to_evaluate}'"""
)
safety_chain = prompt | structured_llm
result = safety_chain.invoke({"text_to_evaluate": ai_message.content})
return {"is_safe": result.is_safe}
def handle_fallback_node(state: ChatState) -> Dict:
"""
Provides a safe, generic response if the main LLM fails after retries.
"""
fallback_message = AIMessage(
content="I am having a little trouble formulating a response right now. "
"Remember that taking a moment to focus on your breath can be a helpful step."
)
return {"messages": [fallback_message]}
# --- 4. Define the Conditional Router ---
def route_after_safety_check(state: ChatState) -> str:
"""
This router decides the next step after the safety check, enabling a retry loop.
"""
if state.get("is_safe"):
return "end"
if state.get("retry_count", 0) < 4:
return "retry"
return "fallback"
# --- 5. Global Setup ---
load_dotenv()
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
groq_api_key = os.getenv("GROQ_API_KEY")
# --- Model Configuration ---
LLM_MODEL_NAME = "openai/gpt-oss-20b"
llm = ChatGroq(model_name=LLM_MODEL_NAME, groq_api_key=groq_api_key,max_tokens=4096)
# --- 6. Build the Graph ---
graph = StateGraph(ChatState)
graph.add_node("chat_node", lambda state: chat_node(state, llm))
graph.add_node("safety_check_node", lambda state: safety_check_node(state, llm))
graph.add_node("handle_fallback_node", handle_fallback_node)
graph.set_entry_point("chat_node")
graph.add_edge("chat_node", "safety_check_node")
graph.add_edge("handle_fallback_node", END)
# Add the conditional edge for the retry loop
graph.add_conditional_edges(
"safety_check_node",
route_after_safety_check,
{
"retry": "chat_node",
"fallback": "handle_fallback_node",
"end": END
}
)
# Compile the graph
quick_help_app = graph.compile()