| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | class CrossModalAttention(nn.Module): |
| | """ |
| | Cross-modal attention mechanism for fusing vision, audio, and text features. |
| | """ |
| | def __init__(self, embed_dim=256, num_heads=8): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.num_heads = num_heads |
| |
|
| | self.query_proj = nn.Linear(embed_dim, embed_dim) |
| | self.key_proj = nn.Linear(embed_dim, embed_dim) |
| | self.value_proj = nn.Linear(embed_dim, embed_dim) |
| |
|
| | self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) |
| |
|
| | self.norm = nn.LayerNorm(embed_dim) |
| | self.dropout = nn.Dropout(0.1) |
| |
|
| | def forward(self, query, key_value): |
| | """ |
| | query: (B, seq_len_q, embed_dim) |
| | key_value: (B, seq_len_kv, embed_dim) |
| | """ |
| | |
| | q = self.query_proj(query) |
| | k = self.key_proj(key_value) |
| | v = self.value_proj(key_value) |
| |
|
| | |
| | attn_output, attn_weights = self.multihead_attn(q, k, v) |
| |
|
| | |
| | output = self.norm(query + self.dropout(attn_output)) |
| |
|
| | return output, attn_weights |
| |
|
| | class TemporalTransformer(nn.Module): |
| | """ |
| | Temporal transformer for modeling sequences across time windows. |
| | """ |
| | def __init__(self, embed_dim=256, num_layers=4, num_heads=8): |
| | super().__init__() |
| | self.layers = nn.ModuleList([ |
| | nn.TransformerEncoderLayer( |
| | d_model=embed_dim, |
| | nhead=num_heads, |
| | dim_feedforward=embed_dim * 4, |
| | dropout=0.1, |
| | batch_first=True |
| | ) for _ in range(num_layers) |
| | ]) |
| |
|
| | self.norm = nn.LayerNorm(embed_dim) |
| |
|
| | def forward(self, x): |
| | """ |
| | x: (B, seq_len, embed_dim) - sequence of fused features over time |
| | """ |
| | for layer in self.layers: |
| | x = layer(x) |
| |
|
| | return self.norm(x) |
| |
|
| | class MultiModalFusion(nn.Module): |
| | """ |
| | Complete fusion network combining vision, audio, text with temporal modeling. |
| | """ |
| | def __init__(self, vision_dim=768, audio_dim=128, text_dim=768, embed_dim=256, |
| | num_emotions=7, num_intents=5): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| |
|
| | |
| | self.vision_proj = nn.Linear(vision_dim, embed_dim) |
| | self.audio_proj = nn.Linear(audio_dim, embed_dim) |
| | self.text_proj = nn.Linear(text_dim, embed_dim) |
| |
|
| | |
| | self.vision_to_audio_attn = CrossModalAttention(embed_dim) |
| | self.audio_to_text_attn = CrossModalAttention(embed_dim) |
| | self.text_to_vision_attn = CrossModalAttention(embed_dim) |
| |
|
| | |
| | self.temporal_transformer = TemporalTransformer(embed_dim) |
| |
|
| | |
| | self.modality_weights = nn.Parameter(torch.ones(3)) |
| |
|
| | |
| | self.emotion_classifier = nn.Linear(embed_dim, num_emotions) |
| | self.intent_classifier = nn.Linear(embed_dim, num_intents) |
| | self.engagement_regressor = nn.Linear(embed_dim, 1) |
| | self.confidence_regressor = nn.Linear(embed_dim, 1) |
| |
|
| | |
| | self.contribution_estimator = nn.Linear(embed_dim * 3, 3) |
| |
|
| | def forward(self, vision_features, audio_features, text_features, temporal_seq=False): |
| | """ |
| | vision_features: (B, vision_dim) or (B, T, vision_dim) |
| | audio_features: (B, audio_dim) or (B, T, audio_dim) |
| | text_features: (B, text_dim) or (B, T, text_dim) |
| | temporal_seq: whether inputs are temporal sequences |
| | """ |
| | |
| | v_proj = self.vision_proj(vision_features) |
| | a_proj = self.audio_proj(audio_features) |
| | t_proj = self.text_proj(text_features) |
| |
|
| | if temporal_seq: |
| | |
| | B, T, _ = v_proj.shape |
| |
|
| | |
| | v_flat = v_proj.view(B*T, 1, -1) |
| | a_flat = a_proj.view(B*T, 1, -1) |
| | t_flat = t_proj.view(B*T, 1, -1) |
| |
|
| | |
| | v_attn, _ = self.vision_to_audio_attn(v_flat, a_flat) |
| | a_attn, _ = self.audio_to_text_attn(a_flat, t_flat) |
| | t_attn, _ = self.text_to_vision_attn(t_flat, v_flat) |
| |
|
| | |
| | fused = (v_attn + a_attn + t_attn) / 3 |
| |
|
| | |
| | fused = fused.view(B, T, -1) |
| |
|
| | |
| | temporal_out = self.temporal_transformer(fused) |
| |
|
| | |
| | pooled = temporal_out[:, -1, :] |
| |
|
| | else: |
| | |
| | |
| | v_attn, _ = self.vision_to_audio_attn(v_proj.unsqueeze(1), a_proj.unsqueeze(1)) |
| | a_attn, _ = self.audio_to_text_attn(a_proj.unsqueeze(1), t_proj.unsqueeze(1)) |
| | t_attn, _ = self.text_to_vision_attn(t_proj.unsqueeze(1), v_proj.unsqueeze(1)) |
| |
|
| | |
| | weights = F.softmax(self.modality_weights, dim=0) |
| | fused = weights[0] * v_attn.squeeze(1) + \ |
| | weights[1] * a_attn.squeeze(1) + \ |
| | weights[2] * t_attn.squeeze(1) |
| |
|
| | pooled = fused |
| |
|
| | |
| | emotion_logits = self.emotion_classifier(pooled) |
| | intent_logits = self.intent_classifier(pooled) |
| | engagement = torch.sigmoid(self.engagement_regressor(pooled)) |
| | confidence = torch.sigmoid(self.confidence_regressor(pooled)) |
| |
|
| | |
| | contributions = torch.softmax(self.contribution_estimator( |
| | torch.cat([v_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True), |
| | a_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True), |
| | t_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True)], dim=-1) |
| | ), dim=-1) |
| |
|
| | return { |
| | 'emotion': emotion_logits, |
| | 'intent': intent_logits, |
| | 'engagement': engagement.squeeze(), |
| | 'confidence': confidence.squeeze(), |
| | 'contributions': contributions.squeeze() |
| | } |
| |
|
| | def get_modality_weights(self): |
| | """ |
| | Return normalized modality weights for explainability. |
| | """ |
| | return F.softmax(self.modality_weights, dim=0) |