avtak commited on
Commit
fd5498f
ยท
verified ยท
1 Parent(s): 926e080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -134
app.py CHANGED
@@ -9,23 +9,19 @@ hf_token = os.getenv("HF_TOKEN")
9
  if hf_token:
10
  login(token=hf_token)
11
 
 
12
  print("Loading Mental-Longformer...")
13
  model_name = "avtak/erisk-longformer-depression-v1"
14
- classifier = pipeline(
15
- "text-classification",
16
- model=model_name,
17
- truncation=True,
18
- max_length=4096,
19
- top_k=None,
20
- )
21
 
22
- # --- 2. MCP TOOLS ---
23
 
24
  def get_crisis_resources(location: str = "Global") -> str:
 
25
  resources = {
26
- "US": "๐Ÿ‡บ๐Ÿ‡ธ US: Crisis Text Line: 741741 | Suicide Lifeline: 988",
27
- "Malaysia": "๐Ÿ‡ฒ๐Ÿ‡พ Malaysia: Befrienders KL: 03-76272929 | Talian Kasih: 15999",
28
- "Global": "๐ŸŒ International: befrienders.org",
29
  }
30
  for key in resources:
31
  if location and key.lower() in location.lower():
@@ -33,91 +29,91 @@ def get_crisis_resources(location: str = "Global") -> str:
33
  return resources["Global"]
34
 
35
  def detect_depression_risk(text: str) -> dict:
36
- processed_text = text.replace("\n", "\n\n")
 
 
 
 
 
 
37
  results = classifier(processed_text)[0]
38
- prob = next((r["score"] for r in results if r["label"] == "LABEL_1"), 0.0)
39
-
40
- if prob < 0.40:
 
41
  level = "Low Risk"
42
  biomarker = "Healthy External Focus"
43
  desc = "Matches 'Isolated Control' group. High lexical diversity, focus on hobbies/events."
44
- color = "#10b981"
45
- elif 0.40 <= prob < 0.60:
46
  level = "Moderate Risk"
47
  biomarker = "Echo Chamber Interaction"
48
- desc = "Matches 'Interactive Non-Depressed' group. Engaging in support forums but likely not clinically depressed."
49
- color = "#f59e0b"
50
- else:
51
  level = "High Risk"
52
  biomarker = "Nocturnal & High-Effort"
53
- desc = "Matches 'Depressed' cohort. Nocturnal posting, highโ€‘effort/lowโ€‘frequency posts."
54
- color = "#ef4444"
55
-
56
  return {
57
- "probability": prob,
58
- "risk_level": level,
59
  "biomarker": biomarker,
60
  "description": desc,
61
  "color": color,
62
- "word_count": len(processed_text.split()),
63
  }
64
 
65
- # --- 3. AGENT LOGIC (messages format only) ---
66
 
67
- def generate_response(history, risk_context, provider):
 
68
  if not risk_context:
69
- risk_context = {
70
- "risk_level": "Unknown",
71
- "probability": 0.0,
72
- "description": "No analysis run yet.",
73
- }
74
 
 
 
 
75
  if provider == "SambaNova":
76
- client = OpenAI(
77
- base_url="https://api.sambanova.ai/v1",
78
- api_key=os.getenv("SAMBANOVA_API_KEY"),
79
- )
80
  model_id = "Meta-Llama-3.3-70B-Instruct"
81
- else:
82
- client = OpenAI(
83
- base_url="https://api.tokenfactory.nebius.com/v1/",
84
- api_key=os.getenv("NEBIUS_API_KEY"),
85
- )
86
  model_id = "moonshotai/Kimi-K2-Thinking"
87
 
88
  system_prompt = f"""
89
- You are 'Dr. Longformer', a specialized Clinical AI Assistant based on Hassan's 2025 Thesis.
90
-
91
- CURRENT USER CONTEXT:
92
- - Analyzed Risk: {risk_context['risk_level']} ({risk_context['probability']:.1%})
93
- - Detected Pattern: {risk_context['description']}
94
-
95
- YOUR GOAL: Provide supportive, scientifically-grounded chat. Max 2 sentences. Do not diagnose.
96
- """.strip()
97
 
98
  messages = [{"role": "system", "content": system_prompt}]
99
- messages.extend(history) # history is already [{"role": ..., "content": ...}]
100
 
101
  try:
102
- resp = client.chat.completions.create(
103
- model=model_id,
104
- messages=messages,
105
- temperature=0.7,
106
- max_tokens=300,
107
  )
108
- return resp.choices[0].message.content
109
  except Exception as e:
