""" Models for T10 Triplet Next-Action Prediction. Two classes live here: * TripletHead — shared head module producing (verb_fine, verb_composite, noun, hand) logits from a pooled feature vector. * DeepConvLSTMTriplet — single-flow CNN+LSTM baseline (concatenates all available modalities along the feature axis). * DailyActFormer — our full-modality cross-modal Transformer that keeps each modality in its own stem, fuses via a modality token, and runs a causal temporal Transformer. Supports the anticipatory auxiliary loss mentioned in the paper plan (currently as a stub; enabled later in training). All models take: x: dict[mod_name -> (B, T, F_mod)] mask: BoolTensor (B, T) and return a dict: {'verb_fine': (B, NUM_VERB_FINE), 'verb_composite': (B, NUM_VERB_COMPOSITE), 'noun': (B, NUM_NOUN), 'hand': (B, NUM_HAND)} """ from __future__ import annotations import math import sys from pathlib import Path from typing import Dict, List, Optional, Sequence import torch import torch.nn as nn import torch.nn.functional as F # Importable from either (a) neurips26 root, or (b) frozen row/code/ folder. _THIS = Path(__file__).resolve() sys.path.insert(0, str(_THIS.parent)) sys.path.insert(0, str(_THIS.parent.parent)) try: from experiments.taxonomy import ( NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND, ) except ModuleNotFoundError: from taxonomy import ( NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND, ) # --------------------------------------------------------------------------- # Shared triplet head # --------------------------------------------------------------------------- class _PrevActionConcat(nn.Module): """Embeds the previous-segment (verb_composite, noun) ground-truth labels and concatenates them to a pooled feature vector. Used by every model when `use_prev_action=True`. The +1 vocab slot is the BOS / no-prev sentinel emitted by the dataset for the first kept segment of each recording. Output dim added to pooled = 2 * prev_emb_dim.""" def __init__(self, prev_emb_dim: int = 32): super().__init__() from taxonomy import NUM_VERB_COMPOSITE as _NVC, NUM_NOUN as _NN # noqa self.vc_emb = nn.Embedding(_NVC + 1, prev_emb_dim) self.n_emb = nn.Embedding(_NN + 1, prev_emb_dim) self.out_dim = 2 * prev_emb_dim def forward(self, pooled: torch.Tensor, prev_v_comp: Optional[torch.Tensor] = None, prev_noun: Optional[torch.Tensor] = None) -> torch.Tensor: if prev_v_comp is None or prev_noun is None: B = pooled.size(0) prev_v_comp = torch.full((B,), self.vc_emb.num_embeddings - 1, dtype=torch.long, device=pooled.device) prev_noun = torch.full((B,), self.n_emb.num_embeddings - 1, dtype=torch.long, device=pooled.device) pe = torch.cat([self.vc_emb(prev_v_comp), self.n_emb(prev_noun)], dim=-1) return torch.cat([pooled, pe], dim=-1) class TripletHead(nn.Module): def __init__(self, feat_dim: int, hidden: int = 256, dropout: float = 0.2): super().__init__() self.norm = nn.LayerNorm(feat_dim) self.trunk = nn.Sequential( nn.Linear(feat_dim, hidden), nn.GELU(), nn.Dropout(dropout), ) self.verb_fine = nn.Linear(hidden, NUM_VERB_FINE) self.verb_composite = nn.Linear(hidden, NUM_VERB_COMPOSITE) self.noun = nn.Linear(hidden, NUM_NOUN) self.hand = nn.Linear(hidden, NUM_HAND) def forward(self, feat: torch.Tensor) -> Dict[str, torch.Tensor]: h = self.trunk(self.norm(feat)) return { "verb_fine": self.verb_fine(h), "verb_composite": self.verb_composite(h), "noun": self.noun(h), "hand": self.hand(h), } def _masked_mean_pool(h: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Mean over the time axis of `h` (B, T, D) using a boolean mask (B, T).""" m = mask.to(h.dtype).unsqueeze(-1) return (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) # --------------------------------------------------------------------------- # Baseline: DeepConvLSTM (Ordonez & Roggen 2016) adapted for triplet prediction # --------------------------------------------------------------------------- class DeepConvLSTMTriplet(nn.Module): """Single-flow CNN+LSTM. Concatenates per-modality features on F axis.""" def __init__( self, modality_dims: Dict[str, int], conv_filters: int = 64, conv_kernel: int = 5, num_conv_layers: int = 4, lstm_hidden: int = 128, num_lstm_layers: int = 2, dropout: float = 0.2, head_hidden: int = 256, use_prev_action: bool = False, prev_emb_dim: int = 32, ): super().__init__() self.modality_dims = dict(modality_dims) self.use_prev_action = use_prev_action in_ch = sum(modality_dims.values()) convs: List[nn.Module] = [] c = in_ch for i in range(num_conv_layers): convs.append(nn.Sequential( nn.Conv1d(c, conv_filters, conv_kernel, padding=conv_kernel // 2), nn.BatchNorm1d(conv_filters), nn.ReLU(), nn.Dropout(dropout if i < num_conv_layers - 1 else dropout + 0.1), )) c = conv_filters self.convs = nn.Sequential(*convs) self.lstm = nn.LSTM( conv_filters, lstm_hidden, num_layers=num_lstm_layers, batch_first=True, bidirectional=False, dropout=dropout if num_lstm_layers > 1 else 0.0, ) head_in = lstm_hidden if use_prev_action: self.prev_concat = _PrevActionConcat(prev_emb_dim) head_in += self.prev_concat.out_dim else: self.prev_concat = None self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) def forward( self, x: Dict[str, torch.Tensor], mask: torch.Tensor, prev_v_comp: Optional[torch.Tensor] = None, prev_noun: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2) feats = self.convs(feats).transpose(1, 2) out, (h_n, _) = self.lstm(feats) pooled = h_n[-1] if self.use_prev_action: pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) return self.head(pooled) # --------------------------------------------------------------------------- # Our model: DailyActFormer # --------------------------------------------------------------------------- class _ModalityStem(nn.Module): """Multi-scale 1-D conv stem (kernels 3, 5, 9) per modality. Borrowed from HandFormer (the top-1 baseline on T10 recognition): three parallel convolutions capture fast (k=3, ~0.15s @ 20Hz), medium (k=5), and slow (k=9, ~0.45s) temporal patterns. Output is a 1×1 fusion of the three branches, projected back to d_model. """ def __init__(self, in_dim: int, d_model: int, kernels=(3, 5, 9), dropout: float = 0.1): super().__init__() self.kernels = kernels self.branches = nn.ModuleList([ nn.Conv1d(in_dim, d_model, k, padding=k // 2) for k in kernels ]) self.merge = nn.Sequential( nn.GELU(), nn.Conv1d(d_model * len(kernels), d_model, 1), ) self.norm = nn.LayerNorm(d_model) self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, T, F_in) -> (B, F_in, T) for conv1d z = x.transpose(1, 2) multi = [c(z) for c in self.branches] # each (B, D, T) h = self.merge(torch.cat(multi, dim=1)).transpose(1, 2) # (B, T, D) return self.drop(self.norm(h)) class _QueryPool(nn.Module): """Learnable-query cross-attention pooling (replaces mean pool). Inspired by FUTR (the top-5 baseline winner): a single learnable query cross-attends to the entire encoder output, producing one summary vector. Compared to a plain mean pool this lets the model weight informative frames more heavily. """ def __init__(self, d_model: int, n_heads: int = 4, dropout: float = 0.1): super().__init__() self.q = nn.Parameter(torch.zeros(1, 1, d_model)) nn.init.trunc_normal_(self.q, std=0.02) self.attn = nn.MultiheadAttention( d_model, n_heads, dropout=dropout, batch_first=True, ) self.norm = nn.LayerNorm(d_model) def forward(self, h: torch.Tensor, key_padding_mask: Optional[torch.Tensor]): # h: (B, T, D); key_padding_mask: (B, T) where True = pad-to-mask-out B = h.size(0) q = self.q.expand(B, -1, -1) out, _ = self.attn(q, h, h, key_padding_mask=key_padding_mask, need_weights=False) return self.norm(out.squeeze(1)) class _CrossModalTemporalShift(nn.Module): """Cross-modal temporal-shift attention between two modalities. Motivation (paper case study, §sec:grasp-phase-main): EMG activation leads motion onset by a sub-frame ~20ms in our 100Hz recordings. After the 5x downsample to 20Hz, that lag is ~0.4 frames, but per-subject variability plus slack in our segment annotations introduces a few frames of drift that a fixed alignment cannot capture. We learn a discrete temporal shift Δ ∈ {-max_shift, …, +max_shift} frames applied to one of the two modalities (EMG by default), so the shifted tokens align with the other branch (MoCap) before cross-modal fusion. The shift is sampled via straight-through Gumbel-softmax during training; at inference we take the argmax (deterministic). Inputs are per-modality token sequences (B, T, D). Outputs the same shape. Only the `shift_modality` branch is shifted; other modalities pass through. """ def __init__(self, max_shift: int = 3, tau: float = 1.0): super().__init__() self.max_shift = max_shift self.tau = tau # Logits over 2*max_shift+1 categorical shift candidates. self.shift_logits = nn.Parameter(torch.zeros(2 * max_shift + 1)) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, T, D); produce a shifted version that's a soft-blend over # the shift dimension. Hard at inference, gumbel-softmax at training. if self.training: w = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True, dim=-1) else: w = F.one_hot(self.shift_logits.argmax(), num_classes=2 * self.max_shift + 1).float() shifted = [] for i, s in enumerate(range(-self.max_shift, self.max_shift + 1)): shifted.append(w[i] * torch.roll(x, shifts=s, dims=1)) return torch.stack(shifted, dim=0).sum(dim=0) class _CausalTransformerBlock(nn.Module): """Standard Transformer encoder block with a strictly causal attention mask.""" def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.1): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) mlp_dim = int(d_model * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(d_model, mlp_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_dim, d_model), nn.Dropout(dropout), ) def forward(self, x: torch.Tensor, attn_mask: torch.Tensor, key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor: h = self.norm1(x) h, _ = self.attn(h, h, h, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) x = x + h x = x + self.mlp(self.norm2(x)) return x class DailyActFormer(nn.Module): """Cross-modal Transformer that uses every available modality. Architecture outline: per-modality stem → learnable modality embedding → concat across time (each frame -> M modality tokens) → 1 fusion-layer cross-modal attention (compress M→1 per frame) → temporal Transformer (bidirectional by default; causal when `causal=True` for anticipation-style next-action prediction) → pooled → TripletHead For simplicity the fusion step is an attention pooling with learnable queries, rather than a full cross-modal block. This keeps the parameter count modest (2–4 M range with d_model=128). """ def __init__( self, modality_dims: Dict[str, int], d_model: int = 128, n_layers: int = 4, n_heads: int = 4, dropout: float = 0.1, head_hidden: int = 256, max_T: int = 256, causal: bool = False, xshift_modality: Optional[str] = "emg", xshift_max: int = 3, use_prev_action: bool = False, prev_emb_dim: int = 32, ): super().__init__() self.modalities = list(modality_dims.keys()) self.causal = causal self.use_prev_action = use_prev_action # Prev-action concat (shared helper) if use_prev_action: self.prev_concat = _PrevActionConcat(prev_emb_dim) self._prev_extra_dim = self.prev_concat.out_dim else: self.prev_concat = None self._prev_extra_dim = 0 # 0) Cross-modal temporal-shift block on one branch (EMG by default). # Disabled if `xshift_modality` is None or not present. if xshift_modality is not None and xshift_modality in modality_dims: self.xshift_modality = xshift_modality self.xshift = _CrossModalTemporalShift(max_shift=xshift_max) else: self.xshift_modality = None self.xshift = None # 1) per-modality 1-D conv stems (each produces d_model features/frame) self.stems = nn.ModuleDict({ m: _ModalityStem(F, d_model, dropout=dropout) for m, F in modality_dims.items() }) # 2) modality embedding (broadcast-add to per-modality tokens) self.modality_embed = nn.Parameter( torch.zeros(len(self.modalities), d_model) ) nn.init.trunc_normal_(self.modality_embed, std=0.02) # 3) per-frame cross-modal fusion: use a single learnable query token self.fusion_q = nn.Parameter(torch.zeros(1, 1, d_model)) self.fusion_kv = nn.LayerNorm(d_model) self.fusion_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True) # 4) positional embedding along time (post-fusion) self.pos_embed = nn.Parameter(torch.zeros(1, max_T, d_model)) nn.init.trunc_normal_(self.pos_embed, std=0.02) self.max_T = max_T # 5) causal temporal Transformer self.temporal_norm = nn.LayerNorm(d_model) self.temporal = nn.ModuleList([ _CausalTransformerBlock(d_model, n_heads, dropout=dropout) for _ in range(n_layers) ]) # 6) Pool: learnable-query cross-attention (replaces mean pool, FUTR-style) self.pool = _QueryPool(d_model, n_heads=n_heads, dropout=dropout) # 7) triplet head: input dim = d_model + (optional prev-action embed) head_in = d_model + self._prev_extra_dim self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) nn.init.trunc_normal_(self.fusion_q, std=0.02) # ---- helpers ---- def _causal_mask(self, T: int, device) -> torch.Tensor: # MultiheadAttention wants additive mask with -inf above diag. m = torch.full((T, T), float("-inf"), device=device) m.triu_(diagonal=1) return m # ---- forward ---- def forward( self, x: Dict[str, torch.Tensor], mask: torch.Tensor, prev_v_comp: Optional[torch.Tensor] = None, prev_noun: Optional[torch.Tensor] = None, return_features: bool = False, ) -> Dict[str, torch.Tensor]: # Stems: per-modality token streams stem_tokens: List[torch.Tensor] = [] mods_in = [m for m in self.modalities if m in x] if not mods_in: raise ValueError("No modality from the model signature was provided.") for i, m in enumerate(mods_in): h = self.stems[m](x[m]) # (B, T, D) # Cross-modal temporal shift: apply to one branch (e.g. EMG) so it # aligns with the others before fusion. Implements paper SyncFuse's # main novelty (sub-frame anticipatory coupling between EMG/MoCap). if self.xshift is not None and m == self.xshift_modality: h = self.xshift(h) h = h + self.modality_embed[self.modalities.index(m)] stem_tokens.append(h) # Cross-modal fusion: per-frame, attend learnable query over the M stacked # modality tokens. Output is (B, T, D). B, T, D = stem_tokens[0].shape # stack -> (B, T, M, D) -> reshape as (B*T, M, D) stacked = torch.stack(stem_tokens, dim=2) # (B, T, M, D) M = stacked.size(2) stacked = stacked.reshape(B * T, M, D) kv = self.fusion_kv(stacked) q = self.fusion_q.expand(B * T, -1, -1) fused, _ = self.fusion_attn(q, kv, kv, need_weights=False) fused = fused.reshape(B, T, D) # (B, T, D) # Positional embedding + causal temporal Transformer if T > self.max_T: raise ValueError(f"T={T} exceeds max_T={self.max_T}") h = fused + self.pos_embed[:, :T, :] h = self.temporal_norm(h) attn_mask = self._causal_mask(T, h.device) if self.causal else None key_padding = ~mask if mask is not None else None for block in self.temporal: h = block(h, attn_mask=attn_mask, key_padding_mask=key_padding) # Pool: learnable-query cross-attention (FUTR-style) over valid frames pooled = self.pool(h, key_padding_mask=key_padding) # Optional: condition on previous segment's labels if self.use_prev_action: pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) logits = self.head(pooled) if return_features: logits["_pooled"] = pooled return logits # =========================================================================== # Published baselines, sensor-adapted. Each keeps the original paper's key # idea (rolling+unrolling LSTM for RULSTM, causal encoder–decoder for FUTR, # early modality-token fusion for AFFT, etc.) but swaps the RGB/feature input # for our multimodal sensor streams, and the classification head for our # shared TripletHead. # =========================================================================== # --------------------------------------------------------------------------- # RULSTM (Furnari & Farinella, TPAMI 2020) — sensor-adapted # Per-modality rolling LSTM summarises the past, a second unrolling LSTM # takes R-LSTM state and walks `future_steps` steps forward to mimic # anticipation without needing future sensor data. Fusion is late: each # modality produces logits, we average them. # --------------------------------------------------------------------------- class _RULSTMBranch(nn.Module): def __init__(self, in_dim: int, hidden: int, future_steps: int, dropout: float = 0.2): super().__init__() self.future_steps = future_steps self.rolling = nn.LSTM(in_dim, hidden, batch_first=True) self.unrolling = nn.LSTMCell(hidden, hidden) self.drop = nn.Dropout(dropout) self.out_dim = hidden def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # x: (B, T, F_in), mask: (B, T) # Pack-free: LSTM on padded sequences is fine since we pool from h_n. _, (h_n, c_n) = self.rolling(x) # (1, B, H) h = h_n.squeeze(0); c = c_n.squeeze(0) inp = h for _ in range(self.future_steps): h, c = self.unrolling(inp, (h, c)) inp = h return self.drop(h) class RULSTMTriplet(nn.Module): def __init__(self, modality_dims: Dict[str, int], hidden: int = 128, future_steps: int = 8, dropout: float = 0.2, head_hidden: int = 256, use_prev_action: bool = False, prev_emb_dim: int = 32): super().__init__() self.use_prev_action = use_prev_action self.branches = nn.ModuleDict({ m: _RULSTMBranch(F, hidden, future_steps, dropout) for m, F in modality_dims.items() }) head_in = hidden if use_prev_action: self.prev_concat = _PrevActionConcat(prev_emb_dim) head_in += self.prev_concat.out_dim else: self.prev_concat = None self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) def forward(self, x, mask, prev_v_comp=None, prev_noun=None): feats = [] for m in x: feats.append(self.branches[m](x[m], mask)) fused = torch.stack(feats, dim=0).mean(dim=0) if self.use_prev_action: fused = self.prev_concat(fused, prev_v_comp, prev_noun) return self.head(fused) # --------------------------------------------------------------------------- # FUTR (Gong et al., CVPR 2022) — sensor-adapted # Transformer encoder over observation frames (with per-frame feature from # concat(modalities)). A decoder query attends over the encoder memory to # produce a single future-action embedding which is fed into the triplet # head. No autoregressive decoding — we only predict 1 target segment. # --------------------------------------------------------------------------- class FUTRTriplet(nn.Module): def __init__(self, modality_dims: Dict[str, int], d_model: int = 128, n_heads: int = 4, n_layers: int = 3, dropout: float = 0.1, head_hidden: int = 256, max_T: int = 256, use_prev_action: bool = False, prev_emb_dim: int = 32): super().__init__() self.use_prev_action = use_prev_action in_dim = sum(modality_dims.values()) self.in_proj = nn.Linear(in_dim, d_model) self.pos = nn.Parameter(torch.zeros(1, max_T, d_model)) nn.init.trunc_normal_(self.pos, std=0.02) self.max_T = max_T enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, dropout=dropout, batch_first=True, activation="gelu", ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) self.future_q = nn.Parameter(torch.zeros(1, 1, d_model)) nn.init.trunc_normal_(self.future_q, std=0.02) self.cross_attn = nn.MultiheadAttention( d_model, n_heads, dropout=dropout, batch_first=True, ) head_in = d_model if use_prev_action: self.prev_concat = _PrevActionConcat(prev_emb_dim) head_in += self.prev_concat.out_dim else: self.prev_concat = None self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) def forward(self, x, mask, prev_v_comp=None, prev_noun=None): feats = torch.cat([x[m] for m in x], dim=-1) B, T, _ = feats.shape if T > self.max_T: raise ValueError(f"T={T} exceeds FUTR max_T={self.max_T}") h = self.in_proj(feats) + self.pos[:, :T, :] h = self.encoder(h, src_key_padding_mask=~mask) q = self.future_q.expand(B, -1, -1) out, _ = self.cross_attn(q, h, h, key_padding_mask=~mask, need_weights=False) pooled = out.squeeze(1) if self.use_prev_action: pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) return self.head(pooled) # --------------------------------------------------------------------------- # AFFT (Zhong et al., WACV 2023) — sensor-adapted # Per-modality tokens (one per frame per modality) are concatenated into a # long token sequence of length T*M and passed through an encoder with # causal temporal attention so the model must anticipate strictly from the # past. Fusion happens "anticipatively" inside the attention. # --------------------------------------------------------------------------- class AFFTTriplet(nn.Module): def __init__(self, modality_dims: Dict[str, int], d_model: int = 96, n_heads: int = 4, n_layers: int = 3, dropout: float = 0.1, head_hidden: int = 256, max_T: int = 256, use_prev_action: bool = False, prev_emb_dim: int = 32): super().__init__() self.use_prev_action = use_prev_action self.modalities = list(modality_dims.keys()) self.stems = nn.ModuleDict({ m: nn.Linear(F, d_model) for m, F in modality_dims.items() }) self.mod_embed = nn.Parameter( torch.zeros(len(self.modalities), d_model) ) nn.init.trunc_normal_(self.mod_embed, std=0.02) self.pos = nn.Parameter(torch.zeros(1, max_T, d_model)) nn.init.trunc_normal_(self.pos, std=0.02) self.max_T = max_T self.d_model = d_model self.blocks = nn.ModuleList([ _CausalTransformerBlock(d_model, n_heads, dropout=dropout) for _ in range(n_layers) ]) head_in = d_model if use_prev_action: self.prev_concat = _PrevActionConcat(prev_emb_dim) head_in += self.prev_concat.out_dim else: self.prev_concat = None self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) def _expand_causal_mask(self, T: int, M: int, device) -> torch.Tensor: # Token layout: [m0_t0, m1_t0, ..., mM_t0, m0_t1, ..., mM_t(T-1)] # Token at (m, t) can attend to all (m', t') with t' <= t. ts = torch.arange(T, device=device).unsqueeze(1).expand(-1, M).reshape(-1) return ts[:, None] < ts[None, :] # True where future (mask out) def forward(self, x, mask, prev_v_comp=None, prev_noun=None): # Build per-frame token streams. mods = [m for m in self.modalities if m in x] per_mod_tokens = [] B, T, _ = x[mods[0]].shape for i, m in enumerate(mods): h = self.stems[m](x[m]) + self.mod_embed[self.modalities.index(m)] per_mod_tokens.append(h) stacked = torch.stack(per_mod_tokens, dim=2) M = stacked.size(2) tokens = stacked.reshape(B, T * M, self.d_model) if T > self.max_T: raise ValueError(f"T={T} exceeds AFFT max_T={self.max_T}") pos_per_frame = self.pos[:, :T, :].unsqueeze(2).expand(-1, -1, M, -1) tokens = tokens + pos_per_frame.reshape(1, T * M, self.d_model) attn_mask = self._expand_causal_mask(T, M, tokens.device) attn_mask = torch.where(attn_mask, torch.tensor(float("-inf"), device=tokens.device), torch.tensor(0.0, device=tokens.device)) kp = (~mask).unsqueeze(2).expand(-1, -1, M).reshape(B, T * M) for blk in self.blocks: tokens = blk(tokens, attn_mask=attn_mask, key_padding_mask=kp) last_slice = tokens[:, -M:, :] pooled = last_slice.mean(dim=1) if self.use_prev_action: pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) return self.head(pooled) # --------------------------------------------------------------------------- # HandFormer (Shamil et al., ECCV 2024) — sensor-adapted # Originally on 3D hand poses. We feed it only the MoCap modality (which # contains 10 fingertip joints). Multi-scale 1-D conv over time, followed # by a Transformer. If MoCap is not in `modalities`, falls back to whatever # is provided (but then it's no longer the paper's "pose-only" setup). # --------------------------------------------------------------------------- class HandFormerTriplet(nn.Module): def __init__(self, modality_dims: Dict[str, int], d_model: int = 128, n_heads: int = 4, n_layers: int = 3, kernels=(3, 5, 9), dropout: float = 0.1, head_hidden: int = 256, max_T: int = 256, use_prev_action: bool = False, prev_emb_dim: int = 32): super().__init__() self.use_prev_action = use_prev_action in_dim = sum(modality_dims.values()) self.multi_conv = nn.ModuleList([ nn.Conv1d(in_dim, d_model, k, padding=k // 2) for k in kernels ]) self.conv_merge = nn.Conv1d(d_model * len(kernels), d_model, 1) self.pos = nn.Parameter(torch.zeros(1, max_T, d_model)) nn.init.trunc_normal_(self.pos, std=0.02) self.max_T = max_T enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, dropout=dropout, batch_first=True, activation="gelu", ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) head_in = d_model if use_prev_action: self.prev_concat = _PrevActionConcat(prev_emb_dim) head_in += self.prev_concat.out_dim else: self.prev_concat = None self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) def forward(self, x, mask, prev_v_comp=None, prev_noun=None): feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2) multi = [c(feats) for c in self.multi_conv] h = self.conv_merge(torch.cat(multi, dim=1)) h = h.transpose(1, 2) T = h.size(1) if T > self.max_T: raise ValueError(f"T={T} exceeds HandFormer max_T={self.max_T}") h = h + self.pos[:, :T, :] h = self.encoder(h, src_key_padding_mask=~mask) pooled = _masked_mean_pool(h, mask) if self.use_prev_action: pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) return self.head(pooled) # --------------------------------------------------------------------------- # Placeholder ActionLLM — a conv-stem sensor encoder + a 2-layer Transformer # trained from scratch as a surrogate. The *full* LoRA+Qwen version lives in # `train_pred.py` and can be wired in later if the surrogate is too weak. # --------------------------------------------------------------------------- class ActionLLMSurrogate(nn.Module): def __init__(self, modality_dims: Dict[str, int], d_model: int = 192, n_heads: int = 6, n_layers: int = 2, dropout: float = 0.1, head_hidden: int = 256, max_T: int = 256, use_prev_action: bool = False, prev_emb_dim: int = 32): super().__init__() self.use_prev_action = use_prev_action in_dim = sum(modality_dims.values()) self.stem = nn.Sequential( nn.Conv1d(in_dim, d_model, 5, padding=2), nn.GELU(), nn.Conv1d(d_model, d_model, 5, padding=2), ) self.pos = nn.Parameter(torch.zeros(1, max_T, d_model)) nn.init.trunc_normal_(self.pos, std=0.02) self.max_T = max_T enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, dropout=dropout, batch_first=True, activation="gelu", ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) head_in = d_model if use_prev_action: self.prev_concat = _PrevActionConcat(prev_emb_dim) head_in += self.prev_concat.out_dim else: self.prev_concat = None self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) def forward(self, x, mask, prev_v_comp=None, prev_noun=None): feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2) h = self.stem(feats).transpose(1, 2) T = h.size(1) if T > self.max_T: raise ValueError(f"T={T} exceeds ActionLLM max_T={self.max_T}") h = h + self.pos[:, :T, :] h = self.encoder(h, src_key_padding_mask=~mask) pooled = _masked_mean_pool(h, mask) if self.use_prev_action: pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) return self.head(pooled) # --------------------------------------------------------------------------- # Factory # --------------------------------------------------------------------------- def build_model( name: str, modality_dims: Dict[str, int], **kwargs, ) -> nn.Module: name = name.lower() if name in ("deepconvlstm", "dcl"): return DeepConvLSTMTriplet(modality_dims, **kwargs) if name in ("dailyactformer", "ours", "daf"): return DailyActFormer(modality_dims, **kwargs) if name in ("rulstm",): return RULSTMTriplet(modality_dims, **kwargs) if name in ("futr",): return FUTRTriplet(modality_dims, **kwargs) if name in ("afft",): return AFFTTriplet(modality_dims, **kwargs) if name in ("handformer",): return HandFormerTriplet(modality_dims, **kwargs) if name in ("actionllm",): return ActionLLMSurrogate(modality_dims, **kwargs) raise ValueError(f"Unknown model: {name}") # --------------------------------------------------------------------------- # Smoke-test: build each model, run a random batch, check output shapes. # --------------------------------------------------------------------------- if __name__ == "__main__": B, T = 2, 160 dims = {"imu": 180, "emg": 8, "eyetrack": 24} x = {m: torch.randn(B, T, d) for m, d in dims.items()} mask = torch.ones(B, T, dtype=torch.bool) for name in ("deepconvlstm", "dailyactformer", "rulstm", "futr", "afft", "handformer", "actionllm"): model = build_model(name, dims) n_params = sum(p.numel() for p in model.parameters()) out = model(x, mask) print(f"{name:16s} params={n_params:>10,} shapes=" f"vf={tuple(out['verb_fine'].shape)} " f"vc={tuple(out['verb_composite'].shape)} " f"n={tuple(out['noun'].shape)} " f"h={tuple(out['hand'].shape)}")