Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- src/main_back.py +220 -0
- src/quick_help.py +134 -0
- src/streamlit_app.py +166 -0
src/main_back.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict, List, TypedDict, Annotated
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
# LangChain and LangGraph imports
|
| 6 |
+
from langchain_chroma import Chroma
|
| 7 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 8 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 9 |
+
from langchain_groq import ChatGroq
|
| 10 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 11 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 12 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 13 |
+
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage
|
| 14 |
+
|
| 15 |
+
from langgraph.graph import StateGraph, END, START
|
| 16 |
+
from langgraph.graph.message import add_messages
|
| 17 |
+
import sqlite3
|
| 18 |
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
| 19 |
+
from pydantic import BaseModel, Field
|
| 20 |
+
|
| 21 |
+
# --- Graph State Definition ---
|
| 22 |
+
class GraphState(TypedDict):
|
| 23 |
+
questionnaire_responses: Dict[str, int]
|
| 24 |
+
domain_scores: Dict[str, int]
|
| 25 |
+
primary_concern: str
|
| 26 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
| 27 |
+
is_safe: bool
|
| 28 |
+
retry_count: int
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# --- RAG Retriever Helper Function ---
|
| 32 |
+
def create_persistent_rag_retriever(pdf_paths: List[str], db_name: str, embedding_model):
|
| 33 |
+
"""Creates or loads a persistent RAG retriever from one or more PDF documents."""
|
| 34 |
+
persist_directory = f"./chroma_db/{db_name}"
|
| 35 |
+
if os.path.exists(persist_directory):
|
| 36 |
+
print(f"--- Loading existing persistent DB: {db_name} ---")
|
| 37 |
+
return Chroma(persist_directory=persist_directory, embedding_function=embedding_model).as_retriever()
|
| 38 |
+
|
| 39 |
+
print(f"--- Creating new persistent DB: {db_name} ---")
|
| 40 |
+
vector_store = None
|
| 41 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
| 42 |
+
|
| 43 |
+
for pdf_path in pdf_paths:
|
| 44 |
+
if not os.path.exists(pdf_path):
|
| 45 |
+
print(f"Warning: PDF not found at '{pdf_path}'. Skipping.")
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
print(f"--- Processing PDF: {pdf_path} ---")
|
| 49 |
+
loader = PyPDFLoader(pdf_path)
|
| 50 |
+
documents = loader.load()
|
| 51 |
+
splits = text_splitter.split_documents(documents)
|
| 52 |
+
|
| 53 |
+
if not splits: continue
|
| 54 |
+
|
| 55 |
+
if vector_store is None:
|
| 56 |
+
vector_store = Chroma.from_documents(documents=splits, embedding=embedding_model, persist_directory=persist_directory)
|
| 57 |
+
else:
|
| 58 |
+
vector_store.add_documents(splits)
|
| 59 |
+
|
| 60 |
+
if vector_store:
|
| 61 |
+
print(f"--- DB creation complete for {db_name} ---")
|
| 62 |
+
return vector_store.as_retriever(search_kwargs={'k': 3})
|
| 63 |
+
else:
|
| 64 |
+
print(f"--- Could not create DB for {db_name}. No documents processed. ---")
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
# --- Node Definitions ---
|
| 68 |
+
def questionnaire(state: GraphState) -> GraphState:
|
| 69 |
+
"""Calculates all domain scores and clears the questionnaire responses from the state."""
|
| 70 |
+
responses = state.get("questionnaire_responses", {})
|
| 71 |
+
domain_scores = {
|
| 72 |
+
"Depression": max(responses.get("1", 0), responses.get("2", 0)),
|
| 73 |
+
"Anger": responses.get("3", 0),
|
| 74 |
+
"Mania": max(responses.get("4", 0), responses.get("5", 0)),
|
| 75 |
+
"Anxiety": max(responses.get("6", 0), responses.get("7", 0), responses.get("8", 0)),
|
| 76 |
+
"Somatic_Symptoms": max(responses.get("9", 0), responses.get("10", 0)),
|
| 77 |
+
"Suicidal_Ideation": responses.get("11", 0),
|
| 78 |
+
"Psychosis": max(responses.get("12", 0), responses.get("13", 0)),
|
| 79 |
+
"Sleep_Problems": responses.get("14", 0),
|
| 80 |
+
"Memory": responses.get("15", 0),
|
| 81 |
+
"Repetitive_Thoughts_Behaviors": max(responses.get("16", 0), responses.get("17", 0)),
|
| 82 |
+
"Dissociation": responses.get("18", 0),
|
| 83 |
+
"Personality_Functioning": max(responses.get("19", 0), responses.get("20", 0)),
|
| 84 |
+
"Substance_Use": max(responses.get("21", 0), responses.get("22", 0), responses.get("23", 0)),
|
| 85 |
+
}
|
| 86 |
+
initial_question = "User has completed the initial questionnaire.Provide supportive steps and coping mechanisms."
|
| 87 |
+
initial_message = HumanMessage(content=initial_question)
|
| 88 |
+
|
| 89 |
+
return {"domain_scores": domain_scores, "retry_count": 0, "messages": [initial_message]}
|
| 90 |
+
|
| 91 |
+
def route_entry(state: GraphState)-> str:
|
| 92 |
+
if state.get("domain_scores"):
|
| 93 |
+
"""Routes to the appropriate RAG handler based on scores."""
|
| 94 |
+
scores = state.get("domain_scores", {})
|
| 95 |
+
if scores.get("Depression", 0) >= 2:
|
| 96 |
+
return "depression"
|
| 97 |
+
if scores.get("Anxiety", 0) >= 2:
|
| 98 |
+
return "anxiety"
|
| 99 |
+
return "no_action"
|
| 100 |
+
else:
|
| 101 |
+
return "questionnaire"
|
| 102 |
+
|
| 103 |
+
def handle_depression_rag(state: GraphState) -> GraphState:
|
| 104 |
+
"""Handles the conversational RAG pipeline for depression."""
|
| 105 |
+
score = state.get("domain_scores", {}).get("Depression", 0)
|
| 106 |
+
messages = state.get("messages", [])
|
| 107 |
+
retry_count = state.get("retry_count", 0)
|
| 108 |
+
|
| 109 |
+
retry_guidance = "Please provide a helpful and supportive plan."
|
| 110 |
+
if retry_count > 0: retry_guidance = "Your previous response was flagged. Please try again."
|
| 111 |
+
|
| 112 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 113 |
+
("system", """You are a kind and empathetic AI assistant. Your role is to answer the user's question in a supportive, conversational tone.
|
| 114 |
+
- **Do not** just summarize the documents. Synthesize the information and answer the user's question directly."""),
|
| 115 |
+
("human", "The user has a depression score of {score} on a scale of 0 (None) to 4 (Severe).My question is: {question}. {retry_guidance}")
|
| 116 |
+
])
|
| 117 |
+
|
| 118 |
+
rag_chain = prompt | llm | StrOutputParser()
|
| 119 |
+
response = rag_chain.invoke({ "score": score, "question": messages,"retry_guidance":retry_guidance})
|
| 120 |
+
return {"messages": [AIMessage(content=response)], "retry_count": retry_count + 1}
|
| 121 |
+
|
| 122 |
+
def handle_anxiety_rag(state: GraphState) -> GraphState:
|
| 123 |
+
"""Handles the conversational RAG pipeline for anxiety."""
|
| 124 |
+
score = state.get("domain_scores", {}).get("Depression", 0)
|
| 125 |
+
messages = state.get("messages", [])
|
| 126 |
+
retry_count = state.get("retry_count", 0)
|
| 127 |
+
|
| 128 |
+
retry_guidance = "Please provide a helpful and supportive plan for someone feeling anxious or worried."
|
| 129 |
+
|
| 130 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 131 |
+
("system", """You are a kind and empathetic AI assistant. Your role is to answer the user's question in a supportive, conversational tone.
|
| 132 |
+
- **Do not** just summarize the documents. Synthesize the information and answer the user's question directly."""),
|
| 133 |
+
("human", "The user has a anxiety score of {score} on a scale of 0 (None) to 4 (Severe).My question is: {question}. {retry_guidance}")
|
| 134 |
+
])
|
| 135 |
+
|
| 136 |
+
rag_chain = prompt | llm | StrOutputParser()
|
| 137 |
+
response = rag_chain.invoke({"score": score, "question": messages,"retry_guidance":retry_guidance})
|
| 138 |
+
return {"messages": [AIMessage(content=response)], "retry_count": retry_count + 1}
|
| 139 |
+
|
| 140 |
+
class SafetyCheck(BaseModel):
|
| 141 |
+
"""Pydantic model for the safety check response."""
|
| 142 |
+
is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.")
|
| 143 |
+
|
| 144 |
+
def safety_check(state: GraphState) -> GraphState:
|
| 145 |
+
"""Checks the generated response for safety."""
|
| 146 |
+
ai_response = next(
|
| 147 |
+
(m.content for m in state.get("messages",[]) if isinstance(m, AIMessage)), ''
|
| 148 |
+
)
|
| 149 |
+
structured_llm = llm.with_structured_output(SafetyCheck)
|
| 150 |
+
|
| 151 |
+
prompt = ChatPromptTemplate.from_template(
|
| 152 |
+
"""You are a safety classification model. The user is in a mental health context.
|
| 153 |
+
Analyze the following text for any harmful, unsafe, or non-supportive content.
|
| 154 |
+
Then, provide a JSON object with a single boolean field 'is_safe'.
|
| 155 |
+
|
| 156 |
+
Text to evaluate:
|
| 157 |
+
'{ai_response}'"""
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
safety_chain = prompt | structured_llm
|
| 161 |
+
result = safety_chain.invoke({"ai_response": ai_response})
|
| 162 |
+
|
| 163 |
+
return {"is_safe": result.is_safe}
|
| 164 |
+
|
| 165 |
+
def handle_fallback(state: GraphState) -> GraphState:
|
| 166 |
+
"""Provides a safe, generic response if retries fail."""
|
| 167 |
+
fallback_message = AIMessage(content="I am having trouble generating a specific plan right now. Please consider seeking support from a qualified professional.")
|
| 168 |
+
return {"messages": [fallback_message],"retry_count":0}
|
| 169 |
+
|
| 170 |
+
def finalize_response(state: GraphState) -> GraphState:
|
| 171 |
+
"""Finalizes the turn by returning the safe AI response as a message."""
|
| 172 |
+
return {"retry_count":0}
|
| 173 |
+
|
| 174 |
+
def route_after_safety_check(state: "GraphState") -> str:
|
| 175 |
+
"""Routes after the safety check, enabling the retry loop."""
|
| 176 |
+
if state.get("is_safe"): return "finalize"
|
| 177 |
+
if state.get("retry_count", 0) < 2: return "retry"
|
| 178 |
+
return "fallback"
|
| 179 |
+
|
| 180 |
+
def entry_point(state: GraphState) -> GraphState:
|
| 181 |
+
"""A dedicated node for the graph's entry point that makes no state changes."""
|
| 182 |
+
return state
|
| 183 |
+
|
| 184 |
+
load_dotenv()
|
| 185 |
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
| 186 |
+
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile",max_tokens=4096)
|
| 187 |
+
|
| 188 |
+
graph = StateGraph(GraphState)
|
| 189 |
+
|
| 190 |
+
graph.add_node("entry_point",entry_point)
|
| 191 |
+
graph.add_node("questionnaire",questionnaire)
|
| 192 |
+
graph.add_node("handle_depression_rag", handle_depression_rag)
|
| 193 |
+
graph.add_node("handle_anxiety_rag", handle_anxiety_rag)
|
| 194 |
+
graph.add_node("safety_check_depression", safety_check)
|
| 195 |
+
graph.add_node("safety_check_anxiety", safety_check)
|
| 196 |
+
graph.add_node("handle_fallback", handle_fallback)
|
| 197 |
+
graph.add_node("finalize_response", finalize_response)
|
| 198 |
+
|
| 199 |
+
graph.add_edge(START,"entry_point")
|
| 200 |
+
graph.add_conditional_edges("entry_point",route_entry,{"depression":"handle_depression_rag","anxiety":"handle_anxiety_rag","no_action":END,"questionnaire":"questionnaire"})
|
| 201 |
+
graph.add_edge("questionnaire","entry_point")
|
| 202 |
+
|
| 203 |
+
graph.add_edge("handle_depression_rag", "safety_check_depression")
|
| 204 |
+
graph.add_conditional_edges(
|
| 205 |
+
"safety_check_depression", route_after_safety_check,
|
| 206 |
+
{"retry": "handle_depression_rag", "fallback": "handle_fallback", "finalize": "finalize_response"}
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
graph.add_edge("handle_anxiety_rag", "safety_check_anxiety")
|
| 210 |
+
graph.add_conditional_edges(
|
| 211 |
+
"safety_check_anxiety", route_after_safety_check,
|
| 212 |
+
{"retry": "handle_anxiety_rag", "fallback": "handle_fallback", "finalize": "finalize_response"}
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
graph.add_edge("handle_fallback", END)
|
| 216 |
+
graph.add_edge("finalize_response", END)
|
| 217 |
+
DB_PATH = "/app/data/chatbot.sqlite"
|
| 218 |
+
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
|
| 219 |
+
checkpointer = SqliteSaver(conn=conn)
|
| 220 |
+
app = graph.compile(checkpointer=checkpointer)
|
src/quick_help.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Annotated, TypedDict, List, Dict
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
# LangChain and LangGraph imports
|
| 6 |
+
from langchain_groq import ChatGroq
|
| 7 |
+
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
|
| 8 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 9 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
from langgraph.graph import StateGraph, END
|
| 12 |
+
from langgraph.graph.message import add_messages
|
| 13 |
+
|
| 14 |
+
# --- 1. Pydantic Model for Structured Output ---
|
| 15 |
+
class SafetyCheck(BaseModel):
|
| 16 |
+
"""Pydantic model for the safety check response."""
|
| 17 |
+
is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.")
|
| 18 |
+
|
| 19 |
+
# --- 2. Define the State for the Chatbot ---
|
| 20 |
+
class ChatState(TypedDict):
|
| 21 |
+
"""State of the chatbot, continuously appending messages."""
|
| 22 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
| 23 |
+
is_safe: bool
|
| 24 |
+
retry_count: int
|
| 25 |
+
|
| 26 |
+
# --- 3. Define the Chatbot Nodes ---
|
| 27 |
+
def chat_node(state: ChatState, llm):
|
| 28 |
+
"""
|
| 29 |
+
This node invokes the LLM to get a response using a ChatPromptTemplate.
|
| 30 |
+
"""
|
| 31 |
+
messages = state['messages']
|
| 32 |
+
retry_count = state.get('retry_count', 0)
|
| 33 |
+
|
| 34 |
+
# Create a prompt template with the system message
|
| 35 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 36 |
+
("system", "You are a kind and empathetic AI assistant for mental well-being. "
|
| 37 |
+
"Provide supportive, safe, and general advice. Do not provide medical diagnoses. "
|
| 38 |
+
"Keep your responses concise and helpful."),
|
| 39 |
+
("placeholder", "{chat_history}")
|
| 40 |
+
])
|
| 41 |
+
|
| 42 |
+
# Add a warning to the LLM if this is a retry attempt
|
| 43 |
+
if retry_count > 0:
|
| 44 |
+
messages.append(SystemMessage(
|
| 45 |
+
content="Your previous response was not suitable. Please try again, "
|
| 46 |
+
"focusing on being supportive and safe."
|
| 47 |
+
))
|
| 48 |
+
|
| 49 |
+
chain = prompt | llm | StrOutputParser()
|
| 50 |
+
|
| 51 |
+
# Invoke the LLM with the conversation history
|
| 52 |
+
response_text = chain.invoke({"chat_history": messages})
|
| 53 |
+
|
| 54 |
+
return {"messages": [AIMessage(content=response_text)], "retry_count": retry_count + 1}
|
| 55 |
+
|
| 56 |
+
def safety_check_node(state: ChatState, llm) -> Dict:
|
| 57 |
+
"""
|
| 58 |
+
Evaluates the safety of the generated response using a structured LLM.
|
| 59 |
+
"""
|
| 60 |
+
# Get the most recent AI message to check
|
| 61 |
+
ai_message = state['messages'][-1]
|
| 62 |
+
|
| 63 |
+
# Create a dedicated structured LLM for this task
|
| 64 |
+
structured_llm = llm.with_structured_output(SafetyCheck)
|
| 65 |
+
|
| 66 |
+
prompt = ChatPromptTemplate.from_template(
|
| 67 |
+
"""You are a safety classification model. The user is in a mental health context.
|
| 68 |
+
Analyze the following text for any harmful, unsafe, or non-supportive content.
|
| 69 |
+
Then, provide a JSON object with a single boolean field 'is_safe'.
|
| 70 |
+
|
| 71 |
+
Text to evaluate:
|
| 72 |
+
'{text_to_evaluate}'"""
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
safety_chain = prompt | structured_llm
|
| 76 |
+
result = safety_chain.invoke({"text_to_evaluate": ai_message.content})
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
return {"is_safe": result.is_safe}
|
| 80 |
+
|
| 81 |
+
def handle_fallback_node(state: ChatState) -> Dict:
|
| 82 |
+
"""
|
| 83 |
+
Provides a safe, generic response if the main LLM fails after retries.
|
| 84 |
+
"""
|
| 85 |
+
fallback_message = AIMessage(
|
| 86 |
+
content="I am having a little trouble formulating a response right now. "
|
| 87 |
+
"Remember that taking a moment to focus on your breath can be a helpful step."
|
| 88 |
+
)
|
| 89 |
+
return {"messages": [fallback_message]}
|
| 90 |
+
|
| 91 |
+
# --- 4. Define the Conditional Router ---
|
| 92 |
+
def route_after_safety_check(state: ChatState) -> str:
|
| 93 |
+
"""
|
| 94 |
+
This router decides the next step after the safety check, enabling a retry loop.
|
| 95 |
+
"""
|
| 96 |
+
if state.get("is_safe"):
|
| 97 |
+
return "end"
|
| 98 |
+
if state.get("retry_count", 0) < 4:
|
| 99 |
+
return "retry"
|
| 100 |
+
return "fallback"
|
| 101 |
+
|
| 102 |
+
# --- 5. Global Setup ---
|
| 103 |
+
load_dotenv()
|
| 104 |
+
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
|
| 105 |
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
| 106 |
+
|
| 107 |
+
# --- Model Configuration ---
|
| 108 |
+
LLM_MODEL_NAME = "openai/gpt-oss-20b"
|
| 109 |
+
|
| 110 |
+
llm = ChatGroq(model_name=LLM_MODEL_NAME, groq_api_key=groq_api_key,max_tokens=4096)
|
| 111 |
+
|
| 112 |
+
# --- 6. Build the Graph ---
|
| 113 |
+
graph = StateGraph(ChatState)
|
| 114 |
+
graph.add_node("chat_node", lambda state: chat_node(state, llm))
|
| 115 |
+
graph.add_node("safety_check_node", lambda state: safety_check_node(state, llm))
|
| 116 |
+
graph.add_node("handle_fallback_node", handle_fallback_node)
|
| 117 |
+
|
| 118 |
+
graph.set_entry_point("chat_node")
|
| 119 |
+
graph.add_edge("chat_node", "safety_check_node")
|
| 120 |
+
graph.add_edge("handle_fallback_node", END)
|
| 121 |
+
|
| 122 |
+
# Add the conditional edge for the retry loop
|
| 123 |
+
graph.add_conditional_edges(
|
| 124 |
+
"safety_check_node",
|
| 125 |
+
route_after_safety_check,
|
| 126 |
+
{
|
| 127 |
+
"retry": "chat_node",
|
| 128 |
+
"fallback": "handle_fallback_node",
|
| 129 |
+
"end": END
|
| 130 |
+
}
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Compile the graph
|
| 134 |
+
quick_help_app = graph.compile()
|
src/streamlit_app.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import uuid
|
| 3 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 4 |
+
|
| 5 |
+
# --- Backend Imports ---
|
| 6 |
+
# Ensure these files are in the same directory as your Streamlit app
|
| 7 |
+
from main_back import app
|
| 8 |
+
from quick_help import quick_help_app
|
| 9 |
+
|
| 10 |
+
# --- Page Configuration ---
|
| 11 |
+
st.set_page_config(page_title="Gen AI Mental Health Assistant", layout="wide")
|
| 12 |
+
|
| 13 |
+
st.title("Mental Health AI Assistant")
|
| 14 |
+
st.markdown("A safe space to find support and track your well-being.")
|
| 15 |
+
|
| 16 |
+
# --- Questionnaire Data ---
|
| 17 |
+
# Based on the DSM-5 PDF provided earlier
|
| 18 |
+
questionnaire_data = [
|
| 19 |
+
{"id": "1", "text": "Little interest or pleasure in doing things?"},
|
| 20 |
+
{"id": "2", "text": "Feeling down, depressed, or hopeless?"},
|
| 21 |
+
{"id": "3", "text": "Feeling more irritated, grouchy, or angry than usual?"},
|
| 22 |
+
{"id": "4", "text": "Sleeping less than usual, but still have a lot of energy?"},
|
| 23 |
+
{"id": "5", "text": "Starting lots more projects than usual or doing more risky things than usual?"},
|
| 24 |
+
{"id": "6", "text": "Feeling nervous, anxious, frightened, worried, or on edge?"},
|
| 25 |
+
{"id": "7", "text": "Feeling panic or being frightened?"},
|
| 26 |
+
{"id": "8", "text": "Avoiding situations that make you anxious?"},
|
| 27 |
+
{"id": "9", "text": "Unexplained aches and pains (e.g., head, back, joints, abdomen, legs)?"},
|
| 28 |
+
{"id": "10", "text": "Feeling that your illnesses are not being taken seriously enough?"},
|
| 29 |
+
{"id": "11", "text": "Thoughts of actually hurting yourself?"},
|
| 30 |
+
{"id": "12", "text": "Hearing things other people couldn't hear, such as voices even when no one was around?"},
|
| 31 |
+
{"id": "13", "text": "Feeling that someone could hear your thoughts, or that you could hear what another person was thinking?"},
|
| 32 |
+
{"id": "14", "text": "Problems with sleep that affected your sleep quality over all?"},
|
| 33 |
+
{"id": "15", "text": "Problems with memory (e.g., learning new information) or with location (e.g., finding your way home)?"},
|
| 34 |
+
{"id": "16", "text": "Unpleasant thoughts, urges, or images that repeatedly enter your mind?"},
|
| 35 |
+
{"id": "17", "text": "Feeling driven to perform certain behaviors or mental acts over and over again?"},
|
| 36 |
+
{"id": "18", "text": "Feeling detached or distant from yourself, your body, your physical surroundings, or your memories?"},
|
| 37 |
+
{"id": "19", "text": "Not knowing who you really are or what you want out of life?"},
|
| 38 |
+
{"id": "20", "text": "Not feeling close to other people or enjoying your relationships with them?"},
|
| 39 |
+
{"id": "21", "text": "Drinking at least 4 drinks of any kind of alcohol in a single day?"},
|
| 40 |
+
{"id": "22", "text": "Smoking any cigarettes, a cigar, or pipe, or using snuff or chewing tobacco?"},
|
| 41 |
+
{"id": "23", "text": "Using any medicines ON YOUR OWN, that is, without a doctor's prescription, in greater amounts or longer than prescribed?"},
|
| 42 |
+
]
|
| 43 |
+
response_options = ["None (0)", "Slight (1)", "Mild (2)", "Moderate (3)", "Severe (4)"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# --- App State Initialization ---
|
| 47 |
+
def initialize_session_state():
|
| 48 |
+
# General app mode
|
| 49 |
+
if 'app_mode' not in st.session_state:
|
| 50 |
+
st.session_state.app_mode = "Quick Help"
|
| 51 |
+
|
| 52 |
+
# Quick Help state
|
| 53 |
+
if 'quick_help_history' not in st.session_state:
|
| 54 |
+
st.session_state.quick_help_history = []
|
| 55 |
+
|
| 56 |
+
# Tracking Health state
|
| 57 |
+
if 'tracking_stage' not in st.session_state:
|
| 58 |
+
st.session_state.tracking_stage = "questionnaire" # questionnaire -> chat
|
| 59 |
+
if 'tracking_history' not in st.session_state:
|
| 60 |
+
st.session_state.tracking_history = []
|
| 61 |
+
if 'tracking_thread_id' not in st.session_state:
|
| 62 |
+
st.session_state.tracking_thread_id = None
|
| 63 |
+
|
| 64 |
+
initialize_session_state()
|
| 65 |
+
|
| 66 |
+
# --- Sidebar for Navigation ---
|
| 67 |
+
with st.sidebar:
|
| 68 |
+
st.header("Navigation")
|
| 69 |
+
st.session_state.app_mode = st.radio(
|
| 70 |
+
"Choose a feature:",
|
| 71 |
+
("Quick Help", "Track Your Health"),
|
| 72 |
+
key="app_mode_selector"
|
| 73 |
+
)
|
| 74 |
+
st.info("Your conversations are private. We do not store personally identifiable information.")
|
| 75 |
+
|
| 76 |
+
# --- Main App Logic ---
|
| 77 |
+
|
| 78 |
+
# --- Quick Help Feature ---
|
| 79 |
+
if st.session_state.app_mode == "Quick Help":
|
| 80 |
+
st.header("Quick Help Chat")
|
| 81 |
+
st.markdown("Get immediate, supportive advice. How are you feeling right now?")
|
| 82 |
+
|
| 83 |
+
# Display chat history
|
| 84 |
+
for message in st.session_state.quick_help_history:
|
| 85 |
+
with st.chat_message(message["role"]):
|
| 86 |
+
st.markdown(message["content"])
|
| 87 |
+
|
| 88 |
+
# Handle user input
|
| 89 |
+
if user_input := st.chat_input("Share your thoughts..."):
|
| 90 |
+
st.session_state.quick_help_history.append({"role": "user", "content": user_input})
|
| 91 |
+
with st.chat_message("user"):
|
| 92 |
+
st.markdown(user_input)
|
| 93 |
+
|
| 94 |
+
with st.chat_message("assistant"):
|
| 95 |
+
# The input for the non-persistent app is the full history each time
|
| 96 |
+
history_for_input = [
|
| 97 |
+
HumanMessage(content=msg["content"]) if msg["role"] == "user" else AIMessage(content=msg["content"])
|
| 98 |
+
for msg in st.session_state.quick_help_history
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
ai_response = quick_help_app.invoke({'messages': history_for_input}).get("messages",["Could not generate response"])[-1].content
|
| 102 |
+
st.write(ai_response)
|
| 103 |
+
|
| 104 |
+
st.session_state.quick_help_history.append({"role": "assistant", "content": ai_response})
|
| 105 |
+
|
| 106 |
+
# --- Track Your Health Feature ---
|
| 107 |
+
elif st.session_state.app_mode == "Track Your Health":
|
| 108 |
+
|
| 109 |
+
# --- Stage 1: Questionnaire ---
|
| 110 |
+
if st.session_state.tracking_stage == "questionnaire":
|
| 111 |
+
st.header("Health & Well-being Questionnaire")
|
| 112 |
+
st.markdown("Please answer the following questions based on your feelings over the **last two weeks**.")
|
| 113 |
+
|
| 114 |
+
with st.form("health_questionnaire"):
|
| 115 |
+
responses = {}
|
| 116 |
+
for q in questionnaire_data:
|
| 117 |
+
# Get the integer value from the selection
|
| 118 |
+
response_str = st.radio(q["text"], options=response_options, key=q["id"], horizontal=True)
|
| 119 |
+
responses[q["id"]] = response_options.index(response_str)
|
| 120 |
+
|
| 121 |
+
submitted = st.form_submit_button("Analyze & Create My Plan")
|
| 122 |
+
|
| 123 |
+
if submitted:
|
| 124 |
+
# Generate a unique thread ID for this user's persistent session
|
| 125 |
+
st.session_state.tracking_thread_id = f"user_{uuid.uuid4()}"
|
| 126 |
+
config = {"configurable": {"thread_id": st.session_state.tracking_thread_id}}
|
| 127 |
+
|
| 128 |
+
# Call the main backend app with the questionnaire responses
|
| 129 |
+
initial_input = {"questionnaire_responses": responses}
|
| 130 |
+
|
| 131 |
+
with st.spinner("Analyzing your responses and generating a personalized plan..."):
|
| 132 |
+
# Use .invoke() for the first call as we want the full plan at once
|
| 133 |
+
result = app.invoke(initial_input, config=config)
|
| 134 |
+
initial_plan = result.get('messages', ["Could not generate a plan."])[-1].content
|
| 135 |
+
|
| 136 |
+
# Store the initial plan and switch to chat mode
|
| 137 |
+
st.session_state.tracking_history = [{"role": "assistant", "content": initial_plan}]
|
| 138 |
+
st.session_state.tracking_stage = "chat"
|
| 139 |
+
st.rerun()
|
| 140 |
+
|
| 141 |
+
# --- Stage 2: Chat with the Plan ---
|
| 142 |
+
elif st.session_state.tracking_stage == "chat":
|
| 143 |
+
st.header("Your Personalized Plan & Chat")
|
| 144 |
+
st.markdown("Here is an initial plan based on your responses. You can ask questions about it or request different exercises.")
|
| 145 |
+
|
| 146 |
+
config = {"configurable": {"thread_id": st.session_state.tracking_thread_id}}
|
| 147 |
+
|
| 148 |
+
# Display chat history
|
| 149 |
+
for message in st.session_state.tracking_history:
|
| 150 |
+
with st.chat_message(message["role"]):
|
| 151 |
+
st.markdown(message["content"])
|
| 152 |
+
|
| 153 |
+
# Handle user input for follow-up questions
|
| 154 |
+
if user_input := st.chat_input("Ask a question about your plan..."):
|
| 155 |
+
st.session_state.tracking_history.append({"role": "user", "content": user_input})
|
| 156 |
+
with st.chat_message("user"):
|
| 157 |
+
st.markdown(user_input)
|
| 158 |
+
|
| 159 |
+
with st.chat_message("assistant"):
|
| 160 |
+
# For follow-ups, we only need to send the new message.
|
| 161 |
+
# The checkpointer on the backend handles loading the history.
|
| 162 |
+
follow_up_input = {"messages": [HumanMessage(content=user_input)]}
|
| 163 |
+
result = app.invoke(follow_up_input, config=config)
|
| 164 |
+
ai_response = result.get('messages', ["Could not generate a plan."])[-1].content
|
| 165 |
+
|
| 166 |
+
st.session_state.tracking_history.append({"role": "assistant", "content": ai_response})
|