ourafla's picture
Update app.py
bd8cc3c verified
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# === 1. Model Setup ===
MODEL_NAME = "ourafla/mental-health-bert-finetuned"
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
print("⏳ Loading model... this may take a few seconds.")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME, token=hf_token, output_attentions=True
)
model.eval()
print("βœ… Model loaded successfully.")
except Exception as e:
print(f"❌ Error loading model: {e}")
tokenizer = None
model = None
id2label = {0: "Anxiety", 1: "Depression", 2: "Normal", 3: "Suicidal"}
# === 2. CSS & THEME VARIABLES ===
# This JS toggles the 'dark' class on the body element
js_func = """
function toggleTheme() {
const body = document.querySelector('body');
if (body.classList.contains('dark')) {
body.classList.remove('dark');
} else {
body.classList.add('dark');
}
}
"""
CUSTOM_CSS = """
/* === COLOR VARIABLES === */
:root {
/* Light Mode (Default) */
--bg-app: #F3F4F6;
--bg-card: #FFFFFF;
--bg-input: #F9FAFB;
--border-color: #E5E7EB;
--text-main: #111827;
--text-sub: #4B5563;
--primary: #2563EB;
--primary-hover: #1D4ED8;
--accent-bg: #EFF6FF;
/* API Modal Specifics - Light */
--modal-bg: #FFFFFF;
--code-bg: #F3F4F6;
--code-text: #1F2937;
}
body.dark {
/* Dark Mode */
--bg-app: #0B0F19; /* Very dark blue/black */
--bg-card: #111827; /* Dark Slate */
--bg-input: #1F2937; /* Darker Slate */
--border-color: #374151; /* Medium Gray */
--text-main: #F9FAFB; /* Almost White */
--text-sub: #D1D5DB; /* Light Gray */
--primary: #3B82F6; /* Bright Blue */
--primary-hover: #60A5FA;
--accent-bg: #1E293B;
/* API Modal Specifics - Dark */
--modal-bg: #111827;
--code-bg: #020617;
--code-text: #E2E8F0;
}
/* Global Background Application */
body, .gradio-container {
background-color: var(--bg-app) !important;
color: var(--text-main) !important;
transition: background-color 0.3s, color 0.3s;
font-family: 'Inter', system-ui, sans-serif !important;
}
/* === MAIN COMPONENT STYLES === */
/* Card Container */
.main-card {
background-color: var(--bg-card) !important;
border: 1px solid var(--border-color) !important;
border-radius: 16px !important;
padding: 32px !important;
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1) !important;
max-width: 900px !important;
margin: 20px auto !important;
}
/* Header & Typography */
h1 {
color: var(--text-main) !important;
font-weight: 800 !important;
text-align: center;
}
.subtitle {
color: var(--text-sub) !important;
text-align: center;
font-size: 1.1rem;
margin-bottom: 2rem;
}
/* Theme Toggle Button */
.theme-btn {
position: absolute !important;
top: 20px !important;
right: 20px !important;
background: var(--bg-card) !important;
border: 1px solid var(--border-color) !important;
color: var(--text-main) !important;
border-radius: 50% !important;
width: 40px !important;
height: 40px !important;
padding: 0 !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
font-size: 1.2rem !important;
cursor: pointer !important;
z-index: 100 !important;
}
/* Tabs */
.tabs {
border-bottom: 2px solid var(--border-color) !important;
background: transparent !important;
margin-bottom: 24px !important;
}
.tab-nav button {
color: var(--text-sub) !important;
font-weight: 600 !important;
}
.tab-nav button.selected {
color: var(--primary) !important;
border-bottom-color: var(--primary) !important;
}
/* Input Fields */
textarea {
background-color: var(--bg-input) !important;
border: 1px solid var(--border-color) !important;
color: var(--text-main) !important;
border-radius: 12px !important;
}
textarea::placeholder {
color: var(--text-sub) !important;
opacity: 0.7;
}
/* Read-only Textbox (Explanation) */
textarea:read-only {
background-color: var(--accent-bg) !important;
border-color: var(--border-color) !important;
}
/* Action Button */
.analyze-btn {
background: linear-gradient(135deg, var(--primary) 0%, var(--primary-hover) 100%) !important;
color: white !important;
font-weight: 700 !important;
border: none !important;
border-radius: 12px !important;
padding: 16px !important;
box-shadow: 0 4px 6px rgba(0,0,0,0.1) !important;
}
/* Example Chips */
.example-btn {
background-color: var(--bg-card) !important;
border: 1px solid var(--border-color) !important;
color: var(--text-main) !important;
border-radius: 99px !important;
}
.example-btn:hover {
border-color: var(--primary) !important;
color: var(--primary) !important;
}
/* About Text */
.about-text p, .about-text li {
color: var(--text-sub) !important;
}
.about-text h3, .about-text b {
color: var(--text-main) !important;
}
/* === API DOCUMENTATION MODAL FIXES === */
/* Force the API Modal to respect our theme variables */
.gradio-container div[role="dialog"],
div[data-testid="settings-dialog"],
.modal {
background-color: var(--modal-bg) !important;
border: 1px solid var(--border-color) !important;
}
/* Text inside modal */
div[role="dialog"] h2, div[role="dialog"] h3, div[role="dialog"] p,
div[role="dialog"] label, div[role="dialog"] span, div[role="dialog"] td {
color: var(--text-main) !important;
}
/* Code Blocks in API Modal */
div[role="dialog"] pre, div[role="dialog"] code {
background-color: var(--code-bg) !important;
color: var(--code-text) !important;
border: 1px solid var(--border-color) !important;
}
/* Close button in modal */
div[role="dialog"] button.close {
color: var(--text-main) !important;
}
"""
# === 3. HELPER FUNCTIONS (WITH DOCSTRINGS) ===
# Docstrings are crucial: they appear in the API documentation
def get_anxiety_example():
"""Returns a sample text demonstrating anxiety symptoms."""
return "I feel incredibly anxious and my heart won't stop racing."
def get_depression_example():
"""Returns a sample text demonstrating depression symptoms."""
return "I wake up feeling empty and I don't see the point in anything anymore."
def get_normal_example():
"""Returns a sample text demonstrating a healthy emotional state."""
return "I had a pretty good day at work and looking forward to the weekend."
# === 4. ANALYSIS LOGIC ===
def analyze_text(text):
"""
Analyzes text to detect mental health indicators (Anxiety, Depression, Normal, Suicidal).
Parameters:
text (str): Input text (min 5 chars).
Returns:
(html_output, explanation_text): Visual HTML dashboard and key token extraction.
"""
if not text or len(text.strip()) < 5:
return f"""
<div style='background: #FEF2F2; border: 1px solid #FECACA; border-radius: 12px; padding: 24px; text-align: center;'>
<div style='font-size: 3rem; margin-bottom: 16px;'>⚠️</div>
<div style='color: #991B1B; font-weight: 800; font-size: 1.25rem; margin-bottom: 8px;'>Input too short</div>
<div style='color: #7F1D1D; font-size: 1rem; font-weight: 500;'>Please enter at least 5 characters.</div>
</div>
""", "N/A"
if model is None:
return "Model not loaded.", "Error"
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0]
label_idx = torch.argmax(probs).item()
label = id2label[label_idx]
confidence = float(probs[label_idx])
all_probs = {id2label[i]: float(probs[i]) for i in range(len(id2label))}
sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
# Explanation Logic
try:
attn = outputs.attentions[-1]
scores = attn.mean(dim=1).mean(dim=1)[0]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
token_scores = []
for i, score in enumerate(scores):
tok = tokens[i]
if tok not in ["[CLS]", "[SEP]", "[PAD]"]:
token_scores.append((tok, score.item()))
token_scores.sort(key=lambda x: x[1], reverse=True)
top_tokens = [t[0].replace("##", "") for t in token_scores[:3]]
explanation = ", ".join(top_tokens)
except:
explanation = "Analysis unavailable"
# UI Configurations
status_config = {
"Anxiety": {"emoji": "😰", "color": "#D97706"},
"Depression": {"emoji": "πŸ˜”", "color": "#2563EB"},
"Normal": {"emoji": "😊", "color": "#059669"},
"Suicidal": {"emoji": "πŸ†˜", "color": "#DC2626"},
}
cfg = status_config.get(label, {"emoji": "😐", "color": "#4B5563"})
# HTML Generation (Inline styles used for results to ensure they persist in all themes)
bars_html = ""
for name, score in sorted_probs:
pct = int(score * 100)
bar_color = status_config.get(name, {}).get("color", "#6B7280")
opacity = "1.0" if name == label else "0.5"
# We use CSS variables for text inside the HTML block to adapt to dark mode
bars_html += f"""
<div style='margin-bottom: 16px;'>
<div style='display: flex; justify-content: space-between; font-size: 0.95rem; margin-bottom: 6px;'>
<span style='font-weight: 700; color: var(--text-main);'>{name}</span>
<span style='font-weight: 700; color: var(--text-sub);'>{pct}%</span>
</div>
<div style='background: var(--border-color); border-radius: 8px; height: 12px; width: 100%; overflow: hidden;'>
<div style='background: {bar_color}; opacity: {opacity}; width: {pct}%; height: 100%; border-radius: 8px;'></div>
</div>
</div>
"""
html_output = f"""
<div class='result-container' style='border-top: 6px solid {cfg['color']}; background: var(--bg-card); padding: 32px; border-radius: 16px; border: 1px solid var(--border-color);'>
<div style='display: flex; align-items: center; justify-content: center; flex-direction: column; padding-bottom: 32px; border-bottom: 2px solid var(--border-color);'>
<div style='font-size: 4rem; margin-bottom: 16px;'>{cfg['emoji']}</div>
<div style='font-size: 2rem; font-weight: 800; color: var(--text-main); margin-bottom: 8px;'>{label}</div>
<div style='color: var(--text-sub); font-size: 1.1rem; font-weight: 600;'>Confidence: {int(confidence*100)}%</div>
</div>
<div style='margin-top: 32px;'>
<div style='font-size: 1.1rem; font-weight: 800; color: var(--text-main); margin-bottom: 24px;'>Detailed Breakdown</div>
{bars_html}
</div>
</div>
"""
return html_output, explanation
# === 5. UI LAYOUT ===
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Default()) as demo:
# Theme Toggle Button (Floating)
toggle_btn = gr.Button("πŸŒ—", elem_classes="theme-btn")
with gr.Column(elem_classes="main-card"):
gr.HTML("""
<h1>Mental Health Text Analyzer</h1>
<div class="subtitle">AI-powered analysis for Anxiety, Depression, and Suicidal tendencies</div>
""")
with gr.Tabs():
# --- Tab 1: Analyze ---
with gr.TabItem("Analyze Text"):
input_text = gr.Textbox(
placeholder="Type or paste text here...",
lines=5,
label="Input Text",
)
gr.HTML("<div style='font-size: 0.95rem; color: var(--text-sub); font-weight: 600; margin-bottom: 12px; margin-top: 20px;'>Try an example:</div>")
with gr.Row():
ex_anxiety = gr.Button("😰 Anxiety", elem_classes="example-btn")
ex_depression = gr.Button("πŸ˜” Depression", elem_classes="example-btn")
ex_normal = gr.Button("😊 Normal", elem_classes="example-btn")
analyze_btn = gr.Button("Analyze Mental Health", elem_classes="analyze-btn")
html_result = gr.HTML(label="Analysis Result")
explanation_box = gr.Textbox(
label="Key Keywords Detected",
interactive=False,
visible=True
)
# --- Tab 2: About ---
with gr.TabItem("About & Privacy"):
gr.HTML("""
<div class="about-text">
<h3>πŸ€– How it works</h3>
<p>This tool utilizes a fine-tuned BERT model. It analyzes semantic patterns to classify text.</p>
<h3>πŸ“Š Detection Categories</h3>
<ul>
<li><b>Anxiety:</b> Signs of worry or excessive concern.</li>
<li><b>Depression:</b> Signs of sadness or hopelessness.</li>
<li><b>Normal:</b> Neutral or positive state.</li>
<li><b>Suicidal:</b> Warning signs of self-harm.</li>
</ul>
</div>
""")
# === 6. EVENT WIRING (Including API Names for Documentation) ===
# 1. Theme Toggle
toggle_btn.click(None, js=js_func)
# 2. Example Buttons (Named APIs for Doc)
ex_anxiety.click(
fn=get_anxiety_example,
outputs=input_text,
api_name="get_anxiety_example" # <--- API NAME 1
)
ex_depression.click(
fn=get_depression_example,
outputs=input_text,
api_name="get_depression_example" # <--- API NAME 2
)
ex_normal.click(
fn=get_normal_example,
outputs=input_text,
api_name="get_normal_example" # <--- API NAME 3
)
# 3. Main Analysis (Named API for Doc)
analyze_btn.click(
fn=analyze_text,
inputs=input_text,
outputs=[html_result, explanation_box],
api_name="predict" # <--- API NAME 4
)
if __name__ == "__main__":
print("πŸš€ Starting Gradio Server...")
demo.launch(server_name="0.0.0.0", server_port=7860)