File size: 6,734 Bytes
25d0747 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossModalAttention(nn.Module):
"""
Cross-modal attention mechanism for fusing vision, audio, and text features.
"""
def __init__(self, embed_dim=256, num_heads=8):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query_proj = nn.Linear(embed_dim, embed_dim)
self.key_proj = nn.Linear(embed_dim, embed_dim)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(0.1)
def forward(self, query, key_value):
"""
query: (B, seq_len_q, embed_dim)
key_value: (B, seq_len_kv, embed_dim)
"""
# Project to attention space
q = self.query_proj(query)
k = self.key_proj(key_value)
v = self.value_proj(key_value)
# Multi-head attention
attn_output, attn_weights = self.multihead_attn(q, k, v)
# Residual connection and normalization
output = self.norm(query + self.dropout(attn_output))
return output, attn_weights
class TemporalTransformer(nn.Module):
"""
Temporal transformer for modeling sequences across time windows.
"""
def __init__(self, embed_dim=256, num_layers=4, num_heads=8):
super().__init__()
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=0.1,
batch_first=True
) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
"""
x: (B, seq_len, embed_dim) - sequence of fused features over time
"""
for layer in self.layers:
x = layer(x)
return self.norm(x)
class MultiModalFusion(nn.Module):
"""
Complete fusion network combining vision, audio, text with temporal modeling.
"""
def __init__(self, vision_dim=768, audio_dim=128, text_dim=768, embed_dim=256,
num_emotions=7, num_intents=5):
super().__init__()
self.embed_dim = embed_dim
# Modality projectors
self.vision_proj = nn.Linear(vision_dim, embed_dim)
self.audio_proj = nn.Linear(audio_dim, embed_dim)
self.text_proj = nn.Linear(text_dim, embed_dim)
# Cross-modal attention layers
self.vision_to_audio_attn = CrossModalAttention(embed_dim)
self.audio_to_text_attn = CrossModalAttention(embed_dim)
self.text_to_vision_attn = CrossModalAttention(embed_dim)
# Temporal modeling
self.temporal_transformer = TemporalTransformer(embed_dim)
# Dynamic modality weighting
self.modality_weights = nn.Parameter(torch.ones(3)) # vision, audio, text
# Output heads
self.emotion_classifier = nn.Linear(embed_dim, num_emotions)
self.intent_classifier = nn.Linear(embed_dim, num_intents)
self.engagement_regressor = nn.Linear(embed_dim, 1)
self.confidence_regressor = nn.Linear(embed_dim, 1)
# Modality contribution estimator
self.contribution_estimator = nn.Linear(embed_dim * 3, 3) # weights for each modality
def forward(self, vision_features, audio_features, text_features, temporal_seq=False):
"""
vision_features: (B, vision_dim) or (B, T, vision_dim)
audio_features: (B, audio_dim) or (B, T, audio_dim)
text_features: (B, text_dim) or (B, T, text_dim)
temporal_seq: whether inputs are temporal sequences
"""
# Project to common embedding space
v_proj = self.vision_proj(vision_features) # (B, embed_dim) or (B, T, embed_dim)
a_proj = self.audio_proj(audio_features)
t_proj = self.text_proj(text_features)
if temporal_seq:
# Handle temporal sequences
B, T, _ = v_proj.shape
# Reshape for attention: (B*T, 1, embed_dim)
v_flat = v_proj.view(B*T, 1, -1)
a_flat = a_proj.view(B*T, 1, -1)
t_flat = t_proj.view(B*T, 1, -1)
# Cross-modal attention
v_attn, _ = self.vision_to_audio_attn(v_flat, a_flat)
a_attn, _ = self.audio_to_text_attn(a_flat, t_flat)
t_attn, _ = self.text_to_vision_attn(t_flat, v_flat)
# Combine attended features
fused = (v_attn + a_attn + t_attn) / 3 # (B*T, 1, embed_dim)
# Reshape back to temporal: (B, T, embed_dim)
fused = fused.view(B, T, -1)
# Temporal transformer
temporal_out = self.temporal_transformer(fused) # (B, T, embed_dim)
# Pool temporal dimension (take last timestep or mean)
pooled = temporal_out[:, -1, :] # (B, embed_dim)
else:
# Single timestep fusion
# Cross-modal attention
v_attn, _ = self.vision_to_audio_attn(v_proj.unsqueeze(1), a_proj.unsqueeze(1))
a_attn, _ = self.audio_to_text_attn(a_proj.unsqueeze(1), t_proj.unsqueeze(1))
t_attn, _ = self.text_to_vision_attn(t_proj.unsqueeze(1), v_proj.unsqueeze(1))
# Weighted fusion
weights = F.softmax(self.modality_weights, dim=0)
fused = weights[0] * v_attn.squeeze(1) + \
weights[1] * a_attn.squeeze(1) + \
weights[2] * t_attn.squeeze(1)
pooled = fused
# Output predictions
emotion_logits = self.emotion_classifier(pooled)
intent_logits = self.intent_classifier(pooled)
engagement = torch.sigmoid(self.engagement_regressor(pooled))
confidence = torch.sigmoid(self.confidence_regressor(pooled))
# Modality contributions
contributions = torch.softmax(self.contribution_estimator(
torch.cat([v_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True),
a_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True),
t_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True)], dim=-1)
), dim=-1)
return {
'emotion': emotion_logits,
'intent': intent_logits,
'engagement': engagement.squeeze(),
'confidence': confidence.squeeze(),
'contributions': contributions.squeeze()
}
def get_modality_weights(self):
"""
Return normalized modality weights for explainability.
"""
return F.softmax(self.modality_weights, dim=0) |