File size: 12,404 Bytes
b4b2877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""
SyncFuse — our proposed method for T1 scene recognition.

Four components (all toggleable via args for ablation):

 (1) Modality dropout:    per-sample independent Bernoulli(p=0.3) drop on each
                          modality during training; at test time all modalities
                          are active. Keeps at least 1 modality.
 (2) Pretrained transfer: each per-modality backbone is optionally loaded from
                          an independently pretrained single-modality
                          checkpoint and frozen during fine-tuning.
 (3) Cross-modal temporal-shift attention:
                          a late cross-attention block where EMG queries
                          attend to MoCap keys/values at a LEARNED temporal
                          offset Δ (Gumbel-softmax over {-10,...,+10} bins at
                          20 Hz = ±500 ms). Motivated by the paper's case-study
                          finding (EMG leads motion by ~20 ms sub-frame).
 (4) Learnable late fusion:
                          per-modality classifier logits are combined with a
                          learnable softmax-weighted average (temperature is
                          also learned). Equivalent to `late_agg='learned'`
                          in the repo's existing LateFusionModel.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import random


def masked_mean(x, mask):
    m = mask.unsqueeze(-1).float()
    return (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)


# ---------------------------------------------------------------------------
# Per-modality Transformer branch (same as repo's TransformerBackbone)
# ---------------------------------------------------------------------------

class ModTransformer(nn.Module):
    def __init__(self, feat_dim, hidden=128, n_layers=2, n_heads=4, dropout=0.1):
        super().__init__()
        self.in_proj = nn.Linear(feat_dim, hidden)
        self.pos = nn.Parameter(torch.zeros(1, 4096, hidden))
        nn.init.trunc_normal_(self.pos, std=0.02)
        layer = nn.TransformerEncoderLayer(
            d_model=hidden, nhead=n_heads, dim_feedforward=4 * hidden,
            dropout=dropout, batch_first=True, activation='gelu',
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
        self.output_dim = hidden

    def forward(self, x, mask):
        # x: (B, T, feat_dim)
        T = x.size(1)
        h = self.in_proj(x) + self.pos[:, :T, :]
        h = self.encoder(h, src_key_padding_mask=~mask)
        return h  # (B, T, hidden) — token-level, NOT pooled


# ---------------------------------------------------------------------------
# (3) Cross-modal temporal-shift attention
# ---------------------------------------------------------------------------

class TemporalShiftAttention(nn.Module):
    """Multi-head attention where queries are temporally shifted by a learned
    offset Δ from the keys. Δ is drawn from a discrete set {-3,...,+3} via
    straight-through Gumbel-softmax: we sample ONE shift per forward pass,
    but the softmax weights flow gradient back through shift_logits.

    At 20 Hz bins, ±3 ≈ ±150 ms, which brackets the paper's ~20 ms EMG-motion
    lead. Memory cost is ~1 attention pass (not 7)."""
    def __init__(self, d_model, n_heads=4, dropout=0.1, max_shift=3,
                 gumbel_tau=1.0):
        super().__init__()
        self.max_shift = max_shift
        self.shifts = list(range(-max_shift, max_shift + 1))
        self.shift_logits = nn.Parameter(torch.zeros(len(self.shifts)))
        self.tau = gumbel_tau
        self.attn = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True,
        )
        self.norm = nn.LayerNorm(d_model)

    def _shift_tensor(self, x, shift, mask):
        if shift == 0:
            return x, mask
        B, T, D = x.shape
        if shift > 0:
            pad = torch.zeros(B, shift, D, device=x.device, dtype=x.dtype)
            x_s = torch.cat([x[:, shift:, :], pad], dim=1)
            m_s = torch.cat([mask[:, shift:],
                             torch.zeros(B, shift, device=mask.device, dtype=torch.bool)],
                            dim=1)
        else:
            s = -shift
            pad = torch.zeros(B, s, D, device=x.device, dtype=x.dtype)
            x_s = torch.cat([pad, x[:, :-s, :]], dim=1)
            m_s = torch.cat([torch.zeros(B, s, device=mask.device, dtype=torch.bool),
                             mask[:, :-s]], dim=1)
        return x_s, m_s

    def forward(self, q_tokens, kv_tokens, q_mask, kv_mask, hard=False):
        if hard or not self.training:
            # Eval: take the argmax shift
            with torch.no_grad():
                idx = self.shift_logits.argmax().item()
            shift = self.shifts[idx]
            shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask)
            out, _ = self.attn(q_tokens, shifted_kv, shifted_kv,
                               key_padding_mask=~shifted_mask)
            return self.norm(q_tokens + out)

        # Training: straight-through Gumbel-softmax to sample 1 shift,
        # with gradient flowing via softmax weights.
        one_hot = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True)
        # pick the sampled shift (argmax of the hard one-hot)
        idx = int(one_hot.argmax().item())
        shift = self.shifts[idx]
        shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask)
        out, _ = self.attn(q_tokens, shifted_kv, shifted_kv,
                           key_padding_mask=~shifted_mask)
        # scale out by the corresponding soft weight to let gradient flow
        out = out * one_hot[idx]
        return self.norm(q_tokens + out)


# ---------------------------------------------------------------------------
# SyncFuse main model
# ---------------------------------------------------------------------------