110
- return f"โš ๏ธ Error with {provider}: {e}"
111
 
112
- # --- 4. UI ORCHESTRATION (Chatbot uses messages format) ---
113
 
114
  def run_analysis(text, location, provider):
115
- if not text.strip():
116
- return None, [], None
117
-
118
  data = detect_depression_risk(text)
119
  resources = get_crisis_resources(location)
120
-
 
121
  html_dashboard = f"""
122
  <div style="padding: 20px; border-radius: 12px; background-color: {data['color']}15; border: 1px solid {data['color']};">
123
  <div style="display: flex; justify-content: space-between; align-items: center;">
@@ -137,108 +133,135 @@ def run_analysis(text, location, provider):
137
  </div>
138
  </div>
139
  """
140
-
141
- # Initial assistant message
142
- history = [
143
- {"role": "assistant", "content": generate_response(
144
- [{"role": "user", "content": "I just ran the analysis. Please explain my results."}],
145
- data,
146
- provider
147
- )}
148
- ]
149
-
150
- # dashboard HTML, chatbot messages (list[dict]), risk_data
151
- return html_dashboard, history, data
152
-
153
- def user_chat(user_message, history, risk_data, provider):
154
- if not user_message:
155
- return history, ""
156
-
157
- history.append({"role": "user", "content": user_message})
158
- ai_msg = generate_response(history, risk_data, provider)
159
- history.append({"role": "assistant", "content": ai_msg})
160
- return history, ""
161
-
162
- # --- 5. EXAMPLES ---
163
-
164
- example_low = """The new update for the Linux kernel fixed some Realtek driver issues I was hitting.
165
- On a different note, the basketball team's defensive stats improved a lot this season.
166
- I also spent time woodworking on a walnut coffee table. The grain is tricky but satisfying to work with."""
167
-
168
- example_mod = """Things have been pretty busy at work with a big project and tight deadlines.
169
- Sleep could be better; I've been staying up too late scrolling my phone.
170
- Went hiking with some friends last weekend which helped me reset a bit."""
171
-
172
- example_high = """I don't know why I even bother getting out of bed anymore.
173
- I've been avoiding my friends for weeks and can't bring myself to reply.
174
- Everything feels grey and heavy, and I was up until 4 AM staring at the ceiling again."""
175
-
176
- # --- 6. UI LAYOUT ---
177
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  with gr.Blocks(title="Depression Risk Agent") as demo:
179
  gr.Markdown("# ๐Ÿง  Early Depression Detection Agent (MCP)")
180
- gr.Markdown("Agentic system using **Mental-Longformer** (Tool) + **SambaNova/Nebius** (Chat).")
181
- gr.Markdown("โšก Powered by: SambaNova (Llama 3.3) & Nebius (Kimi K2)")
182
-
 
183
  risk_state = gr.State(None)
184
-
 
185
  with gr.Row():
 
186
  with gr.Column(scale=1):
187
  input_text = gr.Textbox(
188
- label="User Timeline (Paste posts here)",
189
- lines=8,
190
  placeholder="[Post 1] ...\n\n[Post 2] ...",
191
- value=example_high,
192
  )
193
-
 
194
  gr.Markdown("### ๐Ÿ” Try Thesis Patterns")
195
  with gr.Row():
196
- btn_low = gr.Button("๐ŸŸข Low Risk", size="sm", variant="secondary")
197
- btn_mod = gr.Button("๐ŸŸก Moderate", size="sm", variant="secondary")
198
- btn_high = gr.Button("๐Ÿ”ด High Risk", size="sm", variant="secondary")
199
-
 
200
  gr.Markdown("### โš™๏ธ Settings")
201
  with gr.Row():
202
- loc_drop = gr.Dropdown(
203
- ["Global", "US", "Malaysia"],
204
- value="Malaysia",
205
- label="Crisis Resource Region",
206
- )
207
- prov_drop = gr.Dropdown(
208
- ["SambaNova", "Nebius"],
209
- value="SambaNova",
210
- label="Agent Brain",
211
- )
212
-
213
  analyze_btn = gr.Button("๐Ÿš€ Run Clinical Analysis", variant="primary", size="lg")
 
 
 
214
 
 
215
  with gr.Column(scale=1):
216
  dashboard = gr.HTML(label="Clinical Dashboard")
217
- chatbot = gr.Chatbot(label="Agent Chat", height=400, type="messages")
 
 
 
218
  msg_input = gr.Textbox(label="Chat with Agent", placeholder="Ask about your results or get advice...")
