|
|
import os |
|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
|
|
|
MODEL_NAME = "ourafla/mental-health-bert-finetuned" |
|
|
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
|
|
|
print("β³ Loading model... this may take a few seconds.") |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token) |
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
MODEL_NAME, token=hf_token, output_attentions=True |
|
|
) |
|
|
model.eval() |
|
|
print("β
Model loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"β Error loading model: {e}") |
|
|
tokenizer = None |
|
|
model = None |
|
|
|
|
|
id2label = {0: "Anxiety", 1: "Depression", 2: "Normal", 3: "Suicidal"} |
|
|
|
|
|
|
|
|
|
|
|
js_func = """ |
|
|
function toggleTheme() { |
|
|
const body = document.querySelector('body'); |
|
|
if (body.classList.contains('dark')) { |
|
|
body.classList.remove('dark'); |
|
|
} else { |
|
|
body.classList.add('dark'); |
|
|
} |
|
|
} |
|
|
""" |
|
|
|
|
|
CUSTOM_CSS = """ |
|
|
/* === COLOR VARIABLES === */ |
|
|
:root { |
|
|
/* Light Mode (Default) */ |
|
|
--bg-app: #F3F4F6; |
|
|
--bg-card: #FFFFFF; |
|
|
--bg-input: #F9FAFB; |
|
|
--border-color: #E5E7EB; |
|
|
--text-main: #111827; |
|
|
--text-sub: #4B5563; |
|
|
--primary: #2563EB; |
|
|
--primary-hover: #1D4ED8; |
|
|
--accent-bg: #EFF6FF; |
|
|
|
|
|
/* API Modal Specifics - Light */ |
|
|
--modal-bg: #FFFFFF; |
|
|
--code-bg: #F3F4F6; |
|
|
--code-text: #1F2937; |
|
|
} |
|
|
|
|
|
body.dark { |
|
|
/* Dark Mode */ |
|
|
--bg-app: #0B0F19; /* Very dark blue/black */ |
|
|
--bg-card: #111827; /* Dark Slate */ |
|
|
--bg-input: #1F2937; /* Darker Slate */ |
|
|
--border-color: #374151; /* Medium Gray */ |
|
|
--text-main: #F9FAFB; /* Almost White */ |
|
|
--text-sub: #D1D5DB; /* Light Gray */ |
|
|
--primary: #3B82F6; /* Bright Blue */ |
|
|
--primary-hover: #60A5FA; |
|
|
--accent-bg: #1E293B; |
|
|
|
|
|
/* API Modal Specifics - Dark */ |
|
|
--modal-bg: #111827; |
|
|
--code-bg: #020617; |
|
|
--code-text: #E2E8F0; |
|
|
} |
|
|
|
|
|
/* Global Background Application */ |
|
|
body, .gradio-container { |
|
|
background-color: var(--bg-app) !important; |
|
|
color: var(--text-main) !important; |
|
|
transition: background-color 0.3s, color 0.3s; |
|
|
font-family: 'Inter', system-ui, sans-serif !important; |
|
|
} |
|
|
|
|
|
/* === MAIN COMPONENT STYLES === */ |
|
|
|
|
|
/* Card Container */ |
|
|
.main-card { |
|
|
background-color: var(--bg-card) !important; |
|
|
border: 1px solid var(--border-color) !important; |
|
|
border-radius: 16px !important; |
|
|
padding: 32px !important; |
|
|
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1) !important; |
|
|
max-width: 900px !important; |
|
|
margin: 20px auto !important; |
|
|
} |
|
|
|
|
|
/* Header & Typography */ |
|
|
h1 { |
|
|
color: var(--text-main) !important; |
|
|
font-weight: 800 !important; |
|
|
text-align: center; |
|
|
} |
|
|
|
|
|
.subtitle { |
|
|
color: var(--text-sub) !important; |
|
|
text-align: center; |
|
|
font-size: 1.1rem; |
|
|
margin-bottom: 2rem; |
|
|
} |
|
|
|
|
|
/* Theme Toggle Button */ |
|
|
.theme-btn { |
|
|
position: absolute !important; |
|
|
top: 20px !important; |
|
|
right: 20px !important; |
|
|
background: var(--bg-card) !important; |
|
|
border: 1px solid var(--border-color) !important; |
|
|
color: var(--text-main) !important; |
|
|
border-radius: 50% !important; |
|
|
width: 40px !important; |
|
|
height: 40px !important; |
|
|
padding: 0 !important; |
|
|
display: flex !important; |
|
|
align-items: center !important; |
|
|
justify-content: center !important; |
|
|
font-size: 1.2rem !important; |
|
|
cursor: pointer !important; |
|
|
z-index: 100 !important; |
|
|
} |
|
|
|
|
|
/* Tabs */ |
|
|
.tabs { |
|
|
border-bottom: 2px solid var(--border-color) !important; |
|
|
background: transparent !important; |
|
|
margin-bottom: 24px !important; |
|
|
} |
|
|
.tab-nav button { |
|
|
color: var(--text-sub) !important; |
|
|
font-weight: 600 !important; |
|
|
} |
|
|
.tab-nav button.selected { |
|
|
color: var(--primary) !important; |
|
|
border-bottom-color: var(--primary) !important; |
|
|
} |
|
|
|
|
|
/* Input Fields */ |
|
|
textarea { |
|
|
background-color: var(--bg-input) !important; |
|
|
border: 1px solid var(--border-color) !important; |
|
|
color: var(--text-main) !important; |
|
|
border-radius: 12px !important; |
|
|
} |
|
|
textarea::placeholder { |
|
|
color: var(--text-sub) !important; |
|
|
opacity: 0.7; |
|
|
} |
|
|
|
|
|
/* Read-only Textbox (Explanation) */ |
|
|
textarea:read-only { |
|
|
background-color: var(--accent-bg) !important; |
|
|
border-color: var(--border-color) !important; |
|
|
} |
|
|
|
|
|
/* Action Button */ |
|
|
.analyze-btn { |
|
|
background: linear-gradient(135deg, var(--primary) 0%, var(--primary-hover) 100%) !important; |
|
|
color: white !important; |
|
|
font-weight: 700 !important; |
|
|
border: none !important; |
|
|
border-radius: 12px !important; |
|
|
padding: 16px !important; |
|
|
box-shadow: 0 4px 6px rgba(0,0,0,0.1) !important; |
|
|
} |
|
|
|
|
|
/* Example Chips */ |
|
|
.example-btn { |
|
|
background-color: var(--bg-card) !important; |
|
|
border: 1px solid var(--border-color) !important; |
|
|
color: var(--text-main) !important; |
|
|
border-radius: 99px !important; |
|
|
} |
|
|
.example-btn:hover { |
|
|
border-color: var(--primary) !important; |
|
|
color: var(--primary) !important; |
|
|
} |
|
|
|
|
|
/* About Text */ |
|
|
.about-text p, .about-text li { |
|
|
color: var(--text-sub) !important; |
|
|
} |
|
|
.about-text h3, .about-text b { |
|
|
color: var(--text-main) !important; |
|
|
} |
|
|
|
|
|
/* === API DOCUMENTATION MODAL FIXES === */ |
|
|
/* Force the API Modal to respect our theme variables */ |
|
|
.gradio-container div[role="dialog"], |
|
|
div[data-testid="settings-dialog"], |
|
|
.modal { |
|
|
background-color: var(--modal-bg) !important; |
|
|
border: 1px solid var(--border-color) !important; |
|
|
} |
|
|
|
|
|
/* Text inside modal */ |
|
|
div[role="dialog"] h2, div[role="dialog"] h3, div[role="dialog"] p, |
|
|
div[role="dialog"] label, div[role="dialog"] span, div[role="dialog"] td { |
|
|
color: var(--text-main) !important; |
|
|
} |
|
|
|
|
|
/* Code Blocks in API Modal */ |
|
|
div[role="dialog"] pre, div[role="dialog"] code { |
|
|
background-color: var(--code-bg) !important; |
|
|
color: var(--code-text) !important; |
|
|
border: 1px solid var(--border-color) !important; |
|
|
} |
|
|
|
|
|
/* Close button in modal */ |
|
|
div[role="dialog"] button.close { |
|
|
color: var(--text-main) !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def get_anxiety_example(): |
|
|
"""Returns a sample text demonstrating anxiety symptoms.""" |
|
|
return "I feel incredibly anxious and my heart won't stop racing." |
|
|
|
|
|
def get_depression_example(): |
|
|
"""Returns a sample text demonstrating depression symptoms.""" |
|
|
return "I wake up feeling empty and I don't see the point in anything anymore." |
|
|
|
|
|
def get_normal_example(): |
|
|
"""Returns a sample text demonstrating a healthy emotional state.""" |
|
|
return "I had a pretty good day at work and looking forward to the weekend." |
|
|
|
|
|
|
|
|
def analyze_text(text): |
|
|
""" |
|
|
Analyzes text to detect mental health indicators (Anxiety, Depression, Normal, Suicidal). |
|
|
|
|
|
Parameters: |
|
|
text (str): Input text (min 5 chars). |
|
|
|
|
|
Returns: |
|
|
(html_output, explanation_text): Visual HTML dashboard and key token extraction. |
|
|
""" |
|
|
if not text or len(text.strip()) < 5: |
|
|
return f""" |
|
|
<div style='background: #FEF2F2; border: 1px solid #FECACA; border-radius: 12px; padding: 24px; text-align: center;'> |
|
|
<div style='font-size: 3rem; margin-bottom: 16px;'>β οΈ</div> |
|
|
<div style='color: #991B1B; font-weight: 800; font-size: 1.25rem; margin-bottom: 8px;'>Input too short</div> |
|
|
<div style='color: #7F1D1D; font-size: 1rem; font-weight: 500;'>Please enter at least 5 characters.</div> |
|
|
</div> |
|
|
""", "N/A" |
|
|
|
|
|
if model is None: |
|
|
return "Model not loaded.", "Error" |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
probs = torch.softmax(outputs.logits, dim=-1)[0] |
|
|
label_idx = torch.argmax(probs).item() |
|
|
label = id2label[label_idx] |
|
|
confidence = float(probs[label_idx]) |
|
|
|
|
|
all_probs = {id2label[i]: float(probs[i]) for i in range(len(id2label))} |
|
|
sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True) |
|
|
|
|
|
|
|
|
try: |
|
|
attn = outputs.attentions[-1] |
|
|
scores = attn.mean(dim=1).mean(dim=1)[0] |
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
|
token_scores = [] |
|
|
for i, score in enumerate(scores): |
|
|
tok = tokens[i] |
|
|
if tok not in ["[CLS]", "[SEP]", "[PAD]"]: |
|
|
token_scores.append((tok, score.item())) |
|
|
token_scores.sort(key=lambda x: x[1], reverse=True) |
|
|
top_tokens = [t[0].replace("##", "") for t in token_scores[:3]] |
|
|
explanation = ", ".join(top_tokens) |
|
|
except: |
|
|
explanation = "Analysis unavailable" |
|
|
|
|
|
|
|
|
status_config = { |
|
|
"Anxiety": {"emoji": "π°", "color": "#D97706"}, |
|
|
"Depression": {"emoji": "π", "color": "#2563EB"}, |
|
|
"Normal": {"emoji": "π", "color": "#059669"}, |
|
|
"Suicidal": {"emoji": "π", "color": "#DC2626"}, |
|
|
} |
|
|
cfg = status_config.get(label, {"emoji": "π", "color": "#4B5563"}) |
|
|
|
|
|
|
|
|
bars_html = "" |
|
|
for name, score in sorted_probs: |
|
|
pct = int(score * 100) |
|
|
bar_color = status_config.get(name, {}).get("color", "#6B7280") |
|
|
opacity = "1.0" if name == label else "0.5" |
|
|
|
|
|
|
|
|
bars_html += f""" |
|
|
<div style='margin-bottom: 16px;'> |
|
|
<div style='display: flex; justify-content: space-between; font-size: 0.95rem; margin-bottom: 6px;'> |
|
|
<span style='font-weight: 700; color: var(--text-main);'>{name}</span> |
|
|
<span style='font-weight: 700; color: var(--text-sub);'>{pct}%</span> |
|
|
</div> |
|
|
<div style='background: var(--border-color); border-radius: 8px; height: 12px; width: 100%; overflow: hidden;'> |
|
|
<div style='background: {bar_color}; opacity: {opacity}; width: {pct}%; height: 100%; border-radius: 8px;'></div> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
html_output = f""" |
|
|
<div class='result-container' style='border-top: 6px solid {cfg['color']}; background: var(--bg-card); padding: 32px; border-radius: 16px; border: 1px solid var(--border-color);'> |
|
|
<div style='display: flex; align-items: center; justify-content: center; flex-direction: column; padding-bottom: 32px; border-bottom: 2px solid var(--border-color);'> |
|
|
<div style='font-size: 4rem; margin-bottom: 16px;'>{cfg['emoji']}</div> |
|
|
<div style='font-size: 2rem; font-weight: 800; color: var(--text-main); margin-bottom: 8px;'>{label}</div> |
|
|
<div style='color: var(--text-sub); font-size: 1.1rem; font-weight: 600;'>Confidence: {int(confidence*100)}%</div> |
|
|
</div> |
|
|
<div style='margin-top: 32px;'> |
|
|
<div style='font-size: 1.1rem; font-weight: 800; color: var(--text-main); margin-bottom: 24px;'>Detailed Breakdown</div> |
|
|
{bars_html} |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
return html_output, explanation |
|
|
|
|
|
|
|
|
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Default()) as demo: |
|
|
|
|
|
|
|
|
toggle_btn = gr.Button("π", elem_classes="theme-btn") |
|
|
|
|
|
with gr.Column(elem_classes="main-card"): |
|
|
|
|
|
gr.HTML(""" |
|
|
<h1>Mental Health Text Analyzer</h1> |
|
|
<div class="subtitle">AI-powered analysis for Anxiety, Depression, and Suicidal tendencies</div> |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
|
|
|
with gr.TabItem("Analyze Text"): |
|
|
|
|
|
input_text = gr.Textbox( |
|
|
placeholder="Type or paste text here...", |
|
|
lines=5, |
|
|
label="Input Text", |
|
|
) |
|
|
|
|
|
gr.HTML("<div style='font-size: 0.95rem; color: var(--text-sub); font-weight: 600; margin-bottom: 12px; margin-top: 20px;'>Try an example:</div>") |
|
|
|
|
|
with gr.Row(): |
|
|
ex_anxiety = gr.Button("π° Anxiety", elem_classes="example-btn") |
|
|
ex_depression = gr.Button("π Depression", elem_classes="example-btn") |
|
|
ex_normal = gr.Button("π Normal", elem_classes="example-btn") |
|
|
|
|
|
analyze_btn = gr.Button("Analyze Mental Health", elem_classes="analyze-btn") |
|
|
|
|
|
html_result = gr.HTML(label="Analysis Result") |
|
|
|
|
|
explanation_box = gr.Textbox( |
|
|
label="Key Keywords Detected", |
|
|
interactive=False, |
|
|
visible=True |
|
|
) |
|
|
|
|
|
|
|
|
with gr.TabItem("About & Privacy"): |
|
|
gr.HTML(""" |
|
|
<div class="about-text"> |
|
|
<h3>π€ How it works</h3> |
|
|
<p>This tool utilizes a fine-tuned BERT model. It analyzes semantic patterns to classify text.</p> |
|
|
<h3>π Detection Categories</h3> |
|
|
<ul> |
|
|
<li><b>Anxiety:</b> Signs of worry or excessive concern.</li> |
|
|
<li><b>Depression:</b> Signs of sadness or hopelessness.</li> |
|
|
<li><b>Normal:</b> Neutral or positive state.</li> |
|
|
<li><b>Suicidal:</b> Warning signs of self-harm.</li> |
|
|
</ul> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toggle_btn.click(None, js=js_func) |
|
|
|
|
|
|
|
|
ex_anxiety.click( |
|
|
fn=get_anxiety_example, |
|
|
outputs=input_text, |
|
|
api_name="get_anxiety_example" |
|
|
) |
|
|
ex_depression.click( |
|
|
fn=get_depression_example, |
|
|
outputs=input_text, |
|
|
api_name="get_depression_example" |
|
|
) |
|
|
ex_normal.click( |
|
|
fn=get_normal_example, |
|
|
outputs=input_text, |
|
|
api_name="get_normal_example" |
|
|
) |
|
|
|
|
|
|
|
|
analyze_btn.click( |
|
|
fn=analyze_text, |
|
|
inputs=input_text, |
|
|
outputs=[html_result, explanation_box], |
|
|
api_name="predict" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("π Starting Gradio Server...") |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |