File size: 6,734 Bytes
25d0747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossModalAttention(nn.Module):
    """
    Cross-modal attention mechanism for fusing vision, audio, and text features.
    """
    def __init__(self, embed_dim=256, num_heads=8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(embed_dim, embed_dim)
        self.value_proj = nn.Linear(embed_dim, embed_dim)

        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, query, key_value):
        """
        query: (B, seq_len_q, embed_dim)
        key_value: (B, seq_len_kv, embed_dim)
        """
        # Project to attention space
        q = self.query_proj(query)
        k = self.key_proj(key_value)
        v = self.value_proj(key_value)

        # Multi-head attention
        attn_output, attn_weights = self.multihead_attn(q, k, v)

        # Residual connection and normalization
        output = self.norm(query + self.dropout(attn_output))

        return output, attn_weights

class TemporalTransformer(nn.Module):
    """
    Temporal transformer for modeling sequences across time windows.
    """
    def __init__(self, embed_dim=256, num_layers=4, num_heads=8):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=embed_dim * 4,
                dropout=0.1,
                batch_first=True
            ) for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        x: (B, seq_len, embed_dim) - sequence of fused features over time
        """
        for layer in self.layers:
            x = layer(x)

        return self.norm(x)

class MultiModalFusion(nn.Module):
    """
    Complete fusion network combining vision, audio, text with temporal modeling.
    """
    def __init__(self, vision_dim=768, audio_dim=128, text_dim=768, embed_dim=256,
                 num_emotions=7, num_intents=5):
        super().__init__()
        self.embed_dim = embed_dim

        # Modality projectors
        self.vision_proj = nn.Linear(vision_dim, embed_dim)
        self.audio_proj = nn.Linear(audio_dim, embed_dim)
        self.text_proj = nn.Linear(text_dim, embed_dim)

        # Cross-modal attention layers
        self.vision_to_audio_attn = CrossModalAttention(embed_dim)
        self.audio_to_text_attn = CrossModalAttention(embed_dim)
        self.text_to_vision_attn = CrossModalAttention(embed_dim)

        # Temporal modeling
        self.temporal_transformer = TemporalTransformer(embed_dim)

        # Dynamic modality weighting
        self.modality_weights = nn.Parameter(torch.ones(3))  # vision, audio, text

        # Output heads
        self.emotion_classifier = nn.Linear(embed_dim, num_emotions)
        self.intent_classifier = nn.Linear(embed_dim, num_intents)
        self.engagement_regressor = nn.Linear(embed_dim, 1)
        self.confidence_regressor = nn.Linear(embed_dim, 1)

        # Modality contribution estimator
        self.contribution_estimator = nn.Linear(embed_dim * 3, 3)  # weights for each modality

    def forward(self, vision_features, audio_features, text_features, temporal_seq=False):
        """
        vision_features: (B, vision_dim) or (B, T, vision_dim)
        audio_features: (B, audio_dim) or (B, T, audio_dim)
        text_features: (B, text_dim) or (B, T, text_dim)
        temporal_seq: whether inputs are temporal sequences
        """
        # Project to common embedding space
        v_proj = self.vision_proj(vision_features)  # (B, embed_dim) or (B, T, embed_dim)
        a_proj = self.audio_proj(audio_features)
        t_proj = self.text_proj(text_features)

        if temporal_seq:
            # Handle temporal sequences
            B, T, _ = v_proj.shape

            # Reshape for attention: (B*T, 1, embed_dim)
            v_flat = v_proj.view(B*T, 1, -1)
            a_flat = a_proj.view(B*T, 1, -1)
            t_flat = t_proj.view(B*T, 1, -1)

            # Cross-modal attention
            v_attn, _ = self.vision_to_audio_attn(v_flat, a_flat)
            a_attn, _ = self.audio_to_text_attn(a_flat, t_flat)
            t_attn, _ = self.text_to_vision_attn(t_flat, v_flat)

            # Combine attended features
            fused = (v_attn + a_attn + t_attn) / 3  # (B*T, 1, embed_dim)

            # Reshape back to temporal: (B, T, embed_dim)
            fused = fused.view(B, T, -1)

            # Temporal transformer
            temporal_out = self.temporal_transformer(fused)  # (B, T, embed_dim)

            # Pool temporal dimension (take last timestep or mean)
            pooled = temporal_out[:, -1, :]  # (B, embed_dim)

        else:
            # Single timestep fusion
            # Cross-modal attention
            v_attn, _ = self.vision_to_audio_attn(v_proj.unsqueeze(1), a_proj.unsqueeze(1))
            a_attn, _ = self.audio_to_text_attn(a_proj.unsqueeze(1), t_proj.unsqueeze(1))
            t_attn, _ = self.text_to_vision_attn(t_proj.unsqueeze(1), v_proj.unsqueeze(1))

            # Weighted fusion
            weights = F.softmax(self.modality_weights, dim=0)
            fused = weights[0] * v_attn.squeeze(1) + \
                   weights[1] * a_attn.squeeze(1) + \
                   weights[2] * t_attn.squeeze(1)

            pooled = fused

        # Output predictions
        emotion_logits = self.emotion_classifier(pooled)
        intent_logits = self.intent_classifier(pooled)
        engagement = torch.sigmoid(self.engagement_regressor(pooled))
        confidence = torch.sigmoid(self.confidence_regressor(pooled))

        # Modality contributions
        contributions = torch.softmax(self.contribution_estimator(
            torch.cat([v_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True),
                      a_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True),
                      t_proj.mean(dim=-1 if temporal_seq else 0, keepdim=True)], dim=-1)
        ), dim=-1)

        return {
            'emotion': emotion_logits,
            'intent': intent_logits,
            'engagement': engagement.squeeze(),
            'confidence': confidence.squeeze(),
            'contributions': contributions.squeeze()
        }

    def get_modality_weights(self):
        """
        Return normalized modality weights for explainability.
        """
        return F.softmax(self.modality_weights, dim=0)