EMOTIA / models /fusion.py
Manav2op's picture
Upload folder using huggingface_hub
25d0747 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossModalAttention(nn.Module):
"""
Cross-modal attention mechanism for fusing vision, audio, and text features.
"""
def __init__(self, embed_dim=256, num_heads=8):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query_proj = nn.Linear(embed_dim, embed_dim)
self.key_proj = nn.Linear(embed_dim, embed_dim)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(0.1)
def forward(self, query, key_value):
"""
query: (B, seq_len_q, embed_dim)
key_value: (B, seq_len_kv, embed_dim)
"""
# Project to attention space
q = self.query_proj(query)
k = self.key_proj(key_value)
v = self.value_proj(key_value)
# Multi-head attention
attn_output, attn_weights = self.multihead_attn(q, k, v)
# Residual connection and normalization
output = self.norm(query + self.dropout(attn_output))
return output, attn_weights
class TemporalTransformer(nn.Module):
"""
Temporal transformer for modeling sequences across time windows.
"""
def __init__(self, embed_dim=256, num_layers=4, num_heads=8):
super().__init__()
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=0.1,
batch_first=True
) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
"""
x: (B, seq_len, embed_dim) - sequence of fused features over time
"""
for layer in self.layers:
x = layer(x)
return self.norm(x)
class MultiModalFusion(nn.Module):
"""
Complete fusion network combining vision, audio, text with temporal modeling.
"""
def __init__(self, vision_dim=768, audio_dim=128, text_dim=768, embed_dim=256,
num_emotions=7, num_intents=5):
super().__init__()
self.embed_dim = embed_dim
# Modality projectors
self.vision_proj = nn.Linear(vision_dim, embed_dim)
self.audio_proj = nn.Linear(audio_dim, embed_dim)
self.text_proj = nn.Linear(text_dim, embed_dim)
# Cross-modal attention layers
self.vision_to_audio_attn = CrossModalAttention(embed_dim)
self.audio_to_text_attn = CrossModalAttention(embed_dim)
self.text_to_vision_attn = CrossModalAttention(embed_dim)
# Temporal modeling
self.temporal_transformer = TemporalTransformer(embed_dim)
# Dynamic modality weighting
self.modality_weights = nn.Parameter(torch.ones(3)) # vision, audio, text
# Output heads
self.emotion_classifier = nn.Linear(embed_dim, num_emotions)
self.intent_classifier = nn.Linear(embed_dim, num_intents)
self.engagement_regressor = nn.Linear(embed_dim, 1)
self.confidence_regressor = nn.Linear(embed_dim, 1)
# Modality contribution estimator
self.contribution_estimator = nn.Linear(embed_dim * 3, 3) # weights for each modality
def forward(self, vision_features, audio_features, text_features, temporal_seq=False):
"""
vision_features: (B, vision_dim) or (B, T, vision_dim)
audio_features: (B, audio_dim) or (B, T, audio_dim)
text_features: (B, text_dim) or (B, T, text_dim)
temporal_seq: whether inputs are temporal sequences
"""
# Project to common embedding space
v_proj = self.vision_proj(vision_features) # (B, embed_dim) or (B, T, embed_dim)
a_proj = self.audio_proj(audio_features)
t_proj = self.text_proj(text_features)
if temporal_seq:
# Handle temporal sequences
B, T, _ = v_proj.shape
# Reshape for attention: (B*T, 1, embed_dim)
v_flat = v_proj.view(B*T, 1, -1)
a_flat = a_proj.view(B*T, 1, -1)
t_flat = t_proj.view(B*T, 1, -1)
# Cross-modal attention
v_attn, _ = self.vision_to_audio_attn(v_flat, a_flat)
a_attn, _ = self.audio_to_text_attn(a_flat, t_flat)
t_attn, _ = self.text_to_vision_attn(t_flat, v_flat)
# Combine attended features
fused = (v_attn + a_attn + t_attn) / 3 # (B*T, 1, embed_dim)
# Reshape back to temporal: (B, T, embed_dim)
fused = fused.view(B, T, -1)
# Temporal transformer
temporal_out = self.temporal_transformer(fused) # (B, T, embed_dim)
# Pool temporal dimension (take last timestep or mean)
pooled = temporal_out[:, -1, :] # (B, embed_dim)
else:
# Single timestep fusion
# Cross-modal attention
v_attn, _ = self.vision_to_audio_attn(v_proj.unsqueeze(1), a_proj.unsqueeze(1))
a_attn, _ = self.audio_to_text_attn(a_proj.unsqueeze(1), t_proj.unsqueeze(1))
t_attn, _ = self.text_to_vision_attn(t_proj.unsqueeze(1), v_proj.unsqueeze(1))
# Weighted fusion
weights = F.softmax(self.modality_weights, dim=0)
fused = weights[0] * v_attn.squeeze(1) + \
weights[1] * a_attn.squeeze(1) + \
weights[2] * t_attn.squeeze(1)
pooled = fused
# Output predictions
emotion_logits = self.emotion_classifier(pooled)
intent_logits = self.intent_classifier(pooled)
engagement = torch.sigmoid(self.engagement_regressor(pooled))
confidence = torch.sigmoid(self.confidence_regressor(pooled))
# Modality contributions
contributions = torch.softmax(self.contribution_estimator(
torch.cat([v_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True),
a_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True),
t_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True)], dim=-1)
), dim=-1)
return {
'emotion': emotion_logits,
'intent': intent_logits,
'engagement': engagement.squeeze(),
'confidence': confidence.squeeze(),
'contributions': contributions.squeeze()
}
def get_modality_weights(self):
"""
Return normalized modality weights for explainability.
"""
return F.softmax(self.modality_weights, dim=0)