""" The trainable part of this project: a fusion head that combines pretrained audio / text / face embeddings into a mental-health risk prediction. Two fusion strategies are included: - ConcatFusion: simple, strong baseline (concatenate + MLP) - AttentionFusion: each modality attends to the others before pooling (use this as your "novel" contribution / for the ablation table) """ import torch import torch.nn as nn EMBED_DIM = 256 NUM_CLASSES = 3 # low / moderate / high risk -- change to 1 for PHQ regression class ConcatFusion(nn.Module): def __init__(self, embed_dim=EMBED_DIM, num_classes=NUM_CLASSES, hidden=256, dropout=0.3): super().__init__() self.net = nn.Sequential( nn.Linear(embed_dim * 3, hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden // 2, num_classes), ) def forward(self, audio, text, face): x = torch.cat([audio, text, face], dim=-1) return self.net(x) class AttentionFusion(nn.Module): """ Treats the 3 modality embeddings as a sequence of 3 tokens and runs them through a small multi-head self-attention block so each modality can be re-weighted based on the others (e.g. down-weight a noisy face signal if audio+text strongly agree) before pooling and classifying. """ def __init__(self, embed_dim=EMBED_DIM, num_classes=NUM_CLASSES, num_heads=4, dropout=0.3): super().__init__() self.modality_embed = nn.Parameter(torch.randn(3, embed_dim) * 0.02) self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout) self.norm1 = nn.LayerNorm(embed_dim) self.ffn = nn.Sequential( nn.Linear(embed_dim, embed_dim * 2), nn.ReLU(), nn.Linear(embed_dim * 2, embed_dim), ) self.norm2 = nn.LayerNorm(embed_dim) self.pool_weights = nn.Linear(embed_dim, 1) # learned attention pooling self.classifier = nn.Sequential( nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(embed_dim // 2, num_classes), ) def forward(self, audio, text, face, return_weights=False): # [B, 3, D] tokens = torch.stack([audio, text, face], dim=1) + self.modality_embed.unsqueeze(0) attn_out, attn_weights = self.attn(tokens, tokens, tokens) x = self.norm1(tokens + attn_out) x = self.norm2(x + self.ffn(x)) # learned weighted pooling across the 3 modality tokens pool_scores = torch.softmax(self.pool_weights(x).squeeze(-1), dim=-1) # [B, 3] pooled = torch.bmm(pool_scores.unsqueeze(1), x).squeeze(1) # [B, D] logits = self.classifier(pooled) if return_weights: return logits, pool_scores # pool_scores = per-modality contribution return logits def build_model(kind: str = "attention", **kwargs): if kind == "concat": return ConcatFusion(**kwargs) return AttentionFusion(**kwargs)