""" Published baselines for T1 Scene Recognition, reproduced on DailyAct-5M. Each method accepts a concatenated feature tensor (B, T, F_total) where F_total is the sum of the active modality dims; the per-modality slices are recorded in the `modality_dims` dict. Each method then uses the subset of modalities its original paper intended. All methods output an (B, num_classes) logit tensor. """ import math import torch import torch.nn as nn import torch.nn.functional as F def _slice(x, mod_dims, wanted): """Slice the concatenated feature tensor to keep only `wanted` modalities, in the order given. mod_dims is an ordered dict. Returns {name: tensor(B,T,d_name)} plus the concat.""" parts = {} offset = 0 for name, d in mod_dims.items(): if name in wanted: parts[name] = x[..., offset:offset + d] offset += d assert len(parts) > 0, f"None of {wanted} in {list(mod_dims.keys())}" return parts # --------------------------------------------------------------------------- # 1) ST-GCN (Yan et al., AAAI 2018) # Spatio-temporal graph CNN for skeleton action recognition. # We treat the 56-joint MoCap skeleton as the graph. # --------------------------------------------------------------------------- class STGCNBlock(nn.Module): def __init__(self, in_ch, out_ch, n_joints, stride=1, dropout=0.2): super().__init__() # Spatial graph conv: learnable adjacency (fully learned, no handcrafted A) self.A = nn.Parameter(torch.eye(n_joints) + 0.1 * torch.randn(n_joints, n_joints)) self.spatial = nn.Conv2d(in_ch, out_ch, kernel_size=(1, 1), bias=False) self.spatial_bn = nn.BatchNorm2d(out_ch) self.temporal = nn.Conv2d(out_ch, out_ch, kernel_size=(9, 1), padding=(4, 0), stride=(stride, 1)) self.temporal_bn = nn.BatchNorm2d(out_ch) self.dropout = nn.Dropout(dropout) if in_ch != out_ch or stride != 1: self.res = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=(stride, 1)) else: self.res = nn.Identity() def forward(self, x): # x: (B, C, T, V) res = self.res(x) # spatial: aggregate along joints via A h = self.spatial(x) h = torch.einsum('bctv,vw->bctw', h, F.softmax(self.A, dim=-1)) h = self.spatial_bn(h) h = F.relu(h) # temporal h = self.temporal(h) h = self.temporal_bn(h) h = self.dropout(h) return F.relu(h + res) class STGCN(nn.Module): """ST-GCN on MoCap skeleton. We assume the MoCap modality is 620-dim (hip-relative + velocity) and reshape to ~56 joints.""" def __init__(self, feat_dim_mocap, num_classes, hidden=64, n_joints=52): super().__init__() self.n_joints = n_joints # MoCap feat is (T, 620). 52 joints × 4 (xyz+quat_type), or we take per-joint xyz-only = 156. # In this repo, 620 = 52 markers * 4 cols + velocity features. We'll # reshape by slicing to 3*52=156 "primary" coords, padded if needed. self.coord_dim = 3 # we'll treat each joint as having 3 coords (XYZ) self.proj_in = nn.Linear(feat_dim_mocap, n_joints * self.coord_dim) self.blocks = nn.ModuleList([ STGCNBlock(self.coord_dim, hidden, n_joints), STGCNBlock(hidden, hidden, n_joints), STGCNBlock(hidden, hidden * 2, n_joints, stride=2), STGCNBlock(hidden * 2, hidden * 2, n_joints), STGCNBlock(hidden * 2, hidden * 4, n_joints, stride=2), STGCNBlock(hidden * 4, hidden * 4, n_joints), ]) self.head = nn.Sequential( nn.Dropout(0.3), nn.Linear(hidden * 4, num_classes), ) def forward(self, x_mocap, mask=None): # x_mocap: (B, T, feat_dim_mocap) B, T, _ = x_mocap.shape h = self.proj_in(x_mocap) # (B, T, n_joints * 3) h = h.reshape(B, T, self.n_joints, self.coord_dim).permute(0, 3, 1, 2) # (B, C, T, V) for blk in self.blocks: h = blk(h) # Global mean pool over time & joints (with mask if provided) if mask is not None: # mask: (B, T), h: (B, C, T', V) where T' may be < T due to stride T_ = h.shape[2] m = mask[:, :T_].float().unsqueeze(1).unsqueeze(-1) # (B, 1, T', 1) h = (h * m).sum(dim=(2, 3)) / (m.sum(dim=(2, 3)) * h.shape[3] + 1e-8) else: h = h.mean(dim=(2, 3)) return self.head(h) # --------------------------------------------------------------------------- # 2) CTR-GCN (Chen et al., ICCV 2021) # Channel-wise Topology Refinement GCN — learns a separate adjacency # matrix per channel group, known as SOTA for skeleton action recognition. # --------------------------------------------------------------------------- class CTRGC(nn.Module): """Simplified CTR-GC block: learnable per-channel topology refinement.""" def __init__(self, in_ch, out_ch, n_joints, rel_reduction=4): super().__init__() self.n_joints = n_joints self.conv1 = nn.Conv2d(in_ch, out_ch // rel_reduction, 1) self.conv2 = nn.Conv2d(in_ch, out_ch // rel_reduction, 1) self.conv3 = nn.Conv2d(in_ch, out_ch, 1) self.alpha = nn.Parameter(torch.zeros(1)) self.A = nn.Parameter(torch.eye(n_joints) + 0.1 * torch.randn(n_joints, n_joints)) def forward(self, x): # x: (B, C, T, V) q = self.conv1(x).mean(dim=2) # (B, C', V) k = self.conv2(x).mean(dim=2) # (B, C', V) v = self.conv3(x) # (B, C_out, T, V) # Channel-specific topology refinement topology = F.softmax(torch.tanh(q.unsqueeze(-1) - k.unsqueeze(-2)), dim=-1) # topology: (B, C', V, V); we average across channels to get a shared (B, V, V) topology = topology.mean(dim=1) A = self.A.unsqueeze(0) + self.alpha * topology # apply A to v out = torch.einsum('bctv,bvw->bctw', v, A) return out class CTRGCNBlock(nn.Module): def __init__(self, in_ch, out_ch, n_joints, stride=1): super().__init__() self.gc = CTRGC(in_ch, out_ch, n_joints) self.bn = nn.BatchNorm2d(out_ch) self.tcn = nn.Sequential( nn.Conv2d(out_ch, out_ch, (9, 1), padding=(4, 0), stride=(stride, 1)), nn.BatchNorm2d(out_ch), ) if in_ch != out_ch or stride != 1: self.res = nn.Conv2d(in_ch, out_ch, 1, stride=(stride, 1)) else: self.res = nn.Identity() def forward(self, x): res = self.res(x) h = self.gc(x) h = self.bn(h) h = F.relu(h) h = self.tcn(h) return F.relu(h + res) class CTRGCN(nn.Module): def __init__(self, feat_dim_mocap, num_classes, hidden=64, n_joints=52): super().__init__() self.n_joints = n_joints self.coord_dim = 3 self.proj_in = nn.Linear(feat_dim_mocap, n_joints * self.coord_dim) self.blocks = nn.ModuleList([ CTRGCNBlock(self.coord_dim, hidden, n_joints), CTRGCNBlock(hidden, hidden, n_joints), CTRGCNBlock(hidden, hidden * 2, n_joints, stride=2), CTRGCNBlock(hidden * 2, hidden * 4, n_joints, stride=2), ]) self.head = nn.Sequential( nn.Dropout(0.3), nn.Linear(hidden * 4, num_classes), ) def forward(self, x_mocap, mask=None): B, T, _ = x_mocap.shape h = self.proj_in(x_mocap) h = h.reshape(B, T, self.n_joints, self.coord_dim).permute(0, 3, 1, 2) for blk in self.blocks: h = blk(h) h = h.mean(dim=(2, 3)) return self.head(h) # --------------------------------------------------------------------------- # 3) LIMU-BERT (Xu et al., SenSys 2021) # IMU self-supervised pretraining via masked reconstruction + fine-tune. # We implement a simpler variant: BERT-style encoder with optional # pretraining head. # --------------------------------------------------------------------------- class LIMUBertEncoder(nn.Module): def __init__(self, feat_dim_imu, hidden=128, n_layers=4, n_heads=4, dropout=0.1): super().__init__() self.in_proj = nn.Linear(feat_dim_imu, hidden) self.pos = nn.Parameter(torch.zeros(1, 4096, hidden)) nn.init.trunc_normal_(self.pos, std=0.02) layer = nn.TransformerEncoderLayer( d_model=hidden, nhead=n_heads, dim_feedforward=4 * hidden, dropout=dropout, batch_first=True, activation='gelu', ) self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers) def forward(self, x, mask): T = x.size(1) h = self.in_proj(x) + self.pos[:, :T, :] h = self.encoder(h, src_key_padding_mask=~mask) return h class LIMUBert(nn.Module): """Supervised-only variant: encoder + classifier head. Paper's pretraining is a masked-recon objective; for simplicity we report the supervised-only baseline here.""" def __init__(self, feat_dim_imu, num_classes, hidden=128, n_layers=4, n_heads=4, dropout=0.1): super().__init__() self.encoder = LIMUBertEncoder(feat_dim_imu, hidden, n_layers, n_heads, dropout) self.head = nn.Sequential( nn.LayerNorm(hidden), nn.Dropout(dropout), nn.Linear(hidden, num_classes), ) def forward(self, x_imu, mask): h = self.encoder(x_imu, mask) m = mask.unsqueeze(-1).float() pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) return self.head(pooled) # --------------------------------------------------------------------------- # 4) EMG-CNN (standard 1D CNN baseline from sEMG classification literature) # E.g. Atzori et al. — multi-layer CNN with moving-window input. # --------------------------------------------------------------------------- class EMGCNN(nn.Module): def __init__(self, feat_dim_emg, num_classes, hidden=64): super().__init__() self.cnn = nn.Sequential( nn.Conv1d(feat_dim_emg, hidden, 7, padding=3), nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(0.3), nn.Conv1d(hidden, hidden * 2, 5, padding=2), nn.BatchNorm1d(hidden * 2), nn.ReLU(), nn.Dropout(0.3), nn.Conv1d(hidden * 2, hidden * 4, 3, padding=1), nn.BatchNorm1d(hidden * 4), nn.ReLU(), ) self.head = nn.Linear(hidden * 4, num_classes) def forward(self, x_emg, mask): # (B, T, 8) -> (B, 8, T) for conv1d h = self.cnn(x_emg.transpose(1, 2)) # Masked pool m = mask.unsqueeze(1).float() T_ = h.size(2) if m.size(2) != T_: m = F.adaptive_avg_pool1d(m, T_) m = (m > 0.5).float() pooled = (h * m).sum(dim=2) / m.sum(dim=2).clamp(min=1.0) return self.head(pooled) # --------------------------------------------------------------------------- # 5) ActionSense baseline (DelPreto et al., NeurIPS '22) # Simple 3-layer MLP per modality + shared LSTM + classifier. # --------------------------------------------------------------------------- class ActionSenseLSTM(nn.Module): def __init__(self, modality_dims: dict, num_classes, hidden=128): super().__init__() self.mod_names = list(modality_dims.keys()) self.mod_dims = modality_dims self.per_mod = nn.ModuleDict({ name: nn.Sequential( nn.Linear(d, hidden), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden, hidden), nn.ReLU(), ) for name, d in modality_dims.items() }) concat_dim = hidden * len(modality_dims) self.lstm = nn.LSTM(concat_dim, hidden, num_layers=2, batch_first=True, bidirectional=True, dropout=0.2) self.head = nn.Linear(hidden * 2, num_classes) def forward(self, x, mask): # x: (B, T, F_total), slice by modality offset = 0 feats = [] for name in self.mod_names: d = self.mod_dims[name] x_m = x[..., offset:offset + d] offset += d feats.append(self.per_mod[name](x_m)) h = torch.cat(feats, dim=-1) # (B, T, hidden * M) h, _ = self.lstm(h) m = mask.unsqueeze(-1).float() pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) return self.head(pooled) # --------------------------------------------------------------------------- # 6) MulT (Multimodal Transformer, Tsai et al., ACL 2019) # Core idea: cross-modal attention between every pair of modalities. # For a 3-modality input (A, B, C), produce # {A->B, A->C, B->A, B->C, C->A, C->B} via directed cross-attention. # --------------------------------------------------------------------------- class CrossModalTransformer(nn.Module): def __init__(self, d_model, n_heads=4, n_layers=2, dropout=0.1): super().__init__() self.layers = nn.ModuleList([ nn.TransformerDecoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, dropout=dropout, batch_first=True, activation='gelu', ) for _ in range(n_layers) ]) def forward(self, q, kv, q_mask, kv_mask): # q: (B, T_q, D), kv: (B, T_kv, D) h = q for layer in self.layers: h = layer(h, kv, tgt_key_padding_mask=~q_mask, memory_key_padding_mask=~kv_mask) return h class MulT(nn.Module): """Multimodal Transformer. Uses MoCap + EMG + IMU as 3 modalities (EyeTrack/Pressure omitted to match original 3-mod paper design).""" def __init__(self, modality_dims: dict, num_classes, d_model=128, n_layers=2, n_heads=4, dropout=0.1): super().__init__() self.mod_names = [m for m in ['mocap', 'emg', 'imu'] if m in modality_dims] if len(self.mod_names) < 2: self.mod_names = list(modality_dims.keys())[:3] self.mod_dims = {m: modality_dims[m] for m in self.mod_names} self.in_proj = nn.ModuleDict({ m: nn.Linear(d, d_model) for m, d in self.mod_dims.items() }) # Pairwise cross-attention self.cross = nn.ModuleDict({ f"{a}_to_{b}": CrossModalTransformer(d_model, n_heads, n_layers, dropout) for a in self.mod_names for b in self.mod_names if a != b }) # Self-attention after cross self.self_tx = nn.ModuleDict({ m: nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, dropout=dropout, batch_first=True, activation='gelu', ), num_layers=1, ) for m in self.mod_names }) total_dim = d_model * len(self.mod_names) * len(self.mod_names) self.head = nn.Sequential( nn.LayerNorm(total_dim), nn.Dropout(dropout), nn.Linear(total_dim, num_classes), ) def forward(self, x, mask): # Slice modalities from x offset = 0 projs = {} # Walk through all known mod_dims to find offsets # We need the FULL modality_dims order, which we don't have here; # expect caller to already supply x with exactly mod_names in order. # Workaround: assume caller passes mod_names order matching projection. for m in self.mod_names: d = self.mod_dims[m] projs[m] = self.in_proj[m](x[..., offset:offset + d]) offset += d # Cross-attention: each modality attends to each other fused = {m: [] for m in self.mod_names} for a in self.mod_names: for b in self.mod_names: if a == b: fused[a].append(projs[a]) else: out = self.cross[f"{a}_to_{b}"](projs[a], projs[b], mask, mask) fused[a].append(out) # Self-attention + pool per modality pooled = [] for a in self.mod_names: # Concat all attended-to representations along feature dim cat = torch.cat(fused[a], dim=-1) # (B, T, D * M) # Actually re-project back to D per stream, then self-attn on stacked # Simplified: self-attention over concatenated, pool, flatten # Here we just pool each separately for i, rep in enumerate(fused[a]): rep = self.self_tx[a](rep) m = mask.unsqueeze(-1).float() p = (rep * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) pooled.append(p) h = torch.cat(pooled, dim=-1) return self.head(h) # --------------------------------------------------------------------------- # 7) Perceiver IO (Jaegle et al., ICML 2021) # Cross-attention from a fixed-size latent query set to all input tokens, # repeated for a few iterations. # --------------------------------------------------------------------------- class PerceiverBlock(nn.Module): def __init__(self, latent_dim, n_heads, dropout): super().__init__() self.ca = nn.MultiheadAttention( latent_dim, n_heads, dropout=dropout, batch_first=True, ) self.norm1 = nn.LayerNorm(latent_dim) self.sa = nn.TransformerEncoderLayer( d_model=latent_dim, nhead=n_heads, dim_feedforward=4 * latent_dim, dropout=dropout, batch_first=True, activation='gelu', ) def forward(self, latents, inputs, input_kpm): # Cross-attn: latents attend to inputs h, _ = self.ca(latents, inputs, inputs, key_padding_mask=input_kpm) latents = self.norm1(latents + h) # Self-attn on latents latents = self.sa(latents) return latents class PerceiverIO(nn.Module): """Perceiver with N learnable latent queries; supports any modality mix.""" def __init__(self, modality_dims: dict, num_classes, latent_dim=128, n_latents=32, n_layers=3, n_heads=4, dropout=0.1): super().__init__() self.mod_names = list(modality_dims.keys()) self.mod_dims = modality_dims # Per-modality input projection to latent_dim, with modality-id embedding self.in_proj = nn.ModuleDict({ m: nn.Linear(d, latent_dim) for m, d in modality_dims.items() }) self.mod_emb = nn.Parameter(torch.randn(len(self.mod_names), latent_dim) * 0.02) # Positional encoding (shared) self.pos = nn.Parameter(torch.zeros(1, 4096, latent_dim)) nn.init.trunc_normal_(self.pos, std=0.02) # Learnable latents self.latents = nn.Parameter(torch.randn(n_latents, latent_dim) * 0.02) self.blocks = nn.ModuleList([ PerceiverBlock(latent_dim, n_heads, dropout) for _ in range(n_layers) ]) self.head = nn.Sequential( nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes), ) def forward(self, x, mask): B, T, _ = x.shape # Project each modality + add modality embedding offset = 0 tokens = [] for i, m in enumerate(self.mod_names): d = self.mod_dims[m] tok = self.in_proj[m](x[..., offset:offset + d]) # (B, T, D) tok = tok + self.mod_emb[i] offset += d tokens.append(tok) # Concatenate along TIME dim, add shared pos enc per-modality # Each modality gets its own time sequence concatenated # Simpler: sum across modalities (like early fusion in latent space) + pos h = torch.stack(tokens, dim=2).mean(dim=2) # (B, T, D) h = h + self.pos[:, :T, :] input_kpm = ~mask # (B, T), True = ignore # Iterative cross-attention latents = self.latents.unsqueeze(0).expand(B, -1, -1) # (B, N, D) for blk in self.blocks: latents = blk(latents, h, input_kpm) # Mean-pool latents pooled = latents.mean(dim=1) return self.head(pooled)