import os import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification # === 1. Model Setup === MODEL_NAME = "ourafla/mental-health-bert-finetuned" hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN") print("⏳ Loading model... this may take a few seconds.") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token) model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, token=hf_token, output_attentions=True ) model.eval() print("✅ Model loaded successfully.") except Exception as e: print(f"❌ Error loading model: {e}") tokenizer = None model = None id2label = {0: "Anxiety", 1: "Depression", 2: "Normal", 3: "Suicidal"} # === 2. CSS & THEME VARIABLES === # This JS toggles the 'dark' class on the body element js_func = """ function toggleTheme() { const body = document.querySelector('body'); if (body.classList.contains('dark')) { body.classList.remove('dark'); } else { body.classList.add('dark'); } } """ CUSTOM_CSS = """ /* === COLOR VARIABLES === */ :root { /* Light Mode (Default) */ --bg-app: #F3F4F6; --bg-card: #FFFFFF; --bg-input: #F9FAFB; --border-color: #E5E7EB; --text-main: #111827; --text-sub: #4B5563; --primary: #2563EB; --primary-hover: #1D4ED8; --accent-bg: #EFF6FF; /* API Modal Specifics - Light */ --modal-bg: #FFFFFF; --code-bg: #F3F4F6; --code-text: #1F2937; } body.dark { /* Dark Mode */ --bg-app: #0B0F19; /* Very dark blue/black */ --bg-card: #111827; /* Dark Slate */ --bg-input: #1F2937; /* Darker Slate */ --border-color: #374151; /* Medium Gray */ --text-main: #F9FAFB; /* Almost White */ --text-sub: #D1D5DB; /* Light Gray */ --primary: #3B82F6; /* Bright Blue */ --primary-hover: #60A5FA; --accent-bg: #1E293B; /* API Modal Specifics - Dark */ --modal-bg: #111827; --code-bg: #020617; --code-text: #E2E8F0; } /* Global Background Application */ body, .gradio-container { background-color: var(--bg-app) !important; color: var(--text-main) !important; transition: background-color 0.3s, color 0.3s; font-family: 'Inter', system-ui, sans-serif !important; } /* === MAIN COMPONENT STYLES === */ /* Card Container */ .main-card { background-color: var(--bg-card) !important; border: 1px solid var(--border-color) !important; border-radius: 16px !important; padding: 32px !important; box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1) !important; max-width: 900px !important; margin: 20px auto !important; } /* Header & Typography */ h1 { color: var(--text-main) !important; font-weight: 800 !important; text-align: center; } .subtitle { color: var(--text-sub) !important; text-align: center; font-size: 1.1rem; margin-bottom: 2rem; } /* Theme Toggle Button */ .theme-btn { position: absolute !important; top: 20px !important; right: 20px !important; background: var(--bg-card) !important; border: 1px solid var(--border-color) !important; color: var(--text-main) !important; border-radius: 50% !important; width: 40px !important; height: 40px !important; padding: 0 !important; display: flex !important; align-items: center !important; justify-content: center !important; font-size: 1.2rem !important; cursor: pointer !important; z-index: 100 !important; } /* Tabs */ .tabs { border-bottom: 2px solid var(--border-color) !important; background: transparent !important; margin-bottom: 24px !important; } .tab-nav button { color: var(--text-sub) !important; font-weight: 600 !important; } .tab-nav button.selected { color: var(--primary) !important; border-bottom-color: var(--primary) !important; } /* Input Fields */ textarea { background-color: var(--bg-input) !important; border: 1px solid var(--border-color) !important; color: var(--text-main) !important; border-radius: 12px !important; } textarea::placeholder { color: var(--text-sub) !important; opacity: 0.7; } /* Read-only Textbox (Explanation) */ textarea:read-only { background-color: var(--accent-bg) !important; border-color: var(--border-color) !important; } /* Action Button */ .analyze-btn { background: linear-gradient(135deg, var(--primary) 0%, var(--primary-hover) 100%) !important; color: white !important; font-weight: 700 !important; border: none !important; border-radius: 12px !important; padding: 16px !important; box-shadow: 0 4px 6px rgba(0,0,0,0.1) !important; } /* Example Chips */ .example-btn { background-color: var(--bg-card) !important; border: 1px solid var(--border-color) !important; color: var(--text-main) !important; border-radius: 99px !important; } .example-btn:hover { border-color: var(--primary) !important; color: var(--primary) !important; } /* About Text */ .about-text p, .about-text li { color: var(--text-sub) !important; } .about-text h3, .about-text b { color: var(--text-main) !important; } /* === API DOCUMENTATION MODAL FIXES === */ /* Force the API Modal to respect our theme variables */ .gradio-container div[role="dialog"], div[data-testid="settings-dialog"], .modal { background-color: var(--modal-bg) !important; border: 1px solid var(--border-color) !important; } /* Text inside modal */ div[role="dialog"] h2, div[role="dialog"] h3, div[role="dialog"] p, div[role="dialog"] label, div[role="dialog"] span, div[role="dialog"] td { color: var(--text-main) !important; } /* Code Blocks in API Modal */ div[role="dialog"] pre, div[role="dialog"] code { background-color: var(--code-bg) !important; color: var(--code-text) !important; border: 1px solid var(--border-color) !important; } /* Close button in modal */ div[role="dialog"] button.close { color: var(--text-main) !important; } """ # === 3. HELPER FUNCTIONS (WITH DOCSTRINGS) === # Docstrings are crucial: they appear in the API documentation def get_anxiety_example(): """Returns a sample text demonstrating anxiety symptoms.""" return "I feel incredibly anxious and my heart won't stop racing." def get_depression_example(): """Returns a sample text demonstrating depression symptoms.""" return "I wake up feeling empty and I don't see the point in anything anymore." def get_normal_example(): """Returns a sample text demonstrating a healthy emotional state.""" return "I had a pretty good day at work and looking forward to the weekend." # === 4. ANALYSIS LOGIC === def analyze_text(text): """ Analyzes text to detect mental health indicators (Anxiety, Depression, Normal, Suicidal). Parameters: text (str): Input text (min 5 chars). Returns: (html_output, explanation_text): Visual HTML dashboard and key token extraction. """ if not text or len(text.strip()) < 5: return f"""
⚠️
Input too short
Please enter at least 5 characters.
""", "N/A" if model is None: return "Model not loaded.", "Error" inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1)[0] label_idx = torch.argmax(probs).item() label = id2label[label_idx] confidence = float(probs[label_idx]) all_probs = {id2label[i]: float(probs[i]) for i in range(len(id2label))} sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True) # Explanation Logic try: attn = outputs.attentions[-1] scores = attn.mean(dim=1).mean(dim=1)[0] tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) token_scores = [] for i, score in enumerate(scores): tok = tokens[i] if tok not in ["[CLS]", "[SEP]", "[PAD]"]: token_scores.append((tok, score.item())) token_scores.sort(key=lambda x: x[1], reverse=True) top_tokens = [t[0].replace("##", "") for t in token_scores[:3]] explanation = ", ".join(top_tokens) except: explanation = "Analysis unavailable" # UI Configurations status_config = { "Anxiety": {"emoji": "😰", "color": "#D97706"}, "Depression": {"emoji": "😔", "color": "#2563EB"}, "Normal": {"emoji": "😊", "color": "#059669"}, "Suicidal": {"emoji": "🆘", "color": "#DC2626"}, } cfg = status_config.get(label, {"emoji": "😐", "color": "#4B5563"}) # HTML Generation (Inline styles used for results to ensure they persist in all themes) bars_html = "" for name, score in sorted_probs: pct = int(score * 100) bar_color = status_config.get(name, {}).get("color", "#6B7280") opacity = "1.0" if name == label else "0.5" # We use CSS variables for text inside the HTML block to adapt to dark mode bars_html += f"""
{name} {pct}%
""" html_output = f"""
{cfg['emoji']}
{label}
Confidence: {int(confidence*100)}%
Detailed Breakdown
{bars_html}
""" return html_output, explanation # === 5. UI LAYOUT === with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Default()) as demo: # Theme Toggle Button (Floating) toggle_btn = gr.Button("🌗", elem_classes="theme-btn") with gr.Column(elem_classes="main-card"): gr.HTML("""