class SyncFuse(nn.Module):
    def __init__(self, modality_dims: dict, num_classes, hidden=128, n_heads=4,
                 n_layers=2, dropout=0.1,
                 use_xmod_shift=True, use_learned_late=True):
        super().__init__()
        self.mod_names = list(modality_dims.keys())
        self.mod_dims = modality_dims
        self.use_xmod_shift = use_xmod_shift
        self.use_learned_late = use_learned_late

        self.branches = nn.ModuleDict({
            m: ModTransformer(d, hidden, n_layers, n_heads, dropout)
            for m, d in modality_dims.items()
        })
        self.classifiers = nn.ModuleDict({
            m: nn.Sequential(nn.LayerNorm(hidden), nn.Dropout(dropout),
                             nn.Linear(hidden, num_classes))
            for m in self.mod_names
        })

        # Cross-modal temporal-shift: apply to EMG branch attending to MoCap
        # (and symmetrically MoCap->EMG), only when both modalities are present.
        if use_xmod_shift and 'emg' in self.mod_names and 'mocap' in self.mod_names:
            self.xmod_emg2mocap = TemporalShiftAttention(hidden, n_heads, dropout)
            self.xmod_mocap2emg = TemporalShiftAttention(hidden, n_heads, dropout)
        else:
            self.xmod_emg2mocap = None
            self.xmod_mocap2emg = None

        if use_learned_late:
            self.late_logits = nn.Parameter(torch.zeros(len(self.mod_names)))
            self.late_temperature = nn.Parameter(torch.ones(1))

    def load_pretrained(self, pretrain_paths: dict, freeze=True):
        """Load pretrained single-modality checkpoints into branches.
        pretrain_paths: {modality_name: path_to_checkpoint_state_dict}."""
        import torch as _torch
        for m, path in pretrain_paths.items():
            if m not in self.branches:
                continue
            try:
                sd = _torch.load(path, weights_only=True, map_location='cpu')
            except TypeError:
                sd = _torch.load(path, map_location='cpu')
            # Map SingleModel keys ("backbone.X.*") -> branch keys
            mapped = {}
            for k, v in sd.items():
                if k.startswith('backbone.'):
                    new_k = k.replace('backbone.', '')
                    if new_k in self.branches[m].state_dict():
                        mapped[new_k] = v
            if mapped:
                self.branches[m].load_state_dict(mapped, strict=False)
                if freeze:
                    for p in self.branches[m].parameters():
                        p.requires_grad = False
                print(f"  [SyncFuse] loaded {len(mapped)} tensors into branch '{m}' (frozen={freeze})")

    def forward(self, x, mask, mod_dropout_p=0.0, training_time=True):
        """
        x:    (B, T, F_total) concatenated features
        mask: (B, T)
        mod_dropout_p: probability of dropping each modality (training only)
        """
        B, T, _ = x.shape

        # Slice modality features
        offset = 0
        feats = {}
        for m in self.mod_names:
            d = self.mod_dims[m]
            feats[m] = x[..., offset:offset + d]
            offset += d

        # (1) Modality dropout — per sample, independent per modality
        active = {m: torch.ones(B, dtype=torch.bool, device=x.device) for m in self.mod_names}
        if training_time and self.training and mod_dropout_p > 0:
            drop_map = {m: (torch.rand(B, device=x.device) < mod_dropout_p)
                        for m in self.mod_names}
            all_dropped = torch.stack([drop_map[m] for m in self.mod_names], dim=0).all(dim=0)  # (B,)
            if all_dropped.any():
                # for all-dropped samples, un-drop one random modality
                rescue_idx = torch.randint(0, len(self.mod_names),
                                           (all_dropped.sum().item(),),
                                           device=x.device)
                mod_name_tensor = self.mod_names  # python list
                j = 0
                for b in range(B):
                    if all_dropped[b]:
                        r = mod_name_tensor[rescue_idx[j].item()]
                        drop_map[r][b] = False
                        j += 1
            for m in self.mod_names:
                active[m] = ~drop_map[m]
                # zero out dropped features for that branch
                feats[m] = feats[m] * active[m].view(B, 1, 1).float()

        # Per-modality encoding
        tokens = {}
        for m in self.mod_names:
            tokens[m] = self.branches[m](feats[m], mask)  # (B, T, hidden)

        # (3) Cross-modal temporal-shift (bidirectional EMG <-> MoCap)
        if self.xmod_emg2mocap is not None:
            tokens['emg'] = self.xmod_emg2mocap(
                tokens['emg'], tokens['mocap'], mask, mask,
                hard=not self.training,
            )
            tokens['mocap'] = self.xmod_mocap2emg(
                tokens['mocap'], tokens['emg'], mask, mask,
                hard=not self.training,
            )

        # Pool and classify per modality
        logits_per = []
        for m in self.mod_names:
            pooled = masked_mean(tokens[m], mask)
            logits_per.append(self.classifiers[m](pooled))
        stacked = torch.stack(logits_per, dim=0)  # (M, B, C)

        # Mask out logits from dropped modalities (so they don't dominate)
        if training_time and self.training and mod_dropout_p > 0:
            act_mask = torch.stack([active[m].float() for m in self.mod_names], dim=0)  # (M, B)
            # Re-normalize weights across active modalities
            if self.use_learned_late:
                w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0)
                w = w.view(-1, 1) * act_mask  # (M, B)
                w = w / w.sum(dim=0, keepdim=True).clamp(min=1e-6)
                out = (stacked * w.unsqueeze(-1)).sum(dim=0)
            else:
                w = act_mask / act_mask.sum(dim=0, keepdim=True).clamp(min=1e-6)
                out = (stacked * w.unsqueeze(-1)).sum(dim=0)
        else:
            # (4) Learnable late fusion (or simple mean)
            if self.use_learned_late:
                w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0)
                out = (stacked * w.view(-1, 1, 1)).sum(dim=0)
            else:
                out = stacked.mean(dim=0)
        return out