""" SyncFuse — our proposed method for T1 scene recognition. Four components (all toggleable via args for ablation): (1) Modality dropout: per-sample independent Bernoulli(p=0.3) drop on each modality during training; at test time all modalities are active. Keeps at least 1 modality. (2) Pretrained transfer: each per-modality backbone is optionally loaded from an independently pretrained single-modality checkpoint and frozen during fine-tuning. (3) Cross-modal temporal-shift attention: a late cross-attention block where EMG queries attend to MoCap keys/values at a LEARNED temporal offset Δ (Gumbel-softmax over {-10,...,+10} bins at 20 Hz = ±500 ms). Motivated by the paper's case-study finding (EMG leads motion by ~20 ms sub-frame). (4) Learnable late fusion: per-modality classifier logits are combined with a learnable softmax-weighted average (temperature is also learned). Equivalent to `late_agg='learned'` in the repo's existing LateFusionModel. """ import torch import torch.nn as nn import torch.nn.functional as F import random def masked_mean(x, mask): m = mask.unsqueeze(-1).float() return (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) # --------------------------------------------------------------------------- # Per-modality Transformer branch (same as repo's TransformerBackbone) # --------------------------------------------------------------------------- class ModTransformer(nn.Module): def __init__(self, feat_dim, hidden=128, n_layers=2, n_heads=4, dropout=0.1): super().__init__() self.in_proj = nn.Linear(feat_dim, hidden) self.pos = nn.Parameter(torch.zeros(1, 4096, hidden)) nn.init.trunc_normal_(self.pos, std=0.02) layer = nn.TransformerEncoderLayer( d_model=hidden, nhead=n_heads, dim_feedforward=4 * hidden, dropout=dropout, batch_first=True, activation='gelu', ) self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers) self.output_dim = hidden def forward(self, x, mask): # x: (B, T, feat_dim) T = x.size(1) h = self.in_proj(x) + self.pos[:, :T, :] h = self.encoder(h, src_key_padding_mask=~mask) return h # (B, T, hidden) — token-level, NOT pooled # --------------------------------------------------------------------------- # (3) Cross-modal temporal-shift attention # --------------------------------------------------------------------------- class TemporalShiftAttention(nn.Module): """Multi-head attention where queries are temporally shifted by a learned offset Δ from the keys. Δ is drawn from a discrete set {-3,...,+3} via straight-through Gumbel-softmax: we sample ONE shift per forward pass, but the softmax weights flow gradient back through shift_logits. At 20 Hz bins, ±3 ≈ ±150 ms, which brackets the paper's ~20 ms EMG-motion lead. Memory cost is ~1 attention pass (not 7).""" def __init__(self, d_model, n_heads=4, dropout=0.1, max_shift=3, gumbel_tau=1.0): super().__init__() self.max_shift = max_shift self.shifts = list(range(-max_shift, max_shift + 1)) self.shift_logits = nn.Parameter(torch.zeros(len(self.shifts))) self.tau = gumbel_tau self.attn = nn.MultiheadAttention( d_model, n_heads, dropout=dropout, batch_first=True, ) self.norm = nn.LayerNorm(d_model) def _shift_tensor(self, x, shift, mask): if shift == 0: return x, mask B, T, D = x.shape if shift > 0: pad = torch.zeros(B, shift, D, device=x.device, dtype=x.dtype) x_s = torch.cat([x[:, shift:, :], pad], dim=1) m_s = torch.cat([mask[:, shift:], torch.zeros(B, shift, device=mask.device, dtype=torch.bool)], dim=1) else: s = -shift pad = torch.zeros(B, s, D, device=x.device, dtype=x.dtype) x_s = torch.cat([pad, x[:, :-s, :]], dim=1) m_s = torch.cat([torch.zeros(B, s, device=mask.device, dtype=torch.bool), mask[:, :-s]], dim=1) return x_s, m_s def forward(self, q_tokens, kv_tokens, q_mask, kv_mask, hard=False): if hard or not self.training: # Eval: take the argmax shift with torch.no_grad(): idx = self.shift_logits.argmax().item() shift = self.shifts[idx] shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask) out, _ = self.attn(q_tokens, shifted_kv, shifted_kv, key_padding_mask=~shifted_mask) return self.norm(q_tokens + out) # Training: straight-through Gumbel-softmax to sample 1 shift, # with gradient flowing via softmax weights. one_hot = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True) # pick the sampled shift (argmax of the hard one-hot) idx = int(one_hot.argmax().item()) shift = self.shifts[idx] shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask) out, _ = self.attn(q_tokens, shifted_kv, shifted_kv, key_padding_mask=~shifted_mask) # scale out by the corresponding soft weight to let gradient flow out = out * one_hot[idx] return self.norm(q_tokens + out) # --------------------------------------------------------------------------- # SyncFuse main model # --------------------------------------------------------------------------- class SyncFuse(nn.Module): def __init__(self, modality_dims: dict, num_classes, hidden=128, n_heads=4, n_layers=2, dropout=0.1, use_xmod_shift=True, use_learned_late=True): super().__init__() self.mod_names = list(modality_dims.keys()) self.mod_dims = modality_dims self.use_xmod_shift = use_xmod_shift self.use_learned_late = use_learned_late self.branches = nn.ModuleDict({ m: ModTransformer(d, hidden, n_layers, n_heads, dropout) for m, d in modality_dims.items() }) self.classifiers = nn.ModuleDict({ m: nn.Sequential(nn.LayerNorm(hidden), nn.Dropout(dropout), nn.Linear(hidden, num_classes)) for m in self.mod_names }) # Cross-modal temporal-shift: apply to EMG branch attending to MoCap # (and symmetrically MoCap->EMG), only when both modalities are present. if use_xmod_shift and 'emg' in self.mod_names and 'mocap' in self.mod_names: self.xmod_emg2mocap = TemporalShiftAttention(hidden, n_heads, dropout) self.xmod_mocap2emg = TemporalShiftAttention(hidden, n_heads, dropout) else: self.xmod_emg2mocap = None self.xmod_mocap2emg = None if use_learned_late: self.late_logits = nn.Parameter(torch.zeros(len(self.mod_names))) self.late_temperature = nn.Parameter(torch.ones(1)) def load_pretrained(self, pretrain_paths: dict, freeze=True): """Load pretrained single-modality checkpoints into branches. pretrain_paths: {modality_name: path_to_checkpoint_state_dict}.""" import torch as _torch for m, path in pretrain_paths.items(): if m not in self.branches: continue try: sd = _torch.load(path, weights_only=True, map_location='cpu') except TypeError: sd = _torch.load(path, map_location='cpu') # Map SingleModel keys ("backbone.X.*") -> branch keys mapped = {} for k, v in sd.items(): if k.startswith('backbone.'): new_k = k.replace('backbone.', '') if new_k in self.branches[m].state_dict(): mapped[new_k] = v if mapped: self.branches[m].load_state_dict(mapped, strict=False) if freeze: for p in self.branches[m].parameters(): p.requires_grad = False print(f" [SyncFuse] loaded {len(mapped)} tensors into branch '{m}' (frozen={freeze})") def forward(self, x, mask, mod_dropout_p=0.0, training_time=True): """ x: (B, T, F_total) concatenated features mask: (B, T) mod_dropout_p: probability of dropping each modality (training only) """ B, T, _ = x.shape # Slice modality features offset = 0 feats = {} for m in self.mod_names: d = self.mod_dims[m] feats[m] = x[..., offset:offset + d] offset += d # (1) Modality dropout — per sample, independent per modality active = {m: torch.ones(B, dtype=torch.bool, device=x.device) for m in self.mod_names} if training_time and self.training and mod_dropout_p > 0: drop_map = {m: (torch.rand(B, device=x.device) < mod_dropout_p) for m in self.mod_names} all_dropped = torch.stack([drop_map[m] for m in self.mod_names], dim=0).all(dim=0) # (B,) if all_dropped.any(): # for all-dropped samples, un-drop one random modality rescue_idx = torch.randint(0, len(self.mod_names), (all_dropped.sum().item(),), device=x.device) mod_name_tensor = self.mod_names # python list j = 0 for b in range(B): if all_dropped[b]: r = mod_name_tensor[rescue_idx[j].item()] drop_map[r][b] = False j += 1 for m in self.mod_names: active[m] = ~drop_map[m] # zero out dropped features for that branch feats[m] = feats[m] * active[m].view(B, 1, 1).float() # Per-modality encoding tokens = {} for m in self.mod_names: tokens[m] = self.branches[m](feats[m], mask) # (B, T, hidden) # (3) Cross-modal temporal-shift (bidirectional EMG <-> MoCap) if self.xmod_emg2mocap is not None: tokens['emg'] = self.xmod_emg2mocap( tokens['emg'], tokens['mocap'], mask, mask, hard=not self.training, ) tokens['mocap'] = self.xmod_mocap2emg( tokens['mocap'], tokens['emg'], mask, mask, hard=not self.training, ) # Pool and classify per modality logits_per = [] for m in self.mod_names: pooled = masked_mean(tokens[m], mask) logits_per.append(self.classifiers[m](pooled)) stacked = torch.stack(logits_per, dim=0) # (M, B, C) # Mask out logits from dropped modalities (so they don't dominate) if training_time and self.training and mod_dropout_p > 0: act_mask = torch.stack([active[m].float() for m in self.mod_names], dim=0) # (M, B) # Re-normalize weights across active modalities if self.use_learned_late: w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0) w = w.view(-1, 1) * act_mask # (M, B) w = w / w.sum(dim=0, keepdim=True).clamp(min=1e-6) out = (stacked * w.unsqueeze(-1)).sum(dim=0) else: w = act_mask / act_mask.sum(dim=0, keepdim=True).clamp(min=1e-6) out = (stacked * w.unsqueeze(-1)).sum(dim=0) else: # (4) Learnable late fusion (or simple mean) if self.use_learned_late: w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0) out = (stacked * w.view(-1, 1, 1)).sum(dim=0) else: out = stacked.mean(dim=0) return out