219
  send_btn = gr.Button("Send Message")
220
 
221
- # Wiring
 
 
222
  analyze_btn.click(
223
  run_analysis,
224
  inputs=[input_text, loc_drop, prov_drop],
225
- outputs=[dashboard, chatbot, risk_state],
226
  )
227
-
 
228
  btn_low.click(lambda: example_low, None, input_text)
229
  btn_mod.click(lambda: example_mod, None, input_text)
230
  btn_high.click(lambda: example_high, None, input_text)
231
-
 
 
232
  send_btn.click(
233
  user_chat,
234
- inputs=[msg_input, chatbot, risk_state, prov_drop],
235
- outputs=[chatbot, msg_input],
236
  )
237
  msg_input.submit(
238
  user_chat,
239
- inputs=[msg_input, chatbot, risk_state, prov_drop],
240
- outputs=[chatbot, msg_input],
241
  )
242
 
243
  if __name__ == "__main__":
244
- demo.launch(mcp_server=True, theme=gr.themes.Soft())
 
 
 
 
9
  if hf_token:
10
  login(token=hf_token)
11
 
12
+ # --- 2. LOAD TOOL ---
13
  print("Loading Mental-Longformer...")
14
  model_name = "avtak/erisk-longformer-depression-v1"
15
+ classifier = pipeline("text-classification", model=model_name, truncation=True, max_length=4096, top_k=None)
 
 
 
 
 
 
16
 
17
+ # --- 3. MCP TOOLS ---
18
 
19
  def get_crisis_resources(location: str = "Global") -> str:
20
+ """Returns mental health resources based on location."""
21
  resources = {
22
+ "US": "๐Ÿ‡บ๐Ÿ‡ธ **US:** Crisis Text Line: 741741 | Suicide Lifeline: 988",
23
+ "Malaysia": "๐Ÿ‡ฒ๐Ÿ‡พ **Malaysia:** Befrienders KL: 03-76272929 | Talian Kasih: 15999",
24
+ "Global": "๐ŸŒ **International:** [befrienders.org](https://www.befrienders.org)"
25
  }
26
  for key in resources:
27
  if location and key.lower() in location.lower():
 
29
  return resources["Global"]
30
 
31
  def detect_depression_risk(text: str) -> dict:
32
+ """Analyzes text using Mental-Longformer (eRisk 2025)."""
33
+
34
+ # --- THESIS LOGIC: AGGREGATION (CRITICAL) ---
35
+ # We replace single newlines with double newlines so the model sees distinct posts
36
+ # This matches your Master's Thesis data preparation method.
37
+ processed_text = text.replace("\n", "\n\n")
38
+
39
  results = classifier(processed_text)[0]
40
+ prob = next((r['score'] for r in results if r['label'] == 'LABEL_1'), 0)
41
+
42
+ # Thesis Thresholds (Figure 4.15)
43
+ if prob < 0.40:
44
  level = "Low Risk"
45
  biomarker = "Healthy External Focus"
46
  desc = "Matches 'Isolated Control' group. High lexical diversity, focus on hobbies/events."
47
+ color = "#10b981" # Green
48
+ elif 0.40 <= prob < 0.60:
49
  level = "Moderate Risk"
50
  biomarker = "Echo Chamber Interaction"
51
+ desc = "Matches 'Interactive Non-Depressed' group. Engaging in support forums but likely not clinically depressed (Supportive Responder)."
52
+ color = "#f59e0b" # Yellow
53
+ else:
54
  level = "High Risk"
55
  biomarker = "Nocturnal & High-Effort"
56
+ desc = "Matches 'Depressed' cohort. Indicators: Nocturnal posting spikes (00-05 UTC), high-effort/low-frequency posting."
57
+ color = "#ef4444" # Red
58
+
59
  return {
60
+ "probability": prob,
61
+ "risk_level": level,
62
  "biomarker": biomarker,
63
  "description": desc,
64
  "color": color,
65
+ "word_count": len(processed_text.split())
66
  }
67
 
68
+ # --- 4. AGENT LOGIC (Dual State Management) ---
69
 
70
+ def generate_response(api_history, risk_context, provider):
71
+ """Generates response using Sponsor API."""
72
  if not risk_context:
73
+ risk_context = {"risk_level": "Unknown", "probability": 0.0, "description": "No analysis run yet."}
 
 
 
 
74
 
75
+ client = None
76
+ model_id = None
77
+
78
  if provider == "SambaNova":
79
+ client = OpenAI(base_url="https://api.sambanova.ai/v1", api_key=os.getenv("SAMBANOVA_API_KEY"))
 
 
 
