DrSyedFaizan commited on
Commit
39487d7
·
verified ·
1 Parent(s): 5b2592e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -123
app.py CHANGED
@@ -5,9 +5,22 @@ import pandas as pd
5
  import numpy as np
6
  import plotly.express as px
7
  import time
8
- from streamlit_lottie import st_lottie
9
- import requests
10
- import json
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Page configuration
13
  st.set_page_config(
@@ -82,26 +95,28 @@ st.markdown("""
82
  text-align: center;
83
  color: #666;
84
  }
 
 
 
 
 
 
85
  </style>
86
  """, unsafe_allow_html=True)
87
 
88
- # Function to load Lottie animation
89
- def load_lottie_url(url: str):
90
- r = requests.get(url)
91
- if r.status_code != 200:
92
- return None
93
- return r.json()
94
-
95
- # Load animations
96
- brain_animation = load_lottie_url("https://assets9.lottiefiles.com/packages/lf20_twdne5i2.json")
97
- analyzing_animation = load_lottie_url("https://assets8.lottiefiles.com/private_files/lf30_p9aibxmu.json")
98
-
99
- # Define model and tokenizer paths from Hugging Face
100
- MODEL_PATH = "DrSyedFaizan/mindBERT"
101
-
102
  # Create sidebar
103
  with st.sidebar:
104
- st_lottie(brain_animation, height=200, key="brain_animation")
 
 
 
 
 
 
 
 
 
 
105
  st.markdown("## About MindBERT")
106
  st.info(
107
  "MindBERT is a fine-tuned BERT model specifically designed to detect "
@@ -145,12 +160,19 @@ with tab1:
145
  # Model loading feedback
146
  @st.cache_resource
147
  def load_model():
148
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
149
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
150
- return tokenizer, model
 
 
 
 
 
 
 
151
 
152
  with st.spinner("Loading model..."):
153
- tokenizer, model = load_model()
154
 
155
  # Analysis button
156
  col1, col2, col3 = st.columns([1, 2, 1])
@@ -159,112 +181,124 @@ with tab1:
159
 
160
  # Prediction logic
161
  if analyze_button:
162
- if user_input.strip():
163
- # Show analyzing animation
 
 
 
 
164
  with st.spinner("Analyzing..."):
165
- st_lottie(analyzing_animation, height=200, key="analyze_animation", speed=1.5)
166
-
167
- # Tokenize input
168
- inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
169
-
170
- # Make prediction
171
- with torch.no_grad():
172
- outputs = model(**inputs)
173
- logits = outputs.logits
174
- probabilities = torch.nn.functional.softmax(logits, dim=1)[0]
175
- predicted_class = torch.argmax(logits, dim=1).item()
176
-
177
- # Mapping predicted class to mental state with descriptions
178
- label_map = {
179
- 0: {"name": "Anxiety", "color": "#FFD54F", "description": "Characterized by excessive worry, fear, or nervousness."},
180
- 1: {"name": "Bipolar", "color": "#FF7043", "description": "Featuring alternating periods of depression and mania or elevated mood."},
181
- 2: {"name": "Depression", "color": "#4FC3F7", "description": "Persistent feelings of sadness, hopelessness, and loss of interest."},
182
- 3: {"name": "Normal", "color": "#81C784", "description": "Balanced emotional state without significant mental health concerns."},
183
- 4: {"name": "Personality Disorder", "color": "#9575CD", "description": "Persistent patterns of thinking and behavior that deviate from social norms."},
184
- 5: {"name": "Stress", "color": "#FF8A65", "description": "Physical or emotional tension due to challenging circumstances."},
185
- 6: {"name": "Suicidal", "color": "#F44336", "description": "Thoughts or intentions of self-harm or taking one's own life."}
186
- }
187
-
188
- mental_state = label_map.get(predicted_class, {"name": "Unknown", "color": "#BDBDBD", "description": "Unable to classify the mental state."})
189
 
190
- # Create data for visualization
191
- all_probs = {label_map[i]["name"]: prob.item() * 100 for i, prob in enumerate(probabilities)}
192
- probs_df = pd.DataFrame(list(all_probs.items()), columns=["Mental State", "Confidence (%)"])
193
- probs_df = probs_df.sort_values("Confidence (%)", ascending=False)
194
 
195
- # Display results
196
- st.markdown("<div class='result-box'>", unsafe_allow_html=True)
197
-
198
- # Primary result
199
- col1, col2 = st.columns([1, 2])
200
- with col1:
201
- st.markdown(f"<div class='metric-value' style='color:{mental_state['color']}'>{mental_state['name']}</div>", unsafe_allow_html=True)
202
- st.markdown("<div class='metric-label'>Primary Detection</div>", unsafe_allow_html=True)
203
-
204
- with col2:
205
- st.markdown(f"<div style='background-color:{mental_state['color']}20; padding:15px; border-radius:10px; border-left:5px solid {mental_state['color']}'>")
206
- st.markdown(f"<b>{mental_state['name']}</b>: {mental_state['description']}")
207
- st.markdown("</div>", unsafe_allow_html=True)
208
-
209
- # Confidence scores visualization
210
- st.markdown("<h3 class='sub-header'>Confidence Analysis</h3>", unsafe_allow_html=True)
211
-
212
- # Create bar chart
213
- fig = px.bar(
214
- probs_df,
215
- x="Confidence (%)",
216
- y="Mental State",
217
- orientation="h",
218
- color="Mental State",
219
- color_discrete_map={
220
- "Anxiety": "#FFD54F",
221
- "Bipolar": "#FF7043",
222
- "Depression": "#4FC3F7",
223
- "Normal": "#81C784",
224
- "Personality Disorder": "#9575CD",
225
- "Stress": "#FF8A65",
226
- "Suicidal": "#F44336",
227
- "Unknown": "#BDBDBD"
228
  }
229
- )
230
- fig.update_layout(
231
- height=350,
232
- margin=dict(l=20, r=20, t=30, b=20),
233
- xaxis_title="Confidence (%)",
234
- yaxis_title="",
235
- yaxis=dict(autorange="reversed"),
236
- xaxis=dict(range=[0, 100])
237
- )
238
- st.plotly_chart(fig, use_container_width=True)
239
-
240
- # Warning for high-risk categories
241
- if mental_state["name"] in ["Suicidal", "Depression"] and all_probs[mental_state["name"]] > 50:
242
- st.warning(
243
- "⚠️ **High-risk mental state detected.** If you or someone you know is experiencing "
244
- "suicidal thoughts, please seek immediate professional help or call the National "
245
- "Suicide Prevention Lifeline at 988 or 1-800-273-8255."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  )
247
-
248
- st.markdown("</div>", unsafe_allow_html=True)
249
-
250
- # Suggestion based on detected mental state
251
- suggestion_map = {
252
- "Anxiety": "Consider breathing exercises, meditation, or consulting with a mental health professional about anxiety management techniques.",
253
- "Bipolar": "Regular sleep schedules and medication management with professional oversight can help stabilize mood swings.",
254
- "Depression": "Regular physical activity, social connection, and professional therapy can be beneficial for managing depression.",
255
- "Normal": "Continue maintaining a healthy lifestyle with regular exercise, good sleep habits, and social connections.",
256
- "Personality Disorder": "Long-term psychotherapy with a specialist in personality disorders is often recommended.",
257
- "Stress": "Stress reduction techniques such as mindfulness, time management, and setting boundaries can be helpful.",
258
- "Suicidal": "Please seek immediate professional help. Call the National Suicide Prevention Lifeline at 988 or 1-800-273-8255."
259
- }
260
-
261
- st.markdown("<div class='result-box'>", unsafe_allow_html=True)
262
- st.markdown("<h3 class='sub-header'>Suggestions</h3>", unsafe_allow_html=True)
263
- st.info(suggestion_map.get(mental_state["name"], "Consider consulting with a mental health professional for personalized guidance."))
264
- st.markdown("</div>", unsafe_allow_html=True)
265
-
266
- else:
267
- st.warning("Please enter some text for analysis.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  with tab2:
270
  st.markdown("<h3 class='sub-header'>Mental Health Categories Explained</h3>", unsafe_allow_html=True)
 
5
  import numpy as np
6
  import plotly.express as px
7
  import time
8
+
9
+ # Try to import streamlit_lottie, but provide fallback if it fails
10
+ try:
11
+ from streamlit_lottie import st_lottie
12
+ import requests
13
+ def load_lottie_url(url: str):
14
+ try:
15
+ r = requests.get(url)
16
+ if r.status_code != 200:
17
+ return None
18
+ return r.json()
19
+ except:
20
+ return None
21
+ LOTTIE_AVAILABLE = True
22
+ except ImportError:
23
+ LOTTIE_AVAILABLE = False
24
 
25
  # Page configuration
26
  st.set_page_config(
 
95
  text-align: center;
96
  color: #666;
97
  }
98
+ .brain-icon {
99
+ font-size: 5rem;
100
+ text-align: center;
101
+ margin-bottom: 1rem;
102
+ color: #5E35B1;
103
+ }
104
  </style>
105
  """, unsafe_allow_html=True)
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  # Create sidebar
108
  with st.sidebar:
109
+ # Use either Lottie or a simple icon
110
+ if LOTTIE_AVAILABLE:
111
+ # Fixed Lottie URLs that are reliable
112
+ brain_animation = load_lottie_url("https://lottie.host/2eb12c32-787a-46f7-ac20-34c166d1a285/UcEEbJlFVH.json")
113
+ if brain_animation:
114
+ st_lottie(brain_animation, height=200, key="brain_animation")
115
+ else:
116
+ st.markdown("<div class='brain-icon'>🧠</div>", unsafe_allow_html=True)
117
+ else:
118
+ st.markdown("<div class='brain-icon'>🧠</div>", unsafe_allow_html=True)
119
+
120
  st.markdown("## About MindBERT")
121
  st.info(
122
  "MindBERT is a fine-tuned BERT model specifically designed to detect "
 
160
  # Model loading feedback
161
  @st.cache_resource
162
  def load_model():
163
+ try:
164
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
165
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
166
+ return tokenizer, model, True
167
+ except Exception as e:
168
+ st.error(f"Error loading model: {str(e)}")
169
+ return None, None, False
170
+
171
+ # Define model and tokenizer paths from Hugging Face
172
+ MODEL_PATH = "DrSyedFaizan/mindBERT"
173
 
174
  with st.spinner("Loading model..."):
175
+ tokenizer, model, model_loaded = load_model()
176
 
177
  # Analysis button
178
  col1, col2, col3 = st.columns([1, 2, 1])
 
181
 
182
  # Prediction logic
183
  if analyze_button:
184
+ if not model_loaded:
185
+ st.error("Model failed to load. Please try again later.")
186
+ elif not user_input.strip():
187
+ st.warning("Please enter some text for analysis.")
188
+ else:
189
+ # Show analyzing animation or spinner
190
  with st.spinner("Analyzing..."):
191
+ if LOTTIE_AVAILABLE:
192
+ analyzing_animation = load_lottie_url("https://lottie.host/16c400ec-7d59-4c0c-a84b-56c9134cd673/20XZXacKUS.json")
193
+ if analyzing_animation:
194
+ st_lottie(analyzing_animation, height=200, key="analyze_animation", speed=1.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ # Add a slight delay to show the animation
197
+ time.sleep(1)
 
 
198
 
199
+ try:
200
+ # Tokenize input
201
+ inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
202
+
203
+ # Make prediction
204
+ with torch.no_grad():
205
+ outputs = model(**inputs)
206
+ logits = outputs.logits
207
+ probabilities = torch.nn.functional.softmax(logits, dim=1)[0]
208
+ predicted_class = torch.argmax(logits, dim=1).item()
209
+
210
+ # Mapping predicted class to mental state with descriptions
211
+ label_map = {
212
+ 0: {"name": "Anxiety", "color": "#FFD54F", "description": "Characterized by excessive worry, fear, or nervousness."},
213
+ 1: {"name": "Bipolar", "color": "#FF7043", "description": "Featuring alternating periods of depression and mania or elevated mood."},
214
+ 2: {"name": "Depression", "color": "#4FC3F7", "description": "Persistent feelings of sadness, hopelessness, and loss of interest."},
215
+ 3: {"name": "Normal", "color": "#81C784", "description": "Balanced emotional state without significant mental health concerns."},
216
+ 4: {"name": "Personality Disorder", "color": "#9575CD", "description": "Persistent patterns of thinking and behavior that deviate from social norms."},
217
+ 5: {"name": "Stress", "color": "#FF8A65", "description": "Physical or emotional tension due to challenging circumstances."},
218
+ 6: {"name": "Suicidal", "color": "#F44336", "description": "Thoughts or intentions of self-harm or taking one's own life."}
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  }
220
+
221
+ mental_state = label_map.get(predicted_class, {"name": "Unknown", "color": "#BDBDBD", "description": "Unable to classify the mental state."})
222
+
223
+ # Create data for visualization
224
+ all_probs = {label_map[i]["name"]: prob.item() * 100 for i, prob in enumerate(probabilities)}
225
+ probs_df = pd.DataFrame(list(all_probs.items()), columns=["Mental State", "Confidence (%)"])
226
+ probs_df = probs_df.sort_values("Confidence (%)", ascending=False)
227
+
228
+ # Display results
229
+ st.markdown("<div class='result-box'>", unsafe_allow_html=True)
230
+
231
+ # Primary result
232
+ col1, col2 = st.columns([1, 2])
233
+ with col1:
234
+ st.markdown(f"<div class='metric-value' style='color:{mental_state['color']}'>{mental_state['name']}</div>", unsafe_allow_html=True)
235
+ st.markdown("<div class='metric-label'>Primary Detection</div>", unsafe_allow_html=True)
236
+
237
+ with col2:
238
+ st.markdown(f"<div style='background-color:{mental_state['color']}20; padding:15px; border-radius:10px; border-left:5px solid {mental_state['color']}'>")
239
+ st.markdown(f"<b>{mental_state['name']}</b>: {mental_state['description']}")
240
+ st.markdown("</div>", unsafe_allow_html=True)
241
+
242
+ # Confidence scores visualization
243
+ st.markdown("<h3 class='sub-header'>Confidence Analysis</h3>", unsafe_allow_html=True)
244
+
245
+ # Create bar chart
246
+ fig = px.bar(
247
+ probs_df,
248
+ x="Confidence (%)",
249
+ y="Mental State",
250
+ orientation="h",
251
+ color="Mental State",
252
+ color_discrete_map={
253
+ "Anxiety": "#FFD54F",
254
+ "Bipolar": "#FF7043",
255
+ "Depression": "#4FC3F7",
256
+ "Normal": "#81C784",
257
+ "Personality Disorder": "#9575CD",
258
+ "Stress": "#FF8A65",
259
+ "Suicidal": "#F44336",
260
+ "Unknown": "#BDBDBD"
261
+ }
262
  )
263
+ fig.update_layout(
264
+ height=350,
265
+ margin=dict(l=20, r=20, t=30, b=20),
266
+ xaxis_title="Confidence (%)",
267
+ yaxis_title="",
268
+ yaxis=dict(autorange="reversed"),
269
+ xaxis=dict(range=[0, 100])
270
+ )
271
+ st.plotly_chart(fig, use_container_width=True)
272
+
273
+ # Warning for high-risk categories
274
+ if mental_state["name"] in ["Suicidal", "Depression"] and all_probs[mental_state["name"]] > 50:
275
+ st.warning(
276
+ "⚠️ **High-risk mental state detected.** If you or someone you know is experiencing "
277
+ "suicidal thoughts, please seek immediate professional help or call the National "
278
+ "Suicide Prevention Lifeline at 988 or 1-800-273-8255."
279
+ )
280
+
281
+ st.markdown("</div>", unsafe_allow_html=True)
282
+
283
+ # Suggestion based on detected mental state
284
+ suggestion_map = {
285
+ "Anxiety": "Consider breathing exercises, meditation, or consulting with a mental health professional about anxiety management techniques.",
286
+ "Bipolar": "Regular sleep schedules and medication management with professional oversight can help stabilize mood swings.",
287
+ "Depression": "Regular physical activity, social connection, and professional therapy can be beneficial for managing depression.",
288
+ "Normal": "Continue maintaining a healthy lifestyle with regular exercise, good sleep habits, and social connections.",
289
+ "Personality Disorder": "Long-term psychotherapy with a specialist in personality disorders is often recommended.",
290
+ "Stress": "Stress reduction techniques such as mindfulness, time management, and setting boundaries can be helpful.",
291
+ "Suicidal": "Please seek immediate professional help. Call the National Suicide Prevention Lifeline at 988 or 1-800-273-8255."
292
+ }
293
+
294
+ st.markdown("<div class='result-box'>", unsafe_allow_html=True)
295
+ st.markdown("<h3 class='sub-header'>Suggestions</h3>", unsafe_allow_html=True)
296
+ st.info(suggestion_map.get(mental_state["name"], "Consider consulting with a mental health professional for personalized guidance."))
297
+ st.markdown("</div>", unsafe_allow_html=True)
298
+
299
+ except Exception as e:
300
+ st.error(f"Error during analysis: {str(e)}")
301
+ st.info("Please try again with different text or contact support if the issue persists.")
302
 
303
  with tab2:
304
  st.markdown("<h3 class='sub-header'>Mental Health Categories Explained</h3>", unsafe_allow_html=True)