harsh122e2wr commited on
Commit
fca6e17
·
verified ·
1 Parent(s): ec54f60

Upload 3 files

Browse files
Files changed (3) hide show
  1. src/main_back.py +220 -0
  2. src/quick_help.py +134 -0
  3. src/streamlit_app.py +166 -0
src/main_back.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, TypedDict, Annotated
3
+ from dotenv import load_dotenv
4
+
5
+ # LangChain and LangGraph imports
6
+ from langchain_chroma import Chroma
7
+ from langchain_community.document_loaders import PyPDFLoader
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ from langchain_groq import ChatGroq
10
+ from langchain_huggingface import HuggingFaceEmbeddings
11
+ from langchain_core.prompts import ChatPromptTemplate
12
+ from langchain_core.output_parsers import StrOutputParser
13
+ from langchain_core.messages import AnyMessage, AIMessage, HumanMessage
14
+
15
+ from langgraph.graph import StateGraph, END, START
16
+ from langgraph.graph.message import add_messages
17
+ import sqlite3
18
+ from langgraph.checkpoint.sqlite import SqliteSaver
19
+ from pydantic import BaseModel, Field
20
+
21
+ # --- Graph State Definition ---
22
+ class GraphState(TypedDict):
23
+ questionnaire_responses: Dict[str, int]
24
+ domain_scores: Dict[str, int]
25
+ primary_concern: str
26
+ messages: Annotated[list[AnyMessage], add_messages]
27
+ is_safe: bool
28
+ retry_count: int
29
+
30
+
31
+ # --- RAG Retriever Helper Function ---
32
+ def create_persistent_rag_retriever(pdf_paths: List[str], db_name: str, embedding_model):
33
+ """Creates or loads a persistent RAG retriever from one or more PDF documents."""
34
+ persist_directory = f"./chroma_db/{db_name}"
35
+ if os.path.exists(persist_directory):
36
+ print(f"--- Loading existing persistent DB: {db_name} ---")
37
+ return Chroma(persist_directory=persist_directory, embedding_function=embedding_model).as_retriever()
38
+
39
+ print(f"--- Creating new persistent DB: {db_name} ---")
40
+ vector_store = None
41
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
42
+
43
+ for pdf_path in pdf_paths:
44
+ if not os.path.exists(pdf_path):
45
+ print(f"Warning: PDF not found at '{pdf_path}'. Skipping.")
46
+ continue
47
+
48
+ print(f"--- Processing PDF: {pdf_path} ---")
49
+ loader = PyPDFLoader(pdf_path)
50
+ documents = loader.load()
51
+ splits = text_splitter.split_documents(documents)
52
+
53
+ if not splits: continue
54
+
55
+ if vector_store is None:
56
+ vector_store = Chroma.from_documents(documents=splits, embedding=embedding_model, persist_directory=persist_directory)
57
+ else:
58
+ vector_store.add_documents(splits)
59
+
60
+ if vector_store:
61
+ print(f"--- DB creation complete for {db_name} ---")
62
+ return vector_store.as_retriever(search_kwargs={'k': 3})
63
+ else:
64
+ print(f"--- Could not create DB for {db_name}. No documents processed. ---")
65
+ return None
66
+
67
+ # --- Node Definitions ---
68
+ def questionnaire(state: GraphState) -> GraphState:
69
+ """Calculates all domain scores and clears the questionnaire responses from the state."""
70
+ responses = state.get("questionnaire_responses", {})
71
+ domain_scores = {
72
+ "Depression": max(responses.get("1", 0), responses.get("2", 0)),
73
+ "Anger": responses.get("3", 0),
74
+ "Mania": max(responses.get("4", 0), responses.get("5", 0)),
75
+ "Anxiety": max(responses.get("6", 0), responses.get("7", 0), responses.get("8", 0)),
76
+ "Somatic_Symptoms": max(responses.get("9", 0), responses.get("10", 0)),
77
+ "Suicidal_Ideation": responses.get("11", 0),
78
+ "Psychosis": max(responses.get("12", 0), responses.get("13", 0)),
79
+ "Sleep_Problems": responses.get("14", 0),
80
+ "Memory": responses.get("15", 0),
81
+ "Repetitive_Thoughts_Behaviors": max(responses.get("16", 0), responses.get("17", 0)),
82
+ "Dissociation": responses.get("18", 0),
83
+ "Personality_Functioning": max(responses.get("19", 0), responses.get("20", 0)),
84
+ "Substance_Use": max(responses.get("21", 0), responses.get("22", 0), responses.get("23", 0)),
85
+ }
86
+ initial_question = "User has completed the initial questionnaire.Provide supportive steps and coping mechanisms."
87
+ initial_message = HumanMessage(content=initial_question)
88
+
89
+ return {"domain_scores": domain_scores, "retry_count": 0, "messages": [initial_message]}
90
+
91
+ def route_entry(state: GraphState)-> str:
92
+ if state.get("domain_scores"):
93
+ """Routes to the appropriate RAG handler based on scores."""
94
+ scores = state.get("domain_scores", {})
95
+ if scores.get("Depression", 0) >= 2:
96
+ return "depression"
97
+ if scores.get("Anxiety", 0) >= 2:
98
+ return "anxiety"
99
+ return "no_action"
100
+ else:
101
+ return "questionnaire"
102
+
103
+ def handle_depression_rag(state: GraphState) -> GraphState:
104
+ """Handles the conversational RAG pipeline for depression."""
105
+ score = state.get("domain_scores", {}).get("Depression", 0)
106
+ messages = state.get("messages", [])
107
+ retry_count = state.get("retry_count", 0)
108
+
109
+ retry_guidance = "Please provide a helpful and supportive plan."
110
+ if retry_count > 0: retry_guidance = "Your previous response was flagged. Please try again."
111
+
112
+ prompt = ChatPromptTemplate.from_messages([
113
+ ("system", """You are a kind and empathetic AI assistant. Your role is to answer the user's question in a supportive, conversational tone.
114
+ - **Do not** just summarize the documents. Synthesize the information and answer the user's question directly."""),
115
+ ("human", "The user has a depression score of {score} on a scale of 0 (None) to 4 (Severe).My question is: {question}. {retry_guidance}")
116
+ ])
117
+
118
+ rag_chain = prompt | llm | StrOutputParser()
119
+ response = rag_chain.invoke({ "score": score, "question": messages,"retry_guidance":retry_guidance})
120
+ return {"messages": [AIMessage(content=response)], "retry_count": retry_count + 1}
121
+
122
+ def handle_anxiety_rag(state: GraphState) -> GraphState:
123
+ """Handles the conversational RAG pipeline for anxiety."""
124
+ score = state.get("domain_scores", {}).get("Depression", 0)
125
+ messages = state.get("messages", [])
126
+ retry_count = state.get("retry_count", 0)
127
+
128
+ retry_guidance = "Please provide a helpful and supportive plan for someone feeling anxious or worried."
129
+
130
+ prompt = ChatPromptTemplate.from_messages([
131
+ ("system", """You are a kind and empathetic AI assistant. Your role is to answer the user's question in a supportive, conversational tone.
132
+ - **Do not** just summarize the documents. Synthesize the information and answer the user's question directly."""),
133
+ ("human", "The user has a anxiety score of {score} on a scale of 0 (None) to 4 (Severe).My question is: {question}. {retry_guidance}")
134
+ ])
135
+
136
+ rag_chain = prompt | llm | StrOutputParser()
137
+ response = rag_chain.invoke({"score": score, "question": messages,"retry_guidance":retry_guidance})
138
+ return {"messages": [AIMessage(content=response)], "retry_count": retry_count + 1}
139
+
140
+ class SafetyCheck(BaseModel):
141
+ """Pydantic model for the safety check response."""
142
+ is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.")
143
+
144
+ def safety_check(state: GraphState) -> GraphState:
145
+ """Checks the generated response for safety."""
146
+ ai_response = next(
147
+ (m.content for m in state.get("messages",[]) if isinstance(m, AIMessage)), ''
148
+ )
149
+ structured_llm = llm.with_structured_output(SafetyCheck)
150
+
151
+ prompt = ChatPromptTemplate.from_template(
152
+ """You are a safety classification model. The user is in a mental health context.
153
+ Analyze the following text for any harmful, unsafe, or non-supportive content.
154
+ Then, provide a JSON object with a single boolean field 'is_safe'.
155
+
156
+ Text to evaluate:
157
+ '{ai_response}'"""
158
+ )
159
+
160
+ safety_chain = prompt | structured_llm
161
+ result = safety_chain.invoke({"ai_response": ai_response})
162
+
163
+ return {"is_safe": result.is_safe}
164
+
165
+ def handle_fallback(state: GraphState) -> GraphState:
166
+ """Provides a safe, generic response if retries fail."""
167
+ fallback_message = AIMessage(content="I am having trouble generating a specific plan right now. Please consider seeking support from a qualified professional.")
168
+ return {"messages": [fallback_message],"retry_count":0}
169
+
170
+ def finalize_response(state: GraphState) -> GraphState:
171
+ """Finalizes the turn by returning the safe AI response as a message."""
172
+ return {"retry_count":0}
173
+
174
+ def route_after_safety_check(state: "GraphState") -> str:
175
+ """Routes after the safety check, enabling the retry loop."""
176
+ if state.get("is_safe"): return "finalize"
177
+ if state.get("retry_count", 0) < 2: return "retry"
178
+ return "fallback"
179
+
180
+ def entry_point(state: GraphState) -> GraphState:
181
+ """A dedicated node for the graph's entry point that makes no state changes."""
182
+ return state
183
+
184
+ load_dotenv()
185
+ groq_api_key = os.getenv("GROQ_API_KEY")
186
+ llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile",max_tokens=4096)
187
+
188
+ graph = StateGraph(GraphState)
189
+
190
+ graph.add_node("entry_point",entry_point)
191
+ graph.add_node("questionnaire",questionnaire)
192
+ graph.add_node("handle_depression_rag", handle_depression_rag)
193
+ graph.add_node("handle_anxiety_rag", handle_anxiety_rag)
194
+ graph.add_node("safety_check_depression", safety_check)
195
+ graph.add_node("safety_check_anxiety", safety_check)
196
+ graph.add_node("handle_fallback", handle_fallback)
197
+ graph.add_node("finalize_response", finalize_response)
198
+
199
+ graph.add_edge(START,"entry_point")
200
+ graph.add_conditional_edges("entry_point",route_entry,{"depression":"handle_depression_rag","anxiety":"handle_anxiety_rag","no_action":END,"questionnaire":"questionnaire"})
201
+ graph.add_edge("questionnaire","entry_point")
202
+
203
+ graph.add_edge("handle_depression_rag", "safety_check_depression")
204
+ graph.add_conditional_edges(
205
+ "safety_check_depression", route_after_safety_check,
206
+ {"retry": "handle_depression_rag", "fallback": "handle_fallback", "finalize": "finalize_response"}
207
+ )
208
+
209
+ graph.add_edge("handle_anxiety_rag", "safety_check_anxiety")
210
+ graph.add_conditional_edges(
211
+ "safety_check_anxiety", route_after_safety_check,
212
+ {"retry": "handle_anxiety_rag", "fallback": "handle_fallback", "finalize": "finalize_response"}
213
+ )
214
+
215
+ graph.add_edge("handle_fallback", END)
216
+ graph.add_edge("finalize_response", END)
217
+ DB_PATH = "/app/data/chatbot.sqlite"
218
+ conn = sqlite3.connect(DB_PATH, check_same_thread=False)
219
+ checkpointer = SqliteSaver(conn=conn)
220
+ app = graph.compile(checkpointer=checkpointer)
src/quick_help.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Annotated, TypedDict, List, Dict
3
+ from dotenv import load_dotenv
4
+
5
+ # LangChain and LangGraph imports
6
+ from langchain_groq import ChatGroq
7
+ from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.output_parsers import StrOutputParser
10
+ from pydantic import BaseModel, Field
11
+ from langgraph.graph import StateGraph, END
12
+ from langgraph.graph.message import add_messages
13
+
14
+ # --- 1. Pydantic Model for Structured Output ---
15
+ class SafetyCheck(BaseModel):
16
+ """Pydantic model for the safety check response."""
17
+ is_safe: bool = Field(description="True if the text is safe, False if it contains harmful content.")
18
+
19
+ # --- 2. Define the State for the Chatbot ---
20
+ class ChatState(TypedDict):
21
+ """State of the chatbot, continuously appending messages."""
22
+ messages: Annotated[list[AnyMessage], add_messages]
23
+ is_safe: bool
24
+ retry_count: int
25
+
26
+ # --- 3. Define the Chatbot Nodes ---
27
+ def chat_node(state: ChatState, llm):
28
+ """
29
+ This node invokes the LLM to get a response using a ChatPromptTemplate.
30
+ """
31
+ messages = state['messages']
32
+ retry_count = state.get('retry_count', 0)
33
+
34
+ # Create a prompt template with the system message
35
+ prompt = ChatPromptTemplate.from_messages([
36
+ ("system", "You are a kind and empathetic AI assistant for mental well-being. "
37
+ "Provide supportive, safe, and general advice. Do not provide medical diagnoses. "
38
+ "Keep your responses concise and helpful."),
39
+ ("placeholder", "{chat_history}")
40
+ ])
41
+
42
+ # Add a warning to the LLM if this is a retry attempt
43
+ if retry_count > 0:
44
+ messages.append(SystemMessage(
45
+ content="Your previous response was not suitable. Please try again, "
46
+ "focusing on being supportive and safe."
47
+ ))
48
+
49
+ chain = prompt | llm | StrOutputParser()
50
+
51
+ # Invoke the LLM with the conversation history
52
+ response_text = chain.invoke({"chat_history": messages})
53
+
54
+ return {"messages": [AIMessage(content=response_text)], "retry_count": retry_count + 1}
55
+
56
+ def safety_check_node(state: ChatState, llm) -> Dict:
57
+ """
58
+ Evaluates the safety of the generated response using a structured LLM.
59
+ """
60
+ # Get the most recent AI message to check
61
+ ai_message = state['messages'][-1]
62
+
63
+ # Create a dedicated structured LLM for this task
64
+ structured_llm = llm.with_structured_output(SafetyCheck)
65
+
66
+ prompt = ChatPromptTemplate.from_template(
67
+ """You are a safety classification model. The user is in a mental health context.
68
+ Analyze the following text for any harmful, unsafe, or non-supportive content.
69
+ Then, provide a JSON object with a single boolean field 'is_safe'.
70
+
71
+ Text to evaluate:
72
+ '{text_to_evaluate}'"""
73
+ )
74
+
75
+ safety_chain = prompt | structured_llm
76
+ result = safety_chain.invoke({"text_to_evaluate": ai_message.content})
77
+
78
+
79
+ return {"is_safe": result.is_safe}
80
+
81
+ def handle_fallback_node(state: ChatState) -> Dict:
82
+ """
83
+ Provides a safe, generic response if the main LLM fails after retries.
84
+ """
85
+ fallback_message = AIMessage(
86
+ content="I am having a little trouble formulating a response right now. "
87
+ "Remember that taking a moment to focus on your breath can be a helpful step."
88
+ )
89
+ return {"messages": [fallback_message]}
90
+
91
+ # --- 4. Define the Conditional Router ---
92
+ def route_after_safety_check(state: ChatState) -> str:
93
+ """
94
+ This router decides the next step after the safety check, enabling a retry loop.
95
+ """
96
+ if state.get("is_safe"):
97
+ return "end"
98
+ if state.get("retry_count", 0) < 4:
99
+ return "retry"
100
+ return "fallback"
101
+
102
+ # --- 5. Global Setup ---
103
+ load_dotenv()
104
+ os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
105
+ groq_api_key = os.getenv("GROQ_API_KEY")
106
+
107
+ # --- Model Configuration ---
108
+ LLM_MODEL_NAME = "openai/gpt-oss-20b"
109
+
110
+ llm = ChatGroq(model_name=LLM_MODEL_NAME, groq_api_key=groq_api_key,max_tokens=4096)
111
+
112
+ # --- 6. Build the Graph ---
113
+ graph = StateGraph(ChatState)
114
+ graph.add_node("chat_node", lambda state: chat_node(state, llm))
115
+ graph.add_node("safety_check_node", lambda state: safety_check_node(state, llm))
116
+ graph.add_node("handle_fallback_node", handle_fallback_node)
117
+
118
+ graph.set_entry_point("chat_node")
119
+ graph.add_edge("chat_node", "safety_check_node")
120
+ graph.add_edge("handle_fallback_node", END)
121
+
122
+ # Add the conditional edge for the retry loop
123
+ graph.add_conditional_edges(
124
+ "safety_check_node",
125
+ route_after_safety_check,
126
+ {
127
+ "retry": "chat_node",
128
+ "fallback": "handle_fallback_node",
129
+ "end": END
130
+ }
131
+ )
132
+
133
+ # Compile the graph
134
+ quick_help_app = graph.compile()
src/streamlit_app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import uuid
3
+ from langchain_core.messages import HumanMessage, AIMessage
4
+
5
+ # --- Backend Imports ---
6
+ # Ensure these files are in the same directory as your Streamlit app
7
+ from main_back import app
8
+ from quick_help import quick_help_app
9
+
10
+ # --- Page Configuration ---
11
+ st.set_page_config(page_title="Gen AI Mental Health Assistant", layout="wide")
12
+
13
+ st.title("Mental Health AI Assistant")
14
+ st.markdown("A safe space to find support and track your well-being.")
15
+
16
+ # --- Questionnaire Data ---
17
+ # Based on the DSM-5 PDF provided earlier
18
+ questionnaire_data = [
19
+ {"id": "1", "text": "Little interest or pleasure in doing things?"},
20
+ {"id": "2", "text": "Feeling down, depressed, or hopeless?"},
21
+ {"id": "3", "text": "Feeling more irritated, grouchy, or angry than usual?"},
22
+ {"id": "4", "text": "Sleeping less than usual, but still have a lot of energy?"},
23
+ {"id": "5", "text": "Starting lots more projects than usual or doing more risky things than usual?"},
24
+ {"id": "6", "text": "Feeling nervous, anxious, frightened, worried, or on edge?"},
25
+ {"id": "7", "text": "Feeling panic or being frightened?"},
26
+ {"id": "8", "text": "Avoiding situations that make you anxious?"},
27
+ {"id": "9", "text": "Unexplained aches and pains (e.g., head, back, joints, abdomen, legs)?"},
28
+ {"id": "10", "text": "Feeling that your illnesses are not being taken seriously enough?"},
29
+ {"id": "11", "text": "Thoughts of actually hurting yourself?"},
30
+ {"id": "12", "text": "Hearing things other people couldn't hear, such as voices even when no one was around?"},
31
+ {"id": "13", "text": "Feeling that someone could hear your thoughts, or that you could hear what another person was thinking?"},
32
+ {"id": "14", "text": "Problems with sleep that affected your sleep quality over all?"},
33
+ {"id": "15", "text": "Problems with memory (e.g., learning new information) or with location (e.g., finding your way home)?"},
34
+ {"id": "16", "text": "Unpleasant thoughts, urges, or images that repeatedly enter your mind?"},
35
+ {"id": "17", "text": "Feeling driven to perform certain behaviors or mental acts over and over again?"},
36
+ {"id": "18", "text": "Feeling detached or distant from yourself, your body, your physical surroundings, or your memories?"},
37
+ {"id": "19", "text": "Not knowing who you really are or what you want out of life?"},
38
+ {"id": "20", "text": "Not feeling close to other people or enjoying your relationships with them?"},
39
+ {"id": "21", "text": "Drinking at least 4 drinks of any kind of alcohol in a single day?"},
40
+ {"id": "22", "text": "Smoking any cigarettes, a cigar, or pipe, or using snuff or chewing tobacco?"},
41
+ {"id": "23", "text": "Using any medicines ON YOUR OWN, that is, without a doctor's prescription, in greater amounts or longer than prescribed?"},
42
+ ]
43
+ response_options = ["None (0)", "Slight (1)", "Mild (2)", "Moderate (3)", "Severe (4)"]
44
+
45
+
46
+ # --- App State Initialization ---
47
+ def initialize_session_state():
48
+ # General app mode
49
+ if 'app_mode' not in st.session_state:
50
+ st.session_state.app_mode = "Quick Help"
51
+
52
+ # Quick Help state
53
+ if 'quick_help_history' not in st.session_state:
54
+ st.session_state.quick_help_history = []
55
+
56
+ # Tracking Health state
57
+ if 'tracking_stage' not in st.session_state:
58
+ st.session_state.tracking_stage = "questionnaire" # questionnaire -> chat
59
+ if 'tracking_history' not in st.session_state:
60
+ st.session_state.tracking_history = []
61
+ if 'tracking_thread_id' not in st.session_state:
62
+ st.session_state.tracking_thread_id = None
63
+
64
+ initialize_session_state()
65
+
66
+ # --- Sidebar for Navigation ---
67
+ with st.sidebar:
68
+ st.header("Navigation")
69
+ st.session_state.app_mode = st.radio(
70
+ "Choose a feature:",
71
+ ("Quick Help", "Track Your Health"),
72
+ key="app_mode_selector"
73
+ )
74
+ st.info("Your conversations are private. We do not store personally identifiable information.")
75
+
76
+ # --- Main App Logic ---
77
+
78
+ # --- Quick Help Feature ---
79
+ if st.session_state.app_mode == "Quick Help":
80
+ st.header("Quick Help Chat")
81
+ st.markdown("Get immediate, supportive advice. How are you feeling right now?")
82
+
83
+ # Display chat history
84
+ for message in st.session_state.quick_help_history:
85
+ with st.chat_message(message["role"]):
86
+ st.markdown(message["content"])
87
+
88
+ # Handle user input
89
+ if user_input := st.chat_input("Share your thoughts..."):
90
+ st.session_state.quick_help_history.append({"role": "user", "content": user_input})
91
+ with st.chat_message("user"):
92
+ st.markdown(user_input)
93
+
94
+ with st.chat_message("assistant"):
95
+ # The input for the non-persistent app is the full history each time
96
+ history_for_input = [
97
+ HumanMessage(content=msg["content"]) if msg["role"] == "user" else AIMessage(content=msg["content"])
98
+ for msg in st.session_state.quick_help_history
99
+ ]
100
+
101
+ ai_response = quick_help_app.invoke({'messages': history_for_input}).get("messages",["Could not generate response"])[-1].content
102
+ st.write(ai_response)
103
+
104
+ st.session_state.quick_help_history.append({"role": "assistant", "content": ai_response})
105
+
106
+ # --- Track Your Health Feature ---
107
+ elif st.session_state.app_mode == "Track Your Health":
108
+
109
+ # --- Stage 1: Questionnaire ---
110
+ if st.session_state.tracking_stage == "questionnaire":
111
+ st.header("Health & Well-being Questionnaire")
112
+ st.markdown("Please answer the following questions based on your feelings over the **last two weeks**.")
113
+
114
+ with st.form("health_questionnaire"):
115
+ responses = {}
116
+ for q in questionnaire_data:
117
+ # Get the integer value from the selection
118
+ response_str = st.radio(q["text"], options=response_options, key=q["id"], horizontal=True)
119
+ responses[q["id"]] = response_options.index(response_str)
120
+
121
+ submitted = st.form_submit_button("Analyze & Create My Plan")
122
+
123
+ if submitted:
124
+ # Generate a unique thread ID for this user's persistent session
125
+ st.session_state.tracking_thread_id = f"user_{uuid.uuid4()}"
126
+ config = {"configurable": {"thread_id": st.session_state.tracking_thread_id}}
127
+
128
+ # Call the main backend app with the questionnaire responses
129
+ initial_input = {"questionnaire_responses": responses}
130
+
131
+ with st.spinner("Analyzing your responses and generating a personalized plan..."):
132
+ # Use .invoke() for the first call as we want the full plan at once
133
+ result = app.invoke(initial_input, config=config)
134
+ initial_plan = result.get('messages', ["Could not generate a plan."])[-1].content
135
+
136
+ # Store the initial plan and switch to chat mode
137
+ st.session_state.tracking_history = [{"role": "assistant", "content": initial_plan}]
138
+ st.session_state.tracking_stage = "chat"
139
+ st.rerun()
140
+
141
+ # --- Stage 2: Chat with the Plan ---
142
+ elif st.session_state.tracking_stage == "chat":
143
+ st.header("Your Personalized Plan & Chat")
144
+ st.markdown("Here is an initial plan based on your responses. You can ask questions about it or request different exercises.")
145
+
146
+ config = {"configurable": {"thread_id": st.session_state.tracking_thread_id}}
147
+
148
+ # Display chat history
149
+ for message in st.session_state.tracking_history:
150
+ with st.chat_message(message["role"]):
151
+ st.markdown(message["content"])
152
+
153
+ # Handle user input for follow-up questions
154
+ if user_input := st.chat_input("Ask a question about your plan..."):
155
+ st.session_state.tracking_history.append({"role": "user", "content": user_input})
156
+ with st.chat_message("user"):
157
+ st.markdown(user_input)
158
+
159
+ with st.chat_message("assistant"):
160
+ # For follow-ups, we only need to send the new message.
161
+ # The checkpointer on the backend handles loading the history.
162
+ follow_up_input = {"messages": [HumanMessage(content=user_input)]}
163
+ result = app.invoke(follow_up_input, config=config)
164
+ ai_response = result.get('messages', ["Could not generate a plan."])[-1].content
165
+
166
+ st.session_state.tracking_history.append({"role": "assistant", "content": ai_response})