80
  model_id = "Meta-Llama-3.3-70B-Instruct"
81
+ else:
82
+ client = OpenAI(base_url="https://api.tokenfactory.nebius.com/v1/", api_key=os.getenv("NEBIUS_API_KEY"))
 
 
 
83
  model_id = "moonshotai/Kimi-K2-Thinking"
84
 
85
  system_prompt = f"""
86
+ You are 'Dr. Longformer', a specialized Clinical AI Assistant based on Hassan's 2025 Thesis.
87
+ CURRENT USER CONTEXT:
88
+ - Analyzed Risk: {risk_context['risk_level']} ({risk_context['probability']:.1%})
89
+ - Detected Pattern: {risk_context['description']}
90
+ YOUR GOAL: Provide supportive, scientifically-grounded chat. Max 2 sentences.
91
+ """
 
 
92
 
93
  messages = [{"role": "system", "content": system_prompt}]
94
+ messages.extend(api_history)
95
 
96
  try:
97
+ response = client.chat.completions.create(
98
+ model=model_id,
99
+ messages=messages,
100
+ temperature=0.7,
101
+ max_tokens=300
102
  )
103
+ return response.choices[0].message.content
104
  except Exception as e:
105
+ return f"โš ๏ธ Error with {provider}: {str(e)}"
106
 
107
+ # --- 5. UI ORCHESTRATION ---
108
 
109
  def run_analysis(text, location, provider):
110
+ if not text.strip(): return None, [], [], None
111
+
112
+ # 1. Run Tool
113
  data = detect_depression_risk(text)
114
  resources = get_crisis_resources(location)
115
+
116
+ # 2. Visual Dashboard
117
  html_dashboard = f"""
118
  <div style="padding: 20px; border-radius: 12px; background-color: {data['color']}15; border: 1px solid {data['color']};">
119
  <div style="display: flex; justify-content: space-between; align-items: center;">
 
133
  </div>
134
  </div>
135
  """
136
+
137
+ # 3. Agent Greeting
138
+ # Format for API: List of Dicts
139
+ api_history = [{"role": "user", "content": "I just ran the analysis. Please explain my results."}]
140
+ ai_msg = generate_response(api_history, data, provider)
141
+ api_history.append({"role": "assistant", "content": ai_msg})
142
+
143
+ # Format for UI: List of Tuples [(None, AI_Message)]
144
+ # CRITICAL FIX: This standard format works on ALL Gradio versions
145
+ ui_history = [(None, ai_msg)]
146
+
147
+ return html_dashboard, ui_history, api_history, data
148
+
149
+ def user_chat(user_message, ui_history, api_history, risk_data, provider):
150
+ if not user_message: return ui_history, api_history, ""
151
+
152
+ # 1. Update API History
153
+ api_history.append({"role": "user", "content": user_message})
154
+
155
+ # 2. Update UI History (User message only first)
156
+ ui_history.append([user_message, None])
157
+
158
+ # 3. Generate Response
159
+ ai_msg = generate_response(api_history, risk_data, provider)
160
+
161
+ # 4. Update API History
162
+ api_history.append({"role": "assistant", "content": ai_msg})
163
+
164
+ # 5. Update UI History (Add AI response)
165
+ ui_history[-1][1] = ai_msg
166
+
167
+ return ui_history, api_history, ""
168
+
169
+ # --- 6. EXAMPLES (RESTORED) ---
170
+ example_low = """The new update for the Linux kernel (6.8) finally addressed the driver issues I was seeing with Realtek cards. I read the changelog on kernel.org and noticed they patched the module specific to the rtl8821ce chipset.
171
+
172
+ On a different note, the defensive stats for the basketball team have improved significantly. Allowing 15% fewer points per possession is a game changer.
173
+
174
+ I also spent some time woodworking. I'm building a walnut coffee table and the grain is tricky to plane. I might switch to a polyurethane finish for durability."""
175
+
176
+ example_mod = """Things have been pretty busy at work. We're in the middle of a big project and deadlines are tight. Staying a bit later than usual but that's just how it goes sometimes. Team is handling it well overall.
177
+
178
+ Sleep could be better. Been staying up too late scrolling my phone. Need to work on that. Usually feel okay once I get moving in the morning though.
179
+
180
+ Went hiking with some friends last weekend which was nice. Good to get outside and move around. We're talking about doing another trip next month. Weather should be better by then."""
181
+
182
+ example_high = """I don't know why I even bother getting out of bed anymore. I slept for 12 hours yesterday and I'm still exhausted. It feels like my limbs weigh a thousand pounds.
183
+
184
+ I've been avoiding my friends for weeks. They keep texting me, but I can't bring myself to reply. The thought of socializing is terrifying.
185
+
186
+ Everything feels like a shade of grey. I can't concentrate on my work. I feel like I'm drowning while everyone else is breathing fine. I was up until 4 AM again last night just staring at the ceiling."""
187
+
188
+ # --- 7. UI LAYOUT ---
189
  with gr.Blocks(title="Depression Risk Agent") as demo:
