mentalpulse / app.py
Samdutse's picture
Upload app.py with huggingface_hub
c830d31 verified
Raw
History Blame Contribute Delete
4.55 kB
import gradio as gr
import torch
import joblib
import numpy as np
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import re
import contractions
# Load model and tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert")
model = DistilBertForSequenceClassification.from_pretrained("distilbert")
model.eval()
le = joblib.load("label_encoder.pkl")
label_map = {i: cls for i, cls in enumerate(le.classes_)}
# Class display config
class_config = {
"Normal": {"emoji": "🟒", "color": "#2ecc71"},
"Anxiety": {"emoji": "🟑", "color": "#f39c12"},
"Stress": {"emoji": "🟠", "color": "#e67e22"},
"Depression": {"emoji": "πŸ”΅", "color": "#3498db"},
"Bipolar": {"emoji": "🟣", "color": "#9b59b6"},
"Personality disorder": {"emoji": "βšͺ", "color": "#95a5a6"},
"Suicidal": {"emoji": "πŸ”΄", "color": "#e74c3c"},
}
CRISIS_MESSAGE = """
⚠️ CRISIS SUPPORT β€” You are not alone.
If you or someone you know is in crisis, please reach out immediately:
- International Association for Suicide Prevention: https://www.iasp.info/resources/Crisis_Centres/
- Crisis Text Line (US): Text HOME to 741741
- Nigeria: NEEM Foundation β€” +234 806 210 6493
- Befrienders Worldwide: https://www.befrienders.org
"""
def clean_text(text):
if not isinstance(text, str):
return ""
text = contractions.fix(text)
text = text.lower()
text = re.sub(r"http\S+|www\S+", "", text)
text = re.sub(r"[^a-z\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def predict(statement):
if not statement.strip():
return "Please enter a statement.", "", ""
clean = clean_text(statement)
inputs = tokenizer(
clean,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).squeeze()
pred_idx = torch.argmax(probs).item()
pred_label = label_map[pred_idx]
confidence = round(probs[pred_idx].item() * 100, 1)
cfg = class_config[pred_label]
result = f"{cfg['emoji']} {pred_label} ({confidence}% confidence)"
# Top 3 probabilities
top3_idx = torch.argsort(probs, descending=True)[:3]
breakdown = "Probability breakdown:\n"
for idx in top3_idx:
lbl = label_map[idx.item()]
pct = round(probs[idx].item() * 100, 1)
breakdown += f" {class_config[lbl]['emoji']} {lbl}: {pct}%\n"
crisis = CRISIS_MESSAGE if pred_label == "Suicidal" else ""
return result, breakdown, crisis
# UI
with gr.Blocks(title="MentalPulse") as demo:
gr.Markdown("""
# 🧠 MentalPulse β€” Mental Health Signal Classifier
### Detect mental health signals from text using fine-tuned DistilBERT
*Built by Samuel Yaula Dutse | Mental Health Awareness Month β€” June 2026*
> ⚠️ This tool is for research and educational purposes only.
> It is not a clinical diagnostic tool. If you are in crisis, please seek professional help.
""")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Enter a statement",
placeholder="How are you feeling? Describe your thoughts or emotions...",
lines=5
)
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
pred_output = gr.Textbox(label="Detected Signal", interactive=False)
breakdown_output = gr.Textbox(label="Confidence Breakdown", lines=5, interactive=False)
crisis_output = gr.Textbox(label="Crisis Resources", lines=8, interactive=False, visible=True)
gr.Examples(
examples=[
["I have been feeling so hopeless lately, nothing makes sense and I just want it to stop"],
["My heart races constantly and I cannot stop worrying about everything"],
["Today was actually a good day. I felt present and enjoyed time with my family"],
["My mood swings are so extreme, one day I feel invincible and the next I cannot get out of bed"],
["Work has been overwhelming and I feel like I am stretched too thin"],
],
inputs=[text_input]
)
submit_btn.click(
fn=predict,
inputs=[text_input],
outputs=[pred_output, breakdown_output, crisis_output]
)
demo.launch()