import gradio as gr import torch import joblib import numpy as np from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification import re import contractions # Load model and tokenizer tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert") model = DistilBertForSequenceClassification.from_pretrained("distilbert") model.eval() le = joblib.load("label_encoder.pkl") label_map = {i: cls for i, cls in enumerate(le.classes_)} # Class display config class_config = { "Normal": {"emoji": "🟢", "color": "#2ecc71"}, "Anxiety": {"emoji": "🟡", "color": "#f39c12"}, "Stress": {"emoji": "🟠", "color": "#e67e22"}, "Depression": {"emoji": "🔵", "color": "#3498db"}, "Bipolar": {"emoji": "🟣", "color": "#9b59b6"}, "Personality disorder": {"emoji": "⚪", "color": "#95a5a6"}, "Suicidal": {"emoji": "🔴", "color": "#e74c3c"}, } CRISIS_MESSAGE = """ ⚠️ CRISIS SUPPORT — You are not alone. If you or someone you know is in crisis, please reach out immediately: - International Association for Suicide Prevention: https://www.iasp.info/resources/Crisis_Centres/ - Crisis Text Line (US): Text HOME to 741741 - Nigeria: NEEM Foundation — +234 806 210 6493 - Befrienders Worldwide: https://www.befrienders.org """ def clean_text(text): if not isinstance(text, str): return "" text = contractions.fix(text) text = text.lower() text = re.sub(r"http\S+|www\S+", "", text) text = re.sub(r"[^a-z\s]", "", text) text = re.sub(r"\s+", " ", text).strip() return text def predict(statement): if not statement.strip(): return "Please enter a statement.", "", "" clean = clean_text(statement) inputs = tokenizer( clean, return_tensors="pt", truncation=True, padding=True, max_length=128 ) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=1).squeeze() pred_idx = torch.argmax(probs).item() pred_label = label_map[pred_idx] confidence = round(probs[pred_idx].item() * 100, 1) cfg = class_config[pred_label] result = f"{cfg['emoji']} {pred_label} ({confidence}% confidence)" # Top 3 probabilities top3_idx = torch.argsort(probs, descending=True)[:3] breakdown = "Probability breakdown:\n" for idx in top3_idx: lbl = label_map[idx.item()] pct = round(probs[idx].item() * 100, 1) breakdown += f" {class_config[lbl]['emoji']} {lbl}: {pct}%\n" crisis = CRISIS_MESSAGE if pred_label == "Suicidal" else "" return result, breakdown, crisis # UI with gr.Blocks(title="MentalPulse") as demo: gr.Markdown(""" # 🧠 MentalPulse — Mental Health Signal Classifier ### Detect mental health signals from text using fine-tuned DistilBERT *Built by Samuel Yaula Dutse | Mental Health Awareness Month — June 2026* > ⚠️ This tool is for research and educational purposes only. > It is not a clinical diagnostic tool. If you are in crisis, please seek professional help. """) with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Enter a statement", placeholder="How are you feeling? Describe your thoughts or emotions...", lines=5 ) submit_btn = gr.Button("Analyze", variant="primary") with gr.Column(): pred_output = gr.Textbox(label="Detected Signal", interactive=False) breakdown_output = gr.Textbox(label="Confidence Breakdown", lines=5, interactive=False) crisis_output = gr.Textbox(label="Crisis Resources", lines=8, interactive=False, visible=True) gr.Examples( examples=[ ["I have been feeling so hopeless lately, nothing makes sense and I just want it to stop"], ["My heart races constantly and I cannot stop worrying about everything"], ["Today was actually a good day. I felt present and enjoyed time with my family"], ["My mood swings are so extreme, one day I feel invincible and the next I cannot get out of bed"], ["Work has been overwhelming and I feel like I am stretched too thin"], ], inputs=[text_input] ) submit_btn.click( fn=predict, inputs=[text_input], outputs=[pred_output, breakdown_output, crisis_output] ) demo.launch()