Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Annotated, TypedDict, List, Dict | |
| from dotenv import load_dotenv | |
| # LangChain and LangGraph imports | |
| from langchain_groq import ChatGroq | |
| from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from pydantic import BaseModel, Field | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.graph.message import add_messages | |
| # --- 1. Pydantic Model for Structured Output --- | |
| class SafetyCheck(BaseModel): | |
| """Pydantic model for the safety check response.""" | |
| is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.") | |
| # --- 2. Define the State for the Chatbot --- | |
| class ChatState(TypedDict): | |
| """State of the chatbot, continuously appending messages.""" | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| is_safe: bool | |
| retry_count: int | |
| # --- 3. Define the Chatbot Nodes --- | |
| def chat_node(state: ChatState, llm): | |
| """ | |
| This node invokes the LLM to get a response using a ChatPromptTemplate. | |
| """ | |
| messages = state['messages'] | |
| retry_count = state.get('retry_count', 0) | |
| # Create a prompt template with the system message | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are a kind and empathetic AI assistant for mental well-being. " | |
| "Provide supportive, safe, and general advice. Do not provide medical diagnoses. " | |
| "Keep your responses concise and helpful."), | |
| ("placeholder", "{chat_history}") | |
| ]) | |
| # Add a warning to the LLM if this is a retry attempt | |
| if retry_count > 0: | |
| messages.append(SystemMessage( | |
| content="Your previous response was not suitable. Please try again, " | |
| "focusing on being supportive and safe." | |
| )) | |
| chain = prompt | llm | StrOutputParser() | |
| # Invoke the LLM with the conversation history | |
| response_text = chain.invoke({"chat_history": messages}) | |
| return {"messages": [AIMessage(content=response_text)], "retry_count": retry_count + 1} | |
| def safety_check_node(state: ChatState, llm) -> Dict: | |
| """ | |
| Evaluates the safety of the generated response using a structured LLM. | |
| """ | |
| # Get the most recent AI message to check | |
| ai_message = state['messages'][-1] | |
| # Create a dedicated structured LLM for this task | |
| structured_llm = llm.with_structured_output(SafetyCheck) | |
| prompt = ChatPromptTemplate.from_template( | |
| """You are a safety classification model. The user is in a mental health context. | |
| Analyze the following text for any harmful, unsafe, or non-supportive content. | |
| Then, provide a JSON object with a single boolean field 'is_safe'. | |
| Text to evaluate: | |
| '{text_to_evaluate}'""" | |
| ) | |
| safety_chain = prompt | structured_llm | |
| result = safety_chain.invoke({"text_to_evaluate": ai_message.content}) | |
| return {"is_safe": result.is_safe} | |
| def handle_fallback_node(state: ChatState) -> Dict: | |
| """ | |
| Provides a safe, generic response if the main LLM fails after retries. | |
| """ | |
| fallback_message = AIMessage( | |
| content="I am having a little trouble formulating a response right now. " | |
| "Remember that taking a moment to focus on your breath can be a helpful step." | |
| ) | |
| return {"messages": [fallback_message]} | |
| # --- 4. Define the Conditional Router --- | |
| def route_after_safety_check(state: ChatState) -> str: | |
| """ | |
| This router decides the next step after the safety check, enabling a retry loop. | |
| """ | |
| if state.get("is_safe"): | |
| return "end" | |
| if state.get("retry_count", 0) < 4: | |
| return "retry" | |
| return "fallback" | |
| # --- 5. Global Setup --- | |
| load_dotenv() | |
| os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| # --- Model Configuration --- | |
| LLM_MODEL_NAME = "openai/gpt-oss-20b" | |
| llm = ChatGroq(model_name=LLM_MODEL_NAME, groq_api_key=groq_api_key,max_tokens=4096) | |
| # --- 6. Build the Graph --- | |
| graph = StateGraph(ChatState) | |
| graph.add_node("chat_node", lambda state: chat_node(state, llm)) | |
| graph.add_node("safety_check_node", lambda state: safety_check_node(state, llm)) | |
| graph.add_node("handle_fallback_node", handle_fallback_node) | |
| graph.set_entry_point("chat_node") | |
| graph.add_edge("chat_node", "safety_check_node") | |
| graph.add_edge("handle_fallback_node", END) | |
| # Add the conditional edge for the retry loop | |
| graph.add_conditional_edges( | |
| "safety_check_node", | |
| route_after_safety_check, | |
| { | |
| "retry": "chat_node", | |
| "fallback": "handle_fallback_node", | |
| "end": END | |
| } | |
| ) | |
| # Compile the graph | |
| quick_help_app = graph.compile() |