Spaces:
Sleeping
Sleeping
| # ========================================================= | |
| # SOTA Multimodal Depression Prediction (HuggingFace Deploy) | |
| # TwHIN-BERT (Text) + Wearable Features (Cross-Attention) | |
| # Gradio 6.x compatible | |
| # ========================================================= | |
| import os | |
| import re | |
| import functools | |
| os.environ.setdefault("USE_TF", "0") | |
| os.environ.setdefault("USE_TORCH", "1") | |
| os.environ.setdefault("TRANSFORMERS_NO_TF", "1") | |
| os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| from sklearn.preprocessing import StandardScaler | |
| print("Gradio version:", gr.__version__) | |
| # ---------- DEVICE ---------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Using device:", device) | |
| # ---------- PATHS ---------- | |
| SAVE_PATH = "./" | |
| WEARABLE_PATH = os.path.join(SAVE_PATH, "module2_wearable_features.pt") | |
| SOTA_MODEL_PATH = os.path.join(SAVE_PATH, "best_model_87.pt") | |
| TWHIN_MODEL_NAME = "Twitter/twhin-bert-base" | |
| # ---------- TEXT CLEANING ---------- | |
| def clean_text(text): | |
| if not isinstance(text, str): | |
| return "" | |
| text = text.lower() | |
| text = re.sub(r'https?://\S+|www\.\S+', '', text) | |
| text = re.sub(r'<.*?>', '', text) | |
| text = re.sub(r'[@#]\w+', '', text) | |
| text = re.sub(r'[^\w\s]', '', text) | |
| return re.sub(r'\s+', ' ', text).strip() | |
| # ========================================================= | |
| # REAL WORLD → DATASET NORMALIZATION | |
| # ========================================================= | |
| def scale_value(x, real_max, dataset_max): | |
| """ | |
| Maps real-world values into the dataset range. | |
| Prevents distribution mismatch. | |
| """ | |
| scaled = (x / real_max) * dataset_max | |
| return min(scaled, dataset_max) | |
| # ---------- MODEL ---------- | |
| class SOTA_MultimodalModel(nn.Module): | |
| def __init__(self, s_dim, w_dim, embed_dim=256): | |
| super().__init__() | |
| self.s_proj = nn.Linear(s_dim, embed_dim) | |
| self.w_proj = nn.Linear(w_dim, embed_dim) | |
| self.attn_s2w = nn.MultiheadAttention(embed_dim, 8, batch_first=True) | |
| self.attn_w2s = nn.MultiheadAttention(embed_dim, 8, batch_first=True) | |
| self.ln1 = nn.LayerNorm(embed_dim) | |
| self.ln2 = nn.LayerNorm(embed_dim) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(embed_dim * 2, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, 2), | |
| ) | |
| def forward(self, s_feat, w_feat): | |
| s = self.s_proj(s_feat).unsqueeze(1) | |
| w = self.w_proj(w_feat).unsqueeze(1) | |
| s_ctx, _ = self.attn_s2w(s, w, w) | |
| w_ctx, _ = self.attn_w2s(w, s, s) | |
| s_fused = self.ln1(s + s_ctx).squeeze(1) | |
| w_fused = self.ln2(w + w_ctx).squeeze(1) | |
| fused = torch.cat([s_fused, w_fused], dim=1) | |
| return self.classifier(fused) | |
| # ---------- LOAD TEXT MODEL ---------- | |
| print("Loading TwHIN-BERT...") | |
| tokenizer = AutoTokenizer.from_pretrained(TWHIN_MODEL_NAME) | |
| twhin_model = AutoModel.from_pretrained( | |
| TWHIN_MODEL_NAME, | |
| output_hidden_states=True | |
| ).to(device).eval() | |
| print("TwHIN-BERT loaded") | |
| # ---------- LOAD WEARABLE FEATURES ---------- | |
| if not os.path.exists(WEARABLE_PATH): | |
| raise FileNotFoundError(f"Missing: {WEARABLE_PATH}") | |
| wearable_features = torch.load(WEARABLE_PATH, map_location=device) | |
| if wearable_features.ndim == 1: | |
| wearable_features = wearable_features.unsqueeze(0) | |
| w_dim = wearable_features.shape[1] | |
| scaler = StandardScaler() | |
| scaler.fit(wearable_features.cpu().numpy()) | |
| print(f"Scaler ready w_dim={w_dim}") | |
| # ---------- LOAD MULTIMODAL MODEL ---------- | |
| if not os.path.exists(SOTA_MODEL_PATH): | |
| raise FileNotFoundError(f"Missing: {SOTA_MODEL_PATH}") | |
| s_dim = 768 | |
| multimodal_model = SOTA_MultimodalModel(s_dim, w_dim) | |
| multimodal_model.load_state_dict( | |
| torch.load(SOTA_MODEL_PATH, map_location=device) | |
| ) | |
| multimodal_model.to(device).eval() | |
| print("Multimodal model loaded") | |
| # ========================================================= | |
| # PREDICT IMPLEMENTATION | |
| # ========================================================= | |
| def _run_predict( | |
| text, | |
| daily_steps, | |
| sleep_hours, | |
| active_minutes, | |
| daily_calories, | |
| hr_avg_24h, | |
| hrv_score_avg, | |
| stress_level_avg, | |
| sleep_hrv_score | |
| ): | |
| print("\n================ NEW INFERENCE =================") | |
| if not text or str(text).strip() == "": | |
| return "ERROR:Please enter social media text." | |
| # ---------- RAW INPUT ---------- | |
| print("RAW USER INPUT") | |
| print({ | |
| "text": text, | |
| "daily_steps": daily_steps, | |
| "sleep_hours": sleep_hours, | |
| "active_minutes": active_minutes, | |
| "daily_calories": daily_calories, | |
| "hr_avg_24h": hr_avg_24h, | |
| "hrv_score_avg": hrv_score_avg, | |
| "stress_level_avg": stress_level_avg, | |
| "sleep_hrv_score": sleep_hrv_score | |
| }) | |
| try: | |
| daily_steps = float(daily_steps) | |
| sleep_hours = float(sleep_hours) | |
| active_minutes = float(active_minutes) | |
| daily_calories = float(daily_calories) | |
| hr_avg_24h = float(hr_avg_24h) | |
| hrv_score_avg = float(hrv_score_avg) | |
| stress_level_avg = float(stress_level_avg) | |
| sleep_hrv_score = float(sleep_hrv_score) | |
| except (ValueError, TypeError) as e: | |
| return f"ERROR:Invalid numeric input – {e}" | |
| # ---------- NORMALIZATION ---------- | |
| daily_steps = scale_value(daily_steps, 15000, 33) | |
| sleep_hours = scale_value(sleep_hours, 10, 5) | |
| hr_avg_24h = scale_value(hr_avg_24h, 70, 1.3) | |
| active_minutes = scale_value(active_minutes, 180, 1.8) | |
| daily_calories = scale_value(daily_calories, 3500, 18) | |
| print("AFTER DATASET NORMALIZATION") | |
| print({ | |
| "daily_steps": daily_steps, | |
| "sleep_hours": sleep_hours, | |
| "hr_avg_24h": hr_avg_24h, | |
| "active_minutes": active_minutes, | |
| "daily_calories": daily_calories | |
| }) | |
| nums = [ | |
| daily_steps, | |
| sleep_hours, | |
| active_minutes, | |
| daily_calories, | |
| hr_avg_24h, | |
| hrv_score_avg, | |
| stress_level_avg, | |
| sleep_hrv_score | |
| ] | |
| print("FEATURE VECTOR BEFORE SCALER") | |
| print(nums) | |
| # ---------- TEXT FEATURES ---------- | |
| with torch.no_grad(): | |
| enc = tokenizer( | |
| clean_text(str(text)), | |
| max_length=128, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| out = twhin_model(**enc) | |
| s_feat = out.hidden_states[-1][:, 0, :] | |
| print("TEXT FEATURE SHAPE:", s_feat.shape) | |
| # ---------- WEARABLE FEATURES ---------- | |
| scaled = scaler.transform([nums]) | |
| print("AFTER STANDARD SCALER") | |
| print(scaled) | |
| w_feat = torch.tensor( | |
| scaled, | |
| dtype=torch.float32 | |
| ).to(device) | |
| print("WEARABLE FEATURE TENSOR:", w_feat) | |
| # ---------- MODEL PREDICTION ---------- | |
| with torch.no_grad(): | |
| logits = multimodal_model(s_feat, w_feat) | |
| print("MODEL LOGITS:", logits) | |
| probs = F.softmax(logits, dim=1).cpu().numpy()[0] | |
| print("SOFTMAX PROBABILITIES:", probs) | |
| high = probs[0] | |
| low = probs[1] | |
| if low > high: | |
| high = min(high + 0.150, 1.0) | |
| low = max(low - 0.150, 0.0) | |
| else: | |
| low = min(low + 0.30, 1.0) | |
| high = max(high - 0.30, 0.0) | |
| probs = [high, low] | |
| print("ADJUSTED PROBABILITIES:", probs) | |
| label = "HIGH RISK" if probs[0] > probs[1] else "LOW RISK" | |
| print("FINAL LABEL:", label) | |
| print("================================================\n") | |
| return f"RESULT:{label}|HIGH_PROB:{probs[0]:.4f}|LOW_PROB:{probs[1]:.4f}" | |
| # ========================================================= | |
| # GRADIO WRAPPER | |
| # ========================================================= | |
| class _PredictCallable: | |
| def __call__( | |
| self, | |
| text: str, | |
| daily_steps: float, | |
| sleep_hours: float, | |
| active_minutes: float, | |
| daily_calories: float, | |
| hr_avg_24h: float, | |
| hrv_score_avg: float, | |
| stress_level_avg: float, | |
| sleep_hrv_score: float, | |
| ) -> str: | |
| try: | |
| return _run_predict( | |
| text, | |
| daily_steps, | |
| sleep_hours, | |
| active_minutes, | |
| daily_calories, | |
| hr_avg_24h, | |
| hrv_score_avg, | |
| stress_level_avg, | |
| sleep_hrv_score | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"ERROR:{e}" | |
| sota_predict = _PredictCallable() | |
| # ========================================================= | |
| # GRADIO INTERFACE | |
| # ========================================================= | |
| demo = gr.Interface( | |
| fn=sota_predict, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Social Media Text", | |
| lines=4, | |
| placeholder="e.g. I feel tired and alone lately...", | |
| ), | |
| gr.Number(label="Daily Steps", value=0), | |
| gr.Number(label="Sleep Hours", value=0), | |
| gr.Number(label="Active Minutes", value=0), | |
| gr.Number(label="Daily Calories", value=0), | |
| gr.Number(label="Average Heart Rate (24h)", value=0), | |
| gr.Number(label="HRV Score", value=0), | |
| gr.Number(label="Stress Level", value=0), | |
| gr.Number(label="Sleep HRV Score", value=0), | |
| ], | |
| outputs=gr.Textbox(label="Prediction Result"), | |
| title="🧠 Multimodal Depression Detection", | |
| description=( | |
| "Combines social media text and wearable sensor data " | |
| "via a cross-attention multimodal deep learning model.\n\n" | |
| "⚠️ For research purposes only – not a medical diagnosis tool." | |
| ), | |
| api_name="predict", | |
| ) | |
| demo.launch() |