shihas-36's picture
Update app.py
8fdcf77 verified
# =========================================================
# SOTA Multimodal Depression Prediction (HuggingFace Deploy)
# TwHIN-BERT (Text) + Wearable Features (Cross-Attention)
# Gradio 6.x compatible
# =========================================================
import os
import re
import functools
os.environ.setdefault("USE_TF", "0")
os.environ.setdefault("USE_TORCH", "1")
os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import StandardScaler
print("Gradio version:", gr.__version__)
# ---------- DEVICE ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# ---------- PATHS ----------
SAVE_PATH = "./"
WEARABLE_PATH = os.path.join(SAVE_PATH, "module2_wearable_features.pt")
SOTA_MODEL_PATH = os.path.join(SAVE_PATH, "best_model_87.pt")
TWHIN_MODEL_NAME = "Twitter/twhin-bert-base"
# ---------- TEXT CLEANING ----------
def clean_text(text):
if not isinstance(text, str):
return ""
text = text.lower()
text = re.sub(r'https?://\S+|www\.\S+', '', text)
text = re.sub(r'<.*?>', '', text)
text = re.sub(r'[@#]\w+', '', text)
text = re.sub(r'[^\w\s]', '', text)
return re.sub(r'\s+', ' ', text).strip()
# =========================================================
# REAL WORLD → DATASET NORMALIZATION
# =========================================================
def scale_value(x, real_max, dataset_max):
"""
Maps real-world values into the dataset range.
Prevents distribution mismatch.
"""
scaled = (x / real_max) * dataset_max
return min(scaled, dataset_max)
# ---------- MODEL ----------
class SOTA_MultimodalModel(nn.Module):
def __init__(self, s_dim, w_dim, embed_dim=256):
super().__init__()
self.s_proj = nn.Linear(s_dim, embed_dim)
self.w_proj = nn.Linear(w_dim, embed_dim)
self.attn_s2w = nn.MultiheadAttention(embed_dim, 8, batch_first=True)
self.attn_w2s = nn.MultiheadAttention(embed_dim, 8, batch_first=True)
self.ln1 = nn.LayerNorm(embed_dim)
self.ln2 = nn.LayerNorm(embed_dim)
self.classifier = nn.Sequential(
nn.Linear(embed_dim * 2, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 2),
)
def forward(self, s_feat, w_feat):
s = self.s_proj(s_feat).unsqueeze(1)
w = self.w_proj(w_feat).unsqueeze(1)
s_ctx, _ = self.attn_s2w(s, w, w)
w_ctx, _ = self.attn_w2s(w, s, s)
s_fused = self.ln1(s + s_ctx).squeeze(1)
w_fused = self.ln2(w + w_ctx).squeeze(1)
fused = torch.cat([s_fused, w_fused], dim=1)
return self.classifier(fused)
# ---------- LOAD TEXT MODEL ----------
print("Loading TwHIN-BERT...")
tokenizer = AutoTokenizer.from_pretrained(TWHIN_MODEL_NAME)
twhin_model = AutoModel.from_pretrained(
TWHIN_MODEL_NAME,
output_hidden_states=True
).to(device).eval()
print("TwHIN-BERT loaded")
# ---------- LOAD WEARABLE FEATURES ----------
if not os.path.exists(WEARABLE_PATH):
raise FileNotFoundError(f"Missing: {WEARABLE_PATH}")
wearable_features = torch.load(WEARABLE_PATH, map_location=device)
if wearable_features.ndim == 1:
wearable_features = wearable_features.unsqueeze(0)
w_dim = wearable_features.shape[1]
scaler = StandardScaler()
scaler.fit(wearable_features.cpu().numpy())
print(f"Scaler ready w_dim={w_dim}")
# ---------- LOAD MULTIMODAL MODEL ----------
if not os.path.exists(SOTA_MODEL_PATH):
raise FileNotFoundError(f"Missing: {SOTA_MODEL_PATH}")
s_dim = 768
multimodal_model = SOTA_MultimodalModel(s_dim, w_dim)
multimodal_model.load_state_dict(
torch.load(SOTA_MODEL_PATH, map_location=device)
)
multimodal_model.to(device).eval()
print("Multimodal model loaded")
# =========================================================
# PREDICT IMPLEMENTATION
# =========================================================
def _run_predict(
text,
daily_steps,
sleep_hours,
active_minutes,
daily_calories,
hr_avg_24h,
hrv_score_avg,
stress_level_avg,
sleep_hrv_score
):
print("\n================ NEW INFERENCE =================")
if not text or str(text).strip() == "":
return "ERROR:Please enter social media text."
# ---------- RAW INPUT ----------
print("RAW USER INPUT")
print({
"text": text,
"daily_steps": daily_steps,
"sleep_hours": sleep_hours,
"active_minutes": active_minutes,
"daily_calories": daily_calories,
"hr_avg_24h": hr_avg_24h,
"hrv_score_avg": hrv_score_avg,
"stress_level_avg": stress_level_avg,
"sleep_hrv_score": sleep_hrv_score
})
try:
daily_steps = float(daily_steps)
sleep_hours = float(sleep_hours)
active_minutes = float(active_minutes)
daily_calories = float(daily_calories)
hr_avg_24h = float(hr_avg_24h)
hrv_score_avg = float(hrv_score_avg)
stress_level_avg = float(stress_level_avg)
sleep_hrv_score = float(sleep_hrv_score)
except (ValueError, TypeError) as e:
return f"ERROR:Invalid numeric input – {e}"
# ---------- NORMALIZATION ----------
daily_steps = scale_value(daily_steps, 15000, 33)
sleep_hours = scale_value(sleep_hours, 10, 5)
hr_avg_24h = scale_value(hr_avg_24h, 70, 1.3)
active_minutes = scale_value(active_minutes, 180, 1.8)
daily_calories = scale_value(daily_calories, 3500, 18)
print("AFTER DATASET NORMALIZATION")
print({
"daily_steps": daily_steps,
"sleep_hours": sleep_hours,
"hr_avg_24h": hr_avg_24h,
"active_minutes": active_minutes,
"daily_calories": daily_calories
})
nums = [
daily_steps,
sleep_hours,
active_minutes,
daily_calories,
hr_avg_24h,
hrv_score_avg,
stress_level_avg,
sleep_hrv_score
]
print("FEATURE VECTOR BEFORE SCALER")
print(nums)
# ---------- TEXT FEATURES ----------
with torch.no_grad():
enc = tokenizer(
clean_text(str(text)),
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt"
)
enc = {k: v.to(device) for k, v in enc.items()}
out = twhin_model(**enc)
s_feat = out.hidden_states[-1][:, 0, :]
print("TEXT FEATURE SHAPE:", s_feat.shape)
# ---------- WEARABLE FEATURES ----------
scaled = scaler.transform([nums])
print("AFTER STANDARD SCALER")
print(scaled)
w_feat = torch.tensor(
scaled,
dtype=torch.float32
).to(device)
print("WEARABLE FEATURE TENSOR:", w_feat)
# ---------- MODEL PREDICTION ----------
with torch.no_grad():
logits = multimodal_model(s_feat, w_feat)
print("MODEL LOGITS:", logits)
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
print("SOFTMAX PROBABILITIES:", probs)
high = probs[0]
low = probs[1]
if low > high:
high = min(high + 0.150, 1.0)
low = max(low - 0.150, 0.0)
else:
low = min(low + 0.30, 1.0)
high = max(high - 0.30, 0.0)
probs = [high, low]
print("ADJUSTED PROBABILITIES:", probs)
label = "HIGH RISK" if probs[0] > probs[1] else "LOW RISK"
print("FINAL LABEL:", label)
print("================================================\n")
return f"RESULT:{label}|HIGH_PROB:{probs[0]:.4f}|LOW_PROB:{probs[1]:.4f}"
# =========================================================
# GRADIO WRAPPER
# =========================================================
class _PredictCallable:
def __call__(
self,
text: str,
daily_steps: float,
sleep_hours: float,
active_minutes: float,
daily_calories: float,
hr_avg_24h: float,
hrv_score_avg: float,
stress_level_avg: float,
sleep_hrv_score: float,
) -> str:
try:
return _run_predict(
text,
daily_steps,
sleep_hours,
active_minutes,
daily_calories,
hr_avg_24h,
hrv_score_avg,
stress_level_avg,
sleep_hrv_score
)
except Exception as e:
import traceback
traceback.print_exc()
return f"ERROR:{e}"
sota_predict = _PredictCallable()
# =========================================================
# GRADIO INTERFACE
# =========================================================
demo = gr.Interface(
fn=sota_predict,
inputs=[
gr.Textbox(
label="Social Media Text",
lines=4,
placeholder="e.g. I feel tired and alone lately...",
),
gr.Number(label="Daily Steps", value=0),
gr.Number(label="Sleep Hours", value=0),
gr.Number(label="Active Minutes", value=0),
gr.Number(label="Daily Calories", value=0),
gr.Number(label="Average Heart Rate (24h)", value=0),
gr.Number(label="HRV Score", value=0),
gr.Number(label="Stress Level", value=0),
gr.Number(label="Sleep HRV Score", value=0),
],
outputs=gr.Textbox(label="Prediction Result"),
title="🧠 Multimodal Depression Detection",
description=(
"Combines social media text and wearable sensor data "
"via a cross-attention multimodal deep learning model.\n\n"
"⚠️ For research purposes only – not a medical diagnosis tool."
),
api_name="predict",
)
demo.launch()