MindSense / scripts /fusion_model.py
Yashdesai07's picture
Upload fusion_model.py
3d55bc9 verified
Raw
History Blame Contribute Delete
3.17 kB
"""
The trainable part of this project: a fusion head that combines pretrained
audio / text / face embeddings into a mental-health risk prediction.
Two fusion strategies are included:
- ConcatFusion: simple, strong baseline (concatenate + MLP)
- AttentionFusion: each modality attends to the others before pooling
(use this as your "novel" contribution / for the ablation table)
"""
import torch
import torch.nn as nn
EMBED_DIM = 256
NUM_CLASSES = 3 # low / moderate / high risk -- change to 1 for PHQ regression
class ConcatFusion(nn.Module):
def __init__(self, embed_dim=EMBED_DIM, num_classes=NUM_CLASSES, hidden=256, dropout=0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embed_dim * 3, hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden, hidden // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden // 2, num_classes),
)
def forward(self, audio, text, face):
x = torch.cat([audio, text, face], dim=-1)
return self.net(x)
class AttentionFusion(nn.Module):
"""
Treats the 3 modality embeddings as a sequence of 3 tokens and runs them
through a small multi-head self-attention block so each modality can be
re-weighted based on the others (e.g. down-weight a noisy face signal if
audio+text strongly agree) before pooling and classifying.
"""
def __init__(self, embed_dim=EMBED_DIM, num_classes=NUM_CLASSES, num_heads=4, dropout=0.3):
super().__init__()
self.modality_embed = nn.Parameter(torch.randn(3, embed_dim) * 0.02)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 2),
nn.ReLU(),
nn.Linear(embed_dim * 2, embed_dim),
)
self.norm2 = nn.LayerNorm(embed_dim)
self.pool_weights = nn.Linear(embed_dim, 1) # learned attention pooling
self.classifier = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(embed_dim // 2, num_classes),
)
def forward(self, audio, text, face, return_weights=False):
# [B, 3, D]
tokens = torch.stack([audio, text, face], dim=1) + self.modality_embed.unsqueeze(0)
attn_out, attn_weights = self.attn(tokens, tokens, tokens)
x = self.norm1(tokens + attn_out)
x = self.norm2(x + self.ffn(x))
# learned weighted pooling across the 3 modality tokens
pool_scores = torch.softmax(self.pool_weights(x).squeeze(-1), dim=-1) # [B, 3]
pooled = torch.bmm(pool_scores.unsqueeze(1), x).squeeze(1) # [B, D]
logits = self.classifier(pooled)
if return_weights:
return logits, pool_scores # pool_scores = per-modality contribution
return logits
def build_model(kind: str = "attention", **kwargs):
if kind == "concat":
return ConcatFusion(**kwargs)
return AttentionFusion(**kwargs)