PULSE-code / experiments /nets /models_seqpred.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
"""
Models for T10 Triplet Next-Action Prediction.
Two classes live here:
* TripletHead — shared head module producing (verb_fine, verb_composite,
noun, hand) logits from a pooled feature vector.
* DeepConvLSTMTriplet — single-flow CNN+LSTM baseline (concatenates all
available modalities along the feature axis).
* DailyActFormer — our full-modality cross-modal Transformer that keeps
each modality in its own stem, fuses via a modality
token, and runs a causal temporal Transformer. Supports
the anticipatory auxiliary loss mentioned in the paper
plan (currently as a stub; enabled later in training).
All models take:
x: dict[mod_name -> (B, T, F_mod)]
mask: BoolTensor (B, T)
and return a dict:
{'verb_fine': (B, NUM_VERB_FINE),
'verb_composite': (B, NUM_VERB_COMPOSITE),
'noun': (B, NUM_NOUN),
'hand': (B, NUM_HAND)}
"""
from __future__ import annotations
import math
import sys
from pathlib import Path
from typing import Dict, List, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
# Importable from either (a) neurips26 root, or (b) frozen row/code/ folder.
_THIS = Path(__file__).resolve()
sys.path.insert(0, str(_THIS.parent))
sys.path.insert(0, str(_THIS.parent.parent))
try:
from experiments.taxonomy import (
NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND,
)
except ModuleNotFoundError:
from taxonomy import (
NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND,
)
# ---------------------------------------------------------------------------
# Shared triplet head
# ---------------------------------------------------------------------------
class _PrevActionConcat(nn.Module):
"""Embeds the previous-segment (verb_composite, noun) ground-truth labels
and concatenates them to a pooled feature vector. Used by every model
when `use_prev_action=True`. The +1 vocab slot is the BOS / no-prev
sentinel emitted by the dataset for the first kept segment of each
recording. Output dim added to pooled = 2 * prev_emb_dim."""
def __init__(self, prev_emb_dim: int = 32):
super().__init__()
from taxonomy import NUM_VERB_COMPOSITE as _NVC, NUM_NOUN as _NN # noqa
self.vc_emb = nn.Embedding(_NVC + 1, prev_emb_dim)
self.n_emb = nn.Embedding(_NN + 1, prev_emb_dim)
self.out_dim = 2 * prev_emb_dim
def forward(self, pooled: torch.Tensor,
prev_v_comp: Optional[torch.Tensor] = None,
prev_noun: Optional[torch.Tensor] = None) -> torch.Tensor:
if prev_v_comp is None or prev_noun is None:
B = pooled.size(0)
prev_v_comp = torch.full((B,), self.vc_emb.num_embeddings - 1,
dtype=torch.long, device=pooled.device)
prev_noun = torch.full((B,), self.n_emb.num_embeddings - 1,
dtype=torch.long, device=pooled.device)
pe = torch.cat([self.vc_emb(prev_v_comp), self.n_emb(prev_noun)], dim=-1)
return torch.cat([pooled, pe], dim=-1)
class TripletHead(nn.Module):
def __init__(self, feat_dim: int, hidden: int = 256, dropout: float = 0.2):
super().__init__()
self.norm = nn.LayerNorm(feat_dim)
self.trunk = nn.Sequential(
nn.Linear(feat_dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
)
self.verb_fine = nn.Linear(hidden, NUM_VERB_FINE)
self.verb_composite = nn.Linear(hidden, NUM_VERB_COMPOSITE)
self.noun = nn.Linear(hidden, NUM_NOUN)
self.hand = nn.Linear(hidden, NUM_HAND)
def forward(self, feat: torch.Tensor) -> Dict[str, torch.Tensor]:
h = self.trunk(self.norm(feat))
return {
"verb_fine": self.verb_fine(h),
"verb_composite": self.verb_composite(h),
"noun": self.noun(h),
"hand": self.hand(h),
}
def _masked_mean_pool(h: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Mean over the time axis of `h` (B, T, D) using a boolean mask (B, T)."""
m = mask.to(h.dtype).unsqueeze(-1)
return (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
# ---------------------------------------------------------------------------
# Baseline: DeepConvLSTM (Ordonez & Roggen 2016) adapted for triplet prediction
# ---------------------------------------------------------------------------
class DeepConvLSTMTriplet(nn.Module):
"""Single-flow CNN+LSTM. Concatenates per-modality features on F axis."""
def __init__(
self,
modality_dims: Dict[str, int],
conv_filters: int = 64,
conv_kernel: int = 5,
num_conv_layers: int = 4,
lstm_hidden: int = 128,
num_lstm_layers: int = 2,
dropout: float = 0.2,
head_hidden: int = 256,
use_prev_action: bool = False,
prev_emb_dim: int = 32,
):
super().__init__()
self.modality_dims = dict(modality_dims)
self.use_prev_action = use_prev_action
in_ch = sum(modality_dims.values())
convs: List[nn.Module] = []
c = in_ch
for i in range(num_conv_layers):
convs.append(nn.Sequential(
nn.Conv1d(c, conv_filters, conv_kernel, padding=conv_kernel // 2),
nn.BatchNorm1d(conv_filters),
nn.ReLU(),
nn.Dropout(dropout if i < num_conv_layers - 1 else dropout + 0.1),
))
c = conv_filters
self.convs = nn.Sequential(*convs)
self.lstm = nn.LSTM(
conv_filters, lstm_hidden, num_layers=num_lstm_layers,
batch_first=True, bidirectional=False,
dropout=dropout if num_lstm_layers > 1 else 0.0,
)
head_in = lstm_hidden
if use_prev_action:
self.prev_concat = _PrevActionConcat(prev_emb_dim)
head_in += self.prev_concat.out_dim
else:
self.prev_concat = None
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
def forward(
self, x: Dict[str, torch.Tensor], mask: torch.Tensor,
prev_v_comp: Optional[torch.Tensor] = None,
prev_noun: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2)
feats = self.convs(feats).transpose(1, 2)
out, (h_n, _) = self.lstm(feats)
pooled = h_n[-1]
if self.use_prev_action:
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
return self.head(pooled)
# ---------------------------------------------------------------------------
# Our model: DailyActFormer
# ---------------------------------------------------------------------------
class _ModalityStem(nn.Module):
"""Multi-scale 1-D conv stem (kernels 3, 5, 9) per modality.
Borrowed from HandFormer (the top-1 baseline on T10 recognition): three
parallel convolutions capture fast (k=3, ~0.15s @ 20Hz), medium (k=5),
and slow (k=9, ~0.45s) temporal patterns. Output is a 1×1 fusion of
the three branches, projected back to d_model.
"""
def __init__(self, in_dim: int, d_model: int, kernels=(3, 5, 9),
dropout: float = 0.1):
super().__init__()
self.kernels = kernels
self.branches = nn.ModuleList([
nn.Conv1d(in_dim, d_model, k, padding=k // 2) for k in kernels
])
self.merge = nn.Sequential(
nn.GELU(),
nn.Conv1d(d_model * len(kernels), d_model, 1),
)
self.norm = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, F_in) -> (B, F_in, T) for conv1d
z = x.transpose(1, 2)
multi = [c(z) for c in self.branches] # each (B, D, T)
h = self.merge(torch.cat(multi, dim=1)).transpose(1, 2) # (B, T, D)
return self.drop(self.norm(h))
class _QueryPool(nn.Module):
"""Learnable-query cross-attention pooling (replaces mean pool).
Inspired by FUTR (the top-5 baseline winner): a single learnable query
cross-attends to the entire encoder output, producing one summary vector.
Compared to a plain mean pool this lets the model weight informative
frames more heavily.
"""
def __init__(self, d_model: int, n_heads: int = 4, dropout: float = 0.1):
super().__init__()
self.q = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.trunc_normal_(self.q, std=0.02)
self.attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True,
)
self.norm = nn.LayerNorm(d_model)
def forward(self, h: torch.Tensor, key_padding_mask: Optional[torch.Tensor]):
# h: (B, T, D); key_padding_mask: (B, T) where True = pad-to-mask-out
B = h.size(0)
q = self.q.expand(B, -1, -1)
out, _ = self.attn(q, h, h, key_padding_mask=key_padding_mask,
need_weights=False)
return self.norm(out.squeeze(1))
class _CrossModalTemporalShift(nn.Module):
"""Cross-modal temporal-shift attention between two modalities.
Motivation (paper case study, §sec:grasp-phase-main): EMG activation leads
motion onset by a sub-frame ~20ms in our 100Hz recordings. After the 5x
downsample to 20Hz, that lag is ~0.4 frames, but per-subject variability
plus slack in our segment annotations introduces a few frames of drift
that a fixed alignment cannot capture.
We learn a discrete temporal shift Δ ∈ {-max_shift, …, +max_shift} frames
applied to one of the two modalities (EMG by default), so the shifted
tokens align with the other branch (MoCap) before cross-modal fusion. The
shift is sampled via straight-through Gumbel-softmax during training; at
inference we take the argmax (deterministic).
Inputs are per-modality token sequences (B, T, D). Outputs the same shape.
Only the `shift_modality` branch is shifted; other modalities pass through.
"""
def __init__(self, max_shift: int = 3, tau: float = 1.0):
super().__init__()
self.max_shift = max_shift
self.tau = tau
# Logits over 2*max_shift+1 categorical shift candidates.
self.shift_logits = nn.Parameter(torch.zeros(2 * max_shift + 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, D); produce a shifted version that's a soft-blend over
# the shift dimension. Hard at inference, gumbel-softmax at training.
if self.training:
w = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True, dim=-1)
else:
w = F.one_hot(self.shift_logits.argmax(),
num_classes=2 * self.max_shift + 1).float()
shifted = []
for i, s in enumerate(range(-self.max_shift, self.max_shift + 1)):
shifted.append(w[i] * torch.roll(x, shifts=s, dims=1))
return torch.stack(shifted, dim=0).sum(dim=0)
class _CausalTransformerBlock(nn.Module):
"""Standard Transformer encoder block with a strictly causal attention mask."""
def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0,
dropout: float = 0.1):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout,
batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
mlp_dim = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(mlp_dim, d_model), nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor,
key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
h = self.norm1(x)
h, _ = self.attn(h, h, h, attn_mask=attn_mask,
key_padding_mask=key_padding_mask, need_weights=False)
x = x + h
x = x + self.mlp(self.norm2(x))
return x
class DailyActFormer(nn.Module):
"""Cross-modal Transformer that uses every available modality.
Architecture outline:
per-modality stem → learnable modality embedding →
concat across time (each frame -> M modality tokens) →
1 fusion-layer cross-modal attention (compress M→1 per frame) →
temporal Transformer (bidirectional by default; causal when
`causal=True` for anticipation-style next-action prediction)
→ pooled → TripletHead
For simplicity the fusion step is an attention pooling with learnable
queries, rather than a full cross-modal block. This keeps the parameter
count modest (2–4 M range with d_model=128).
"""
def __init__(
self,
modality_dims: Dict[str, int],
d_model: int = 128,
n_layers: int = 4,
n_heads: int = 4,
dropout: float = 0.1,
head_hidden: int = 256,
max_T: int = 256,
causal: bool = False,
xshift_modality: Optional[str] = "emg",
xshift_max: int = 3,
use_prev_action: bool = False,
prev_emb_dim: int = 32,
):
super().__init__()
self.modalities = list(modality_dims.keys())
self.causal = causal
self.use_prev_action = use_prev_action
# Prev-action concat (shared helper)
if use_prev_action:
self.prev_concat = _PrevActionConcat(prev_emb_dim)
self._prev_extra_dim = self.prev_concat.out_dim
else:
self.prev_concat = None
self._prev_extra_dim = 0
# 0) Cross-modal temporal-shift block on one branch (EMG by default).
# Disabled if `xshift_modality` is None or not present.
if xshift_modality is not None and xshift_modality in modality_dims:
self.xshift_modality = xshift_modality
self.xshift = _CrossModalTemporalShift(max_shift=xshift_max)
else:
self.xshift_modality = None
self.xshift = None
# 1) per-modality 1-D conv stems (each produces d_model features/frame)
self.stems = nn.ModuleDict({
m: _ModalityStem(F, d_model, dropout=dropout)
for m, F in modality_dims.items()
})
# 2) modality embedding (broadcast-add to per-modality tokens)
self.modality_embed = nn.Parameter(
torch.zeros(len(self.modalities), d_model)
)
nn.init.trunc_normal_(self.modality_embed, std=0.02)
# 3) per-frame cross-modal fusion: use a single learnable query token
self.fusion_q = nn.Parameter(torch.zeros(1, 1, d_model))
self.fusion_kv = nn.LayerNorm(d_model)
self.fusion_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
# 4) positional embedding along time (post-fusion)
self.pos_embed = nn.Parameter(torch.zeros(1, max_T, d_model))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.max_T = max_T
# 5) causal temporal Transformer
self.temporal_norm = nn.LayerNorm(d_model)
self.temporal = nn.ModuleList([
_CausalTransformerBlock(d_model, n_heads, dropout=dropout)
for _ in range(n_layers)
])
# 6) Pool: learnable-query cross-attention (replaces mean pool, FUTR-style)
self.pool = _QueryPool(d_model, n_heads=n_heads, dropout=dropout)
# 7) triplet head: input dim = d_model + (optional prev-action embed)
head_in = d_model + self._prev_extra_dim
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
nn.init.trunc_normal_(self.fusion_q, std=0.02)
# ---- helpers ----
def _causal_mask(self, T: int, device) -> torch.Tensor:
# MultiheadAttention wants additive mask with -inf above diag.
m = torch.full((T, T), float("-inf"), device=device)
m.triu_(diagonal=1)
return m
# ---- forward ----
def forward(
self, x: Dict[str, torch.Tensor], mask: torch.Tensor,
prev_v_comp: Optional[torch.Tensor] = None,
prev_noun: Optional[torch.Tensor] = None,
return_features: bool = False,
) -> Dict[str, torch.Tensor]:
# Stems: per-modality token streams
stem_tokens: List[torch.Tensor] = []
mods_in = [m for m in self.modalities if m in x]
if not mods_in:
raise ValueError("No modality from the model signature was provided.")
for i, m in enumerate(mods_in):
h = self.stems[m](x[m]) # (B, T, D)
# Cross-modal temporal shift: apply to one branch (e.g. EMG) so it
# aligns with the others before fusion. Implements paper SyncFuse's
# main novelty (sub-frame anticipatory coupling between EMG/MoCap).
if self.xshift is not None and m == self.xshift_modality:
h = self.xshift(h)
h = h + self.modality_embed[self.modalities.index(m)]
stem_tokens.append(h)
# Cross-modal fusion: per-frame, attend learnable query over the M stacked
# modality tokens. Output is (B, T, D).
B, T, D = stem_tokens[0].shape
# stack -> (B, T, M, D) -> reshape as (B*T, M, D)
stacked = torch.stack(stem_tokens, dim=2) # (B, T, M, D)
M = stacked.size(2)
stacked = stacked.reshape(B * T, M, D)
kv = self.fusion_kv(stacked)
q = self.fusion_q.expand(B * T, -1, -1)
fused, _ = self.fusion_attn(q, kv, kv, need_weights=False)
fused = fused.reshape(B, T, D) # (B, T, D)
# Positional embedding + causal temporal Transformer
if T > self.max_T:
raise ValueError(f"T={T} exceeds max_T={self.max_T}")
h = fused + self.pos_embed[:, :T, :]
h = self.temporal_norm(h)
attn_mask = self._causal_mask(T, h.device) if self.causal else None
key_padding = ~mask if mask is not None else None
for block in self.temporal:
h = block(h, attn_mask=attn_mask, key_padding_mask=key_padding)
# Pool: learnable-query cross-attention (FUTR-style) over valid frames
pooled = self.pool(h, key_padding_mask=key_padding)
# Optional: condition on previous segment's labels
if self.use_prev_action:
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
logits = self.head(pooled)
if return_features:
logits["_pooled"] = pooled
return logits
# ===========================================================================
# Published baselines, sensor-adapted. Each keeps the original paper's key
# idea (rolling+unrolling LSTM for RULSTM, causal encoder–decoder for FUTR,
# early modality-token fusion for AFFT, etc.) but swaps the RGB/feature input
# for our multimodal sensor streams, and the classification head for our
# shared TripletHead.
# ===========================================================================
# ---------------------------------------------------------------------------
# RULSTM (Furnari & Farinella, TPAMI 2020) — sensor-adapted
# Per-modality rolling LSTM summarises the past, a second unrolling LSTM
# takes R-LSTM state and walks `future_steps` steps forward to mimic
# anticipation without needing future sensor data. Fusion is late: each
# modality produces logits, we average them.
# ---------------------------------------------------------------------------
class _RULSTMBranch(nn.Module):
def __init__(self, in_dim: int, hidden: int, future_steps: int,
dropout: float = 0.2):
super().__init__()
self.future_steps = future_steps
self.rolling = nn.LSTM(in_dim, hidden, batch_first=True)
self.unrolling = nn.LSTMCell(hidden, hidden)
self.drop = nn.Dropout(dropout)
self.out_dim = hidden
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# x: (B, T, F_in), mask: (B, T)
# Pack-free: LSTM on padded sequences is fine since we pool from h_n.
_, (h_n, c_n) = self.rolling(x) # (1, B, H)
h = h_n.squeeze(0); c = c_n.squeeze(0)
inp = h
for _ in range(self.future_steps):
h, c = self.unrolling(inp, (h, c))
inp = h
return self.drop(h)
class RULSTMTriplet(nn.Module):
def __init__(self, modality_dims: Dict[str, int], hidden: int = 128,
future_steps: int = 8, dropout: float = 0.2,
head_hidden: int = 256,
use_prev_action: bool = False, prev_emb_dim: int = 32):
super().__init__()
self.use_prev_action = use_prev_action
self.branches = nn.ModuleDict({
m: _RULSTMBranch(F, hidden, future_steps, dropout)
for m, F in modality_dims.items()
})
head_in = hidden
if use_prev_action:
self.prev_concat = _PrevActionConcat(prev_emb_dim)
head_in += self.prev_concat.out_dim
else:
self.prev_concat = None
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
feats = []
for m in x:
feats.append(self.branches[m](x[m], mask))
fused = torch.stack(feats, dim=0).mean(dim=0)
if self.use_prev_action:
fused = self.prev_concat(fused, prev_v_comp, prev_noun)
return self.head(fused)
# ---------------------------------------------------------------------------
# FUTR (Gong et al., CVPR 2022) — sensor-adapted
# Transformer encoder over observation frames (with per-frame feature from
# concat(modalities)). A decoder query attends over the encoder memory to
# produce a single future-action embedding which is fed into the triplet
# head. No autoregressive decoding — we only predict 1 target segment.
# ---------------------------------------------------------------------------
class FUTRTriplet(nn.Module):
def __init__(self, modality_dims: Dict[str, int], d_model: int = 128,
n_heads: int = 4, n_layers: int = 3, dropout: float = 0.1,
head_hidden: int = 256, max_T: int = 256,
use_prev_action: bool = False, prev_emb_dim: int = 32):
super().__init__()
self.use_prev_action = use_prev_action
in_dim = sum(modality_dims.values())
self.in_proj = nn.Linear(in_dim, d_model)
self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
nn.init.trunc_normal_(self.pos, std=0.02)
self.max_T = max_T
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
dropout=dropout, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
self.future_q = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.trunc_normal_(self.future_q, std=0.02)
self.cross_attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True,
)
head_in = d_model
if use_prev_action:
self.prev_concat = _PrevActionConcat(prev_emb_dim)
head_in += self.prev_concat.out_dim
else:
self.prev_concat = None
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
feats = torch.cat([x[m] for m in x], dim=-1)
B, T, _ = feats.shape
if T > self.max_T:
raise ValueError(f"T={T} exceeds FUTR max_T={self.max_T}")
h = self.in_proj(feats) + self.pos[:, :T, :]
h = self.encoder(h, src_key_padding_mask=~mask)
q = self.future_q.expand(B, -1, -1)
out, _ = self.cross_attn(q, h, h, key_padding_mask=~mask,
need_weights=False)
pooled = out.squeeze(1)
if self.use_prev_action:
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
return self.head(pooled)
# ---------------------------------------------------------------------------
# AFFT (Zhong et al., WACV 2023) — sensor-adapted
# Per-modality tokens (one per frame per modality) are concatenated into a
# long token sequence of length T*M and passed through an encoder with
# causal temporal attention so the model must anticipate strictly from the
# past. Fusion happens "anticipatively" inside the attention.
# ---------------------------------------------------------------------------
class AFFTTriplet(nn.Module):
def __init__(self, modality_dims: Dict[str, int], d_model: int = 96,
n_heads: int = 4, n_layers: int = 3, dropout: float = 0.1,
head_hidden: int = 256, max_T: int = 256,
use_prev_action: bool = False, prev_emb_dim: int = 32):
super().__init__()
self.use_prev_action = use_prev_action
self.modalities = list(modality_dims.keys())
self.stems = nn.ModuleDict({
m: nn.Linear(F, d_model) for m, F in modality_dims.items()
})
self.mod_embed = nn.Parameter(
torch.zeros(len(self.modalities), d_model)
)
nn.init.trunc_normal_(self.mod_embed, std=0.02)
self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
nn.init.trunc_normal_(self.pos, std=0.02)
self.max_T = max_T
self.d_model = d_model
self.blocks = nn.ModuleList([
_CausalTransformerBlock(d_model, n_heads, dropout=dropout)
for _ in range(n_layers)
])
head_in = d_model
if use_prev_action:
self.prev_concat = _PrevActionConcat(prev_emb_dim)
head_in += self.prev_concat.out_dim
else:
self.prev_concat = None
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
def _expand_causal_mask(self, T: int, M: int, device) -> torch.Tensor:
# Token layout: [m0_t0, m1_t0, ..., mM_t0, m0_t1, ..., mM_t(T-1)]
# Token at (m, t) can attend to all (m', t') with t' <= t.
ts = torch.arange(T, device=device).unsqueeze(1).expand(-1, M).reshape(-1)
return ts[:, None] < ts[None, :] # True where future (mask out)
def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
# Build per-frame token streams.
mods = [m for m in self.modalities if m in x]
per_mod_tokens = []
B, T, _ = x[mods[0]].shape
for i, m in enumerate(mods):
h = self.stems[m](x[m]) + self.mod_embed[self.modalities.index(m)]
per_mod_tokens.append(h)
stacked = torch.stack(per_mod_tokens, dim=2)
M = stacked.size(2)
tokens = stacked.reshape(B, T * M, self.d_model)
if T > self.max_T:
raise ValueError(f"T={T} exceeds AFFT max_T={self.max_T}")
pos_per_frame = self.pos[:, :T, :].unsqueeze(2).expand(-1, -1, M, -1)
tokens = tokens + pos_per_frame.reshape(1, T * M, self.d_model)
attn_mask = self._expand_causal_mask(T, M, tokens.device)
attn_mask = torch.where(attn_mask, torch.tensor(float("-inf"),
device=tokens.device),
torch.tensor(0.0, device=tokens.device))
kp = (~mask).unsqueeze(2).expand(-1, -1, M).reshape(B, T * M)
for blk in self.blocks:
tokens = blk(tokens, attn_mask=attn_mask, key_padding_mask=kp)
last_slice = tokens[:, -M:, :]
pooled = last_slice.mean(dim=1)
if self.use_prev_action:
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
return self.head(pooled)
# ---------------------------------------------------------------------------
# HandFormer (Shamil et al., ECCV 2024) — sensor-adapted
# Originally on 3D hand poses. We feed it only the MoCap modality (which
# contains 10 fingertip joints). Multi-scale 1-D conv over time, followed
# by a Transformer. If MoCap is not in `modalities`, falls back to whatever
# is provided (but then it's no longer the paper's "pose-only" setup).
# ---------------------------------------------------------------------------
class HandFormerTriplet(nn.Module):
def __init__(self, modality_dims: Dict[str, int], d_model: int = 128,
n_heads: int = 4, n_layers: int = 3, kernels=(3, 5, 9),
dropout: float = 0.1, head_hidden: int = 256, max_T: int = 256,
use_prev_action: bool = False, prev_emb_dim: int = 32):
super().__init__()
self.use_prev_action = use_prev_action
in_dim = sum(modality_dims.values())
self.multi_conv = nn.ModuleList([
nn.Conv1d(in_dim, d_model, k, padding=k // 2) for k in kernels
])
self.conv_merge = nn.Conv1d(d_model * len(kernels), d_model, 1)
self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
nn.init.trunc_normal_(self.pos, std=0.02)
self.max_T = max_T
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
dropout=dropout, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
head_in = d_model
if use_prev_action:
self.prev_concat = _PrevActionConcat(prev_emb_dim)
head_in += self.prev_concat.out_dim
else:
self.prev_concat = None
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2)
multi = [c(feats) for c in self.multi_conv]
h = self.conv_merge(torch.cat(multi, dim=1))
h = h.transpose(1, 2)
T = h.size(1)
if T > self.max_T:
raise ValueError(f"T={T} exceeds HandFormer max_T={self.max_T}")
h = h + self.pos[:, :T, :]
h = self.encoder(h, src_key_padding_mask=~mask)
pooled = _masked_mean_pool(h, mask)
if self.use_prev_action:
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
return self.head(pooled)
# ---------------------------------------------------------------------------
# Placeholder ActionLLM — a conv-stem sensor encoder + a 2-layer Transformer
# trained from scratch as a surrogate. The *full* LoRA+Qwen version lives in
# `train_pred.py` and can be wired in later if the surrogate is too weak.
# ---------------------------------------------------------------------------
class ActionLLMSurrogate(nn.Module):
def __init__(self, modality_dims: Dict[str, int], d_model: int = 192,
n_heads: int = 6, n_layers: int = 2, dropout: float = 0.1,
head_hidden: int = 256, max_T: int = 256,
use_prev_action: bool = False, prev_emb_dim: int = 32):
super().__init__()
self.use_prev_action = use_prev_action
in_dim = sum(modality_dims.values())
self.stem = nn.Sequential(
nn.Conv1d(in_dim, d_model, 5, padding=2),
nn.GELU(),
nn.Conv1d(d_model, d_model, 5, padding=2),
)
self.pos = nn.Parameter(torch.zeros(1, max_T, d_model))
nn.init.trunc_normal_(self.pos, std=0.02)
self.max_T = max_T
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
dropout=dropout, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
head_in = d_model
if use_prev_action:
self.prev_concat = _PrevActionConcat(prev_emb_dim)
head_in += self.prev_concat.out_dim
else:
self.prev_concat = None
self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout)
def forward(self, x, mask, prev_v_comp=None, prev_noun=None):
feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2)
h = self.stem(feats).transpose(1, 2)
T = h.size(1)
if T > self.max_T:
raise ValueError(f"T={T} exceeds ActionLLM max_T={self.max_T}")
h = h + self.pos[:, :T, :]
h = self.encoder(h, src_key_padding_mask=~mask)
pooled = _masked_mean_pool(h, mask)
if self.use_prev_action:
pooled = self.prev_concat(pooled, prev_v_comp, prev_noun)
return self.head(pooled)
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
def build_model(
name: str, modality_dims: Dict[str, int], **kwargs,
) -> nn.Module:
name = name.lower()
if name in ("deepconvlstm", "dcl"):
return DeepConvLSTMTriplet(modality_dims, **kwargs)
if name in ("dailyactformer", "ours", "daf"):
return DailyActFormer(modality_dims, **kwargs)
if name in ("rulstm",):
return RULSTMTriplet(modality_dims, **kwargs)
if name in ("futr",):
return FUTRTriplet(modality_dims, **kwargs)
if name in ("afft",):
return AFFTTriplet(modality_dims, **kwargs)
if name in ("handformer",):
return HandFormerTriplet(modality_dims, **kwargs)
if name in ("actionllm",):
return ActionLLMSurrogate(modality_dims, **kwargs)
raise ValueError(f"Unknown model: {name}")
# ---------------------------------------------------------------------------
# Smoke-test: build each model, run a random batch, check output shapes.
# ---------------------------------------------------------------------------
if __name__ == "__main__":
B, T = 2, 160
dims = {"imu": 180, "emg": 8, "eyetrack": 24}
x = {m: torch.randn(B, T, d) for m, d in dims.items()}
mask = torch.ones(B, T, dtype=torch.bool)
for name in ("deepconvlstm", "dailyactformer", "rulstm", "futr", "afft",
"handformer", "actionllm"):
model = build_model(name, dims)
n_params = sum(p.numel() for p in model.parameters())
out = model(x, mask)
print(f"{name:16s} params={n_params:>10,} shapes="
f"vf={tuple(out['verb_fine'].shape)} "
f"vc={tuple(out['verb_composite'].shape)} "
f"n={tuple(out['noun'].shape)} "
f"h={tuple(out['hand'].shape)}")