velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
"""
SyncFuse — our proposed method for T1 scene recognition.
Four components (all toggleable via args for ablation):
(1) Modality dropout: per-sample independent Bernoulli(p=0.3) drop on each
modality during training; at test time all modalities
are active. Keeps at least 1 modality.
(2) Pretrained transfer: each per-modality backbone is optionally loaded from
an independently pretrained single-modality
checkpoint and frozen during fine-tuning.
(3) Cross-modal temporal-shift attention:
a late cross-attention block where EMG queries
attend to MoCap keys/values at a LEARNED temporal
offset Δ (Gumbel-softmax over {-10,...,+10} bins at
20 Hz = ±500 ms). Motivated by the paper's case-study
finding (EMG leads motion by ~20 ms sub-frame).
(4) Learnable late fusion:
per-modality classifier logits are combined with a
learnable softmax-weighted average (temperature is
also learned). Equivalent to `late_agg='learned'`
in the repo's existing LateFusionModel.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
def masked_mean(x, mask):
m = mask.unsqueeze(-1).float()
return (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
# ---------------------------------------------------------------------------
# Per-modality Transformer branch (same as repo's TransformerBackbone)
# ---------------------------------------------------------------------------
class ModTransformer(nn.Module):
def __init__(self, feat_dim, hidden=128, n_layers=2, n_heads=4, dropout=0.1):
super().__init__()
self.in_proj = nn.Linear(feat_dim, hidden)
self.pos = nn.Parameter(torch.zeros(1, 4096, hidden))
nn.init.trunc_normal_(self.pos, std=0.02)
layer = nn.TransformerEncoderLayer(
d_model=hidden, nhead=n_heads, dim_feedforward=4 * hidden,
dropout=dropout, batch_first=True, activation='gelu',
)
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
self.output_dim = hidden
def forward(self, x, mask):
# x: (B, T, feat_dim)
T = x.size(1)
h = self.in_proj(x) + self.pos[:, :T, :]
h = self.encoder(h, src_key_padding_mask=~mask)
return h # (B, T, hidden) — token-level, NOT pooled
# ---------------------------------------------------------------------------
# (3) Cross-modal temporal-shift attention
# ---------------------------------------------------------------------------
class TemporalShiftAttention(nn.Module):
"""Multi-head attention where queries are temporally shifted by a learned
offset Δ from the keys. Δ is drawn from a discrete set {-3,...,+3} via
straight-through Gumbel-softmax: we sample ONE shift per forward pass,
but the softmax weights flow gradient back through shift_logits.
At 20 Hz bins, ±3 ≈ ±150 ms, which brackets the paper's ~20 ms EMG-motion
lead. Memory cost is ~1 attention pass (not 7)."""
def __init__(self, d_model, n_heads=4, dropout=0.1, max_shift=3,
gumbel_tau=1.0):
super().__init__()
self.max_shift = max_shift
self.shifts = list(range(-max_shift, max_shift + 1))
self.shift_logits = nn.Parameter(torch.zeros(len(self.shifts)))
self.tau = gumbel_tau
self.attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True,
)
self.norm = nn.LayerNorm(d_model)
def _shift_tensor(self, x, shift, mask):
if shift == 0:
return x, mask
B, T, D = x.shape
if shift > 0:
pad = torch.zeros(B, shift, D, device=x.device, dtype=x.dtype)
x_s = torch.cat([x[:, shift:, :], pad], dim=1)
m_s = torch.cat([mask[:, shift:],
torch.zeros(B, shift, device=mask.device, dtype=torch.bool)],
dim=1)
else:
s = -shift
pad = torch.zeros(B, s, D, device=x.device, dtype=x.dtype)
x_s = torch.cat([pad, x[:, :-s, :]], dim=1)
m_s = torch.cat([torch.zeros(B, s, device=mask.device, dtype=torch.bool),
mask[:, :-s]], dim=1)
return x_s, m_s
def forward(self, q_tokens, kv_tokens, q_mask, kv_mask, hard=False):
if hard or not self.training:
# Eval: take the argmax shift
with torch.no_grad():
idx = self.shift_logits.argmax().item()
shift = self.shifts[idx]
shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask)
out, _ = self.attn(q_tokens, shifted_kv, shifted_kv,
key_padding_mask=~shifted_mask)
return self.norm(q_tokens + out)
# Training: straight-through Gumbel-softmax to sample 1 shift,
# with gradient flowing via softmax weights.
one_hot = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True)
# pick the sampled shift (argmax of the hard one-hot)
idx = int(one_hot.argmax().item())
shift = self.shifts[idx]
shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask)
out, _ = self.attn(q_tokens, shifted_kv, shifted_kv,
key_padding_mask=~shifted_mask)
# scale out by the corresponding soft weight to let gradient flow
out = out * one_hot[idx]
return self.norm(q_tokens + out)
# ---------------------------------------------------------------------------
# SyncFuse main model
# ---------------------------------------------------------------------------
class SyncFuse(nn.Module):
def __init__(self, modality_dims: dict, num_classes, hidden=128, n_heads=4,
n_layers=2, dropout=0.1,
use_xmod_shift=True, use_learned_late=True):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = modality_dims
self.use_xmod_shift = use_xmod_shift
self.use_learned_late = use_learned_late
self.branches = nn.ModuleDict({
m: ModTransformer(d, hidden, n_layers, n_heads, dropout)
for m, d in modality_dims.items()
})
self.classifiers = nn.ModuleDict({
m: nn.Sequential(nn.LayerNorm(hidden), nn.Dropout(dropout),
nn.Linear(hidden, num_classes))
for m in self.mod_names
})
# Cross-modal temporal-shift: apply to EMG branch attending to MoCap
# (and symmetrically MoCap->EMG), only when both modalities are present.
if use_xmod_shift and 'emg' in self.mod_names and 'mocap' in self.mod_names:
self.xmod_emg2mocap = TemporalShiftAttention(hidden, n_heads, dropout)
self.xmod_mocap2emg = TemporalShiftAttention(hidden, n_heads, dropout)
else:
self.xmod_emg2mocap = None
self.xmod_mocap2emg = None
if use_learned_late:
self.late_logits = nn.Parameter(torch.zeros(len(self.mod_names)))
self.late_temperature = nn.Parameter(torch.ones(1))
def load_pretrained(self, pretrain_paths: dict, freeze=True):
"""Load pretrained single-modality checkpoints into branches.
pretrain_paths: {modality_name: path_to_checkpoint_state_dict}."""
import torch as _torch
for m, path in pretrain_paths.items():
if m not in self.branches:
continue
try:
sd = _torch.load(path, weights_only=True, map_location='cpu')
except TypeError:
sd = _torch.load(path, map_location='cpu')
# Map SingleModel keys ("backbone.X.*") -> branch keys
mapped = {}
for k, v in sd.items():
if k.startswith('backbone.'):
new_k = k.replace('backbone.', '')
if new_k in self.branches[m].state_dict():
mapped[new_k] = v
if mapped:
self.branches[m].load_state_dict(mapped, strict=False)
if freeze:
for p in self.branches[m].parameters():
p.requires_grad = False
print(f" [SyncFuse] loaded {len(mapped)} tensors into branch '{m}' (frozen={freeze})")
def forward(self, x, mask, mod_dropout_p=0.0, training_time=True):
"""
x: (B, T, F_total) concatenated features
mask: (B, T)
mod_dropout_p: probability of dropping each modality (training only)
"""
B, T, _ = x.shape
# Slice modality features
offset = 0
feats = {}
for m in self.mod_names:
d = self.mod_dims[m]
feats[m] = x[..., offset:offset + d]
offset += d
# (1) Modality dropout — per sample, independent per modality
active = {m: torch.ones(B, dtype=torch.bool, device=x.device) for m in self.mod_names}
if training_time and self.training and mod_dropout_p > 0:
drop_map = {m: (torch.rand(B, device=x.device) < mod_dropout_p)
for m in self.mod_names}
all_dropped = torch.stack([drop_map[m] for m in self.mod_names], dim=0).all(dim=0) # (B,)
if all_dropped.any():
# for all-dropped samples, un-drop one random modality
rescue_idx = torch.randint(0, len(self.mod_names),
(all_dropped.sum().item(),),
device=x.device)
mod_name_tensor = self.mod_names # python list
j = 0
for b in range(B):
if all_dropped[b]:
r = mod_name_tensor[rescue_idx[j].item()]
drop_map[r][b] = False
j += 1
for m in self.mod_names:
active[m] = ~drop_map[m]
# zero out dropped features for that branch
feats[m] = feats[m] * active[m].view(B, 1, 1).float()
# Per-modality encoding
tokens = {}
for m in self.mod_names:
tokens[m] = self.branches[m](feats[m], mask) # (B, T, hidden)
# (3) Cross-modal temporal-shift (bidirectional EMG <-> MoCap)
if self.xmod_emg2mocap is not None:
tokens['emg'] = self.xmod_emg2mocap(
tokens['emg'], tokens['mocap'], mask, mask,
hard=not self.training,
)
tokens['mocap'] = self.xmod_mocap2emg(
tokens['mocap'], tokens['emg'], mask, mask,
hard=not self.training,
)
# Pool and classify per modality
logits_per = []
for m in self.mod_names:
pooled = masked_mean(tokens[m], mask)
logits_per.append(self.classifiers[m](pooled))
stacked = torch.stack(logits_per, dim=0) # (M, B, C)
# Mask out logits from dropped modalities (so they don't dominate)
if training_time and self.training and mod_dropout_p > 0:
act_mask = torch.stack([active[m].float() for m in self.mod_names], dim=0) # (M, B)
# Re-normalize weights across active modalities
if self.use_learned_late:
w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0)
w = w.view(-1, 1) * act_mask # (M, B)
w = w / w.sum(dim=0, keepdim=True).clamp(min=1e-6)
out = (stacked * w.unsqueeze(-1)).sum(dim=0)
else:
w = act_mask / act_mask.sum(dim=0, keepdim=True).clamp(min=1e-6)
out = (stacked * w.unsqueeze(-1)).sum(dim=0)
else:
# (4) Learnable late fusion (or simple mean)
if self.use_learned_late:
w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0)
out = (stacked * w.view(-1, 1, 1)).sum(dim=0)
else:
out = stacked.mean(dim=0)
return out