Spaces:
Sleeping
Sleeping
| """ | |
| The trainable part of this project: a fusion head that combines pretrained | |
| audio / text / face embeddings into a mental-health risk prediction. | |
| Two fusion strategies are included: | |
| - ConcatFusion: simple, strong baseline (concatenate + MLP) | |
| - AttentionFusion: each modality attends to the others before pooling | |
| (use this as your "novel" contribution / for the ablation table) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| EMBED_DIM = 256 | |
| NUM_CLASSES = 3 # low / moderate / high risk -- change to 1 for PHQ regression | |
| class ConcatFusion(nn.Module): | |
| def __init__(self, embed_dim=EMBED_DIM, num_classes=NUM_CLASSES, hidden=256, dropout=0.3): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(embed_dim * 3, hidden), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden, hidden // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden // 2, num_classes), | |
| ) | |
| def forward(self, audio, text, face): | |
| x = torch.cat([audio, text, face], dim=-1) | |
| return self.net(x) | |
| class AttentionFusion(nn.Module): | |
| """ | |
| Treats the 3 modality embeddings as a sequence of 3 tokens and runs them | |
| through a small multi-head self-attention block so each modality can be | |
| re-weighted based on the others (e.g. down-weight a noisy face signal if | |
| audio+text strongly agree) before pooling and classifying. | |
| """ | |
| def __init__(self, embed_dim=EMBED_DIM, num_classes=NUM_CLASSES, num_heads=4, dropout=0.3): | |
| super().__init__() | |
| self.modality_embed = nn.Parameter(torch.randn(3, embed_dim) * 0.02) | |
| self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout) | |
| self.norm1 = nn.LayerNorm(embed_dim) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(embed_dim, embed_dim * 2), | |
| nn.ReLU(), | |
| nn.Linear(embed_dim * 2, embed_dim), | |
| ) | |
| self.norm2 = nn.LayerNorm(embed_dim) | |
| self.pool_weights = nn.Linear(embed_dim, 1) # learned attention pooling | |
| self.classifier = nn.Sequential( | |
| nn.Linear(embed_dim, embed_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(embed_dim // 2, num_classes), | |
| ) | |
| def forward(self, audio, text, face, return_weights=False): | |
| # [B, 3, D] | |
| tokens = torch.stack([audio, text, face], dim=1) + self.modality_embed.unsqueeze(0) | |
| attn_out, attn_weights = self.attn(tokens, tokens, tokens) | |
| x = self.norm1(tokens + attn_out) | |
| x = self.norm2(x + self.ffn(x)) | |
| # learned weighted pooling across the 3 modality tokens | |
| pool_scores = torch.softmax(self.pool_weights(x).squeeze(-1), dim=-1) # [B, 3] | |
| pooled = torch.bmm(pool_scores.unsqueeze(1), x).squeeze(1) # [B, D] | |
| logits = self.classifier(pooled) | |
| if return_weights: | |
| return logits, pool_scores # pool_scores = per-modality contribution | |
| return logits | |
| def build_model(kind: str = "attention", **kwargs): | |
| if kind == "concat": | |
| return ConcatFusion(**kwargs) | |
| return AttentionFusion(**kwargs) | |