Mental Health Text Analyzer

AI-powered analysis for Anxiety, Depression, and Suicidal tendencies
""") with gr.Tabs(): # --- Tab 1: Analyze --- with gr.TabItem("Analyze Text"): input_text = gr.Textbox( placeholder="Type or paste text here...", lines=5, label="Input Text", ) gr.HTML("
Try an example:
") with gr.Row(): ex_anxiety = gr.Button("😰 Anxiety", elem_classes="example-btn") ex_depression = gr.Button("😔 Depression", elem_classes="example-btn") ex_normal = gr.Button("😊 Normal", elem_classes="example-btn") analyze_btn = gr.Button("Analyze Mental Health", elem_classes="analyze-btn") html_result = gr.HTML(label="Analysis Result") explanation_box = gr.Textbox( label="Key Keywords Detected", interactive=False, visible=True ) # --- Tab 2: About --- with gr.TabItem("About & Privacy"): gr.HTML("""

🤖 How it works

This tool utilizes a fine-tuned BERT model. It analyzes semantic patterns to classify text.

📊 Detection Categories

""") # === 6. EVENT WIRING (Including API Names for Documentation) === # 1. Theme Toggle toggle_btn.click(None, js=js_func) # 2. Example Buttons (Named APIs for Doc) ex_anxiety.click( fn=get_anxiety_example, outputs=input_text, api_name="get_anxiety_example" # <--- API NAME 1 ) ex_depression.click( fn=get_depression_example, outputs=input_text, api_name="get_depression_example" # <--- API NAME 2 ) ex_normal.click( fn=get_normal_example, outputs=input_text, api_name="get_normal_example" # <--- API NAME 3 ) # 3. Main Analysis (Named API for Doc) analyze_btn.click( fn=analyze_text, inputs=input_text, outputs=[html_result, explanation_box], api_name="predict" # <--- API NAME 4 ) if __name__ == "__main__": print("🚀 Starting Gradio Server...") demo.launch(server_name="0.0.0.0", server_port=7860)