190
  gr.Markdown("# ๐Ÿง  Early Depression Detection Agent (MCP)")
191
+ gr.Markdown("Agentic system using **Mental-Longformer** (Tool) + **SambaNova/Nebius** (Reasoning).")
192
+ gr.Markdown("โšก **Powered by:** [SambaNova](https://sambanova.ai/) (Llama 3.3) & [Nebius](https://nebius.com/) (Kimi K2)")
193
+
194
+ # Internal State
195
  risk_state = gr.State(None)
196
+ api_state = gr.State([]) # Stores [{"role":...}] for the LLM
197
+
198
  with gr.Row():
199
+ # LEFT: INPUT
200
  with gr.Column(scale=1):
201
  input_text = gr.Textbox(
202
+ label="User Timeline (Paste posts here)",
203
+ lines=8,
204
  placeholder="[Post 1] ...\n\n[Post 2] ...",
205
+ value=example_high
206
  )
207
+
208
+ # EXAMPLES ROW
209
  gr.Markdown("### ๐Ÿ” Try Thesis Patterns")
210
  with gr.Row():
211
+ btn_low = gr.Button("๐ŸŸข Low Risk", size="sm")
212
+ btn_mod = gr.Button("๐ŸŸก Moderate", size="sm")
213
+ btn_high = gr.Button("๐Ÿ”ด High Risk", size="sm")
214
+
215
+ # SETTINGS ROW
216
  gr.Markdown("### โš™๏ธ Settings")
217
  with gr.Row():
218
+ loc_drop = gr.Dropdown(["Global", "US", "Malaysia"], value="Malaysia", label="Crisis Resource Region")
219
+ prov_drop = gr.Dropdown(["SambaNova", "Nebius"], value="SambaNova", label="Agent Brain")
220
+
 
 
 
 
 
 
 
 
221
  analyze_btn = gr.Button("๐Ÿš€ Run Clinical Analysis", variant="primary", size="lg")
222
+
223
+ with gr.Accordion("๐Ÿ”ง MCP Tools Exposed", open=False):
224
+ gr.Markdown("- `detect_depression_risk`\n- `get_crisis_resources`")
225
 
226
+ # RIGHT: DASHBOARD & CHAT
227
  with gr.Column(scale=1):
228
  dashboard = gr.HTML(label="Clinical Dashboard")
229
+
230
+ # FIXED: REMOVED type="messages" to prevent crash. Uses standard Tuples.
231
+ chatbot = gr.Chatbot(label="Agent Chat", height=400)
232
+
233
  msg_input = gr.Textbox(label="Chat with Agent", placeholder="Ask about your results or get advice...")
234
  send_btn = gr.Button("Send Message")
235
 
236
+ # WIRING
237
+
238
+ # 1. Analyze Button -> Updates Dashboard, Chatbot, Risk State
239
  analyze_btn.click(
240
  run_analysis,
241
  inputs=[input_text, loc_drop, prov_drop],
242
+ outputs=[dashboard, chatbot, api_state, risk_state]
243
  )
244
+
245
+ # 2. Example Buttons
246
  btn_low.click(lambda: example_low, None, input_text)
247
  btn_mod.click(lambda: example_mod, None, input_text)
248
  btn_high.click(lambda: example_high, None, input_text)
249
+
250
+ # 3. Chat Interactions
251
+ # Note: We pass 'chatbot' as both input and output (it holds the history in 'tuples' format)
252
  send_btn.click(
253
  user_chat,
254
+ inputs=[msg_input, chatbot, api_state, risk_state, prov_drop],
255
+ outputs=[chatbot, api_state, msg_input]
256
  )
257
  msg_input.submit(
258
  user_chat,
259
+ inputs=[msg_input, chatbot, api_state, risk_state, prov_drop],
260
+ outputs=[chatbot, api_state, msg_input]
261
  )
262
 
263
  if __name__ == "__main__":
264
+ demo.launch(
265
+ mcp_server=True,
266
+ theme=gr.themes.Soft()
267
+ )