velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
"""
Published baselines for T1 Scene Recognition, reproduced on DailyAct-5M.
Each method accepts a concatenated feature tensor (B, T, F_total) where F_total
is the sum of the active modality dims; the per-modality slices are recorded in
the `modality_dims` dict. Each method then uses the subset of modalities its
original paper intended.
All methods output an (B, num_classes) logit tensor.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def _slice(x, mod_dims, wanted):
"""Slice the concatenated feature tensor to keep only `wanted` modalities,
in the order given. mod_dims is an ordered dict. Returns
{name: tensor(B,T,d_name)} plus the concat."""
parts = {}
offset = 0
for name, d in mod_dims.items():
if name in wanted:
parts[name] = x[..., offset:offset + d]
offset += d
assert len(parts) > 0, f"None of {wanted} in {list(mod_dims.keys())}"
return parts
# ---------------------------------------------------------------------------
# 1) ST-GCN (Yan et al., AAAI 2018)
# Spatio-temporal graph CNN for skeleton action recognition.
# We treat the 56-joint MoCap skeleton as the graph.
# ---------------------------------------------------------------------------
class STGCNBlock(nn.Module):
def __init__(self, in_ch, out_ch, n_joints, stride=1, dropout=0.2):
super().__init__()
# Spatial graph conv: learnable adjacency (fully learned, no handcrafted A)
self.A = nn.Parameter(torch.eye(n_joints) + 0.1 * torch.randn(n_joints, n_joints))
self.spatial = nn.Conv2d(in_ch, out_ch, kernel_size=(1, 1), bias=False)
self.spatial_bn = nn.BatchNorm2d(out_ch)
self.temporal = nn.Conv2d(out_ch, out_ch, kernel_size=(9, 1),
padding=(4, 0), stride=(stride, 1))
self.temporal_bn = nn.BatchNorm2d(out_ch)
self.dropout = nn.Dropout(dropout)
if in_ch != out_ch or stride != 1:
self.res = nn.Conv2d(in_ch, out_ch, kernel_size=1,
stride=(stride, 1))
else:
self.res = nn.Identity()
def forward(self, x):
# x: (B, C, T, V)
res = self.res(x)
# spatial: aggregate along joints via A
h = self.spatial(x)
h = torch.einsum('bctv,vw->bctw', h, F.softmax(self.A, dim=-1))
h = self.spatial_bn(h)
h = F.relu(h)
# temporal
h = self.temporal(h)
h = self.temporal_bn(h)
h = self.dropout(h)
return F.relu(h + res)
class STGCN(nn.Module):
"""ST-GCN on MoCap skeleton. We assume the MoCap modality is 620-dim
(hip-relative + velocity) and reshape to ~56 joints."""
def __init__(self, feat_dim_mocap, num_classes, hidden=64, n_joints=52):
super().__init__()
self.n_joints = n_joints
# MoCap feat is (T, 620). 52 joints × 4 (xyz+quat_type), or we take per-joint xyz-only = 156.
# In this repo, 620 = 52 markers * 4 cols + velocity features. We'll
# reshape by slicing to 3*52=156 "primary" coords, padded if needed.
self.coord_dim = 3 # we'll treat each joint as having 3 coords (XYZ)
self.proj_in = nn.Linear(feat_dim_mocap, n_joints * self.coord_dim)
self.blocks = nn.ModuleList([
STGCNBlock(self.coord_dim, hidden, n_joints),
STGCNBlock(hidden, hidden, n_joints),
STGCNBlock(hidden, hidden * 2, n_joints, stride=2),
STGCNBlock(hidden * 2, hidden * 2, n_joints),
STGCNBlock(hidden * 2, hidden * 4, n_joints, stride=2),
STGCNBlock(hidden * 4, hidden * 4, n_joints),
])
self.head = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(hidden * 4, num_classes),
)
def forward(self, x_mocap, mask=None):
# x_mocap: (B, T, feat_dim_mocap)
B, T, _ = x_mocap.shape
h = self.proj_in(x_mocap) # (B, T, n_joints * 3)
h = h.reshape(B, T, self.n_joints, self.coord_dim).permute(0, 3, 1, 2) # (B, C, T, V)
for blk in self.blocks:
h = blk(h)
# Global mean pool over time & joints (with mask if provided)
if mask is not None:
# mask: (B, T), h: (B, C, T', V) where T' may be < T due to stride
T_ = h.shape[2]
m = mask[:, :T_].float().unsqueeze(1).unsqueeze(-1) # (B, 1, T', 1)
h = (h * m).sum(dim=(2, 3)) / (m.sum(dim=(2, 3)) * h.shape[3] + 1e-8)
else:
h = h.mean(dim=(2, 3))
return self.head(h)
# ---------------------------------------------------------------------------
# 2) CTR-GCN (Chen et al., ICCV 2021)
# Channel-wise Topology Refinement GCN — learns a separate adjacency
# matrix per channel group, known as SOTA for skeleton action recognition.
# ---------------------------------------------------------------------------
class CTRGC(nn.Module):
"""Simplified CTR-GC block: learnable per-channel topology refinement."""
def __init__(self, in_ch, out_ch, n_joints, rel_reduction=4):
super().__init__()
self.n_joints = n_joints
self.conv1 = nn.Conv2d(in_ch, out_ch // rel_reduction, 1)
self.conv2 = nn.Conv2d(in_ch, out_ch // rel_reduction, 1)
self.conv3 = nn.Conv2d(in_ch, out_ch, 1)
self.alpha = nn.Parameter(torch.zeros(1))
self.A = nn.Parameter(torch.eye(n_joints) + 0.1 * torch.randn(n_joints, n_joints))
def forward(self, x):
# x: (B, C, T, V)
q = self.conv1(x).mean(dim=2) # (B, C', V)
k = self.conv2(x).mean(dim=2) # (B, C', V)
v = self.conv3(x) # (B, C_out, T, V)
# Channel-specific topology refinement
topology = F.softmax(torch.tanh(q.unsqueeze(-1) - k.unsqueeze(-2)), dim=-1)
# topology: (B, C', V, V); we average across channels to get a shared (B, V, V)
topology = topology.mean(dim=1)
A = self.A.unsqueeze(0) + self.alpha * topology
# apply A to v
out = torch.einsum('bctv,bvw->bctw', v, A)
return out
class CTRGCNBlock(nn.Module):
def __init__(self, in_ch, out_ch, n_joints, stride=1):
super().__init__()
self.gc = CTRGC(in_ch, out_ch, n_joints)
self.bn = nn.BatchNorm2d(out_ch)
self.tcn = nn.Sequential(
nn.Conv2d(out_ch, out_ch, (9, 1), padding=(4, 0), stride=(stride, 1)),
nn.BatchNorm2d(out_ch),
)
if in_ch != out_ch or stride != 1:
self.res = nn.Conv2d(in_ch, out_ch, 1, stride=(stride, 1))
else:
self.res = nn.Identity()
def forward(self, x):
res = self.res(x)
h = self.gc(x)
h = self.bn(h)
h = F.relu(h)
h = self.tcn(h)
return F.relu(h + res)
class CTRGCN(nn.Module):
def __init__(self, feat_dim_mocap, num_classes, hidden=64, n_joints=52):
super().__init__()
self.n_joints = n_joints
self.coord_dim = 3
self.proj_in = nn.Linear(feat_dim_mocap, n_joints * self.coord_dim)
self.blocks = nn.ModuleList([
CTRGCNBlock(self.coord_dim, hidden, n_joints),
CTRGCNBlock(hidden, hidden, n_joints),
CTRGCNBlock(hidden, hidden * 2, n_joints, stride=2),
CTRGCNBlock(hidden * 2, hidden * 4, n_joints, stride=2),
])
self.head = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(hidden * 4, num_classes),
)
def forward(self, x_mocap, mask=None):
B, T, _ = x_mocap.shape
h = self.proj_in(x_mocap)
h = h.reshape(B, T, self.n_joints, self.coord_dim).permute(0, 3, 1, 2)
for blk in self.blocks:
h = blk(h)
h = h.mean(dim=(2, 3))
return self.head(h)
# ---------------------------------------------------------------------------
# 3) LIMU-BERT (Xu et al., SenSys 2021)
# IMU self-supervised pretraining via masked reconstruction + fine-tune.
# We implement a simpler variant: BERT-style encoder with optional
# pretraining head.
# ---------------------------------------------------------------------------
class LIMUBertEncoder(nn.Module):
def __init__(self, feat_dim_imu, hidden=128, n_layers=4, n_heads=4, dropout=0.1):
super().__init__()
self.in_proj = nn.Linear(feat_dim_imu, hidden)
self.pos = nn.Parameter(torch.zeros(1, 4096, hidden))
nn.init.trunc_normal_(self.pos, std=0.02)
layer = nn.TransformerEncoderLayer(
d_model=hidden, nhead=n_heads, dim_feedforward=4 * hidden,
dropout=dropout, batch_first=True, activation='gelu',
)
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
def forward(self, x, mask):
T = x.size(1)
h = self.in_proj(x) + self.pos[:, :T, :]
h = self.encoder(h, src_key_padding_mask=~mask)
return h
class LIMUBert(nn.Module):
"""Supervised-only variant: encoder + classifier head. Paper's
pretraining is a masked-recon objective; for simplicity we report the
supervised-only baseline here."""
def __init__(self, feat_dim_imu, num_classes, hidden=128, n_layers=4,
n_heads=4, dropout=0.1):
super().__init__()
self.encoder = LIMUBertEncoder(feat_dim_imu, hidden, n_layers, n_heads, dropout)
self.head = nn.Sequential(
nn.LayerNorm(hidden),
nn.Dropout(dropout),
nn.Linear(hidden, num_classes),
)
def forward(self, x_imu, mask):
h = self.encoder(x_imu, mask)
m = mask.unsqueeze(-1).float()
pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
return self.head(pooled)
# ---------------------------------------------------------------------------
# 4) EMG-CNN (standard 1D CNN baseline from sEMG classification literature)
# E.g. Atzori et al. — multi-layer CNN with moving-window input.
# ---------------------------------------------------------------------------
class EMGCNN(nn.Module):
def __init__(self, feat_dim_emg, num_classes, hidden=64):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv1d(feat_dim_emg, hidden, 7, padding=3),
nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(0.3),
nn.Conv1d(hidden, hidden * 2, 5, padding=2),
nn.BatchNorm1d(hidden * 2), nn.ReLU(), nn.Dropout(0.3),
nn.Conv1d(hidden * 2, hidden * 4, 3, padding=1),
nn.BatchNorm1d(hidden * 4), nn.ReLU(),
)
self.head = nn.Linear(hidden * 4, num_classes)
def forward(self, x_emg, mask):
# (B, T, 8) -> (B, 8, T) for conv1d
h = self.cnn(x_emg.transpose(1, 2))
# Masked pool
m = mask.unsqueeze(1).float()
T_ = h.size(2)
if m.size(2) != T_:
m = F.adaptive_avg_pool1d(m, T_)
m = (m > 0.5).float()
pooled = (h * m).sum(dim=2) / m.sum(dim=2).clamp(min=1.0)
return self.head(pooled)
# ---------------------------------------------------------------------------
# 5) ActionSense baseline (DelPreto et al., NeurIPS '22)
# Simple 3-layer MLP per modality + shared LSTM + classifier.
# ---------------------------------------------------------------------------
class ActionSenseLSTM(nn.Module):
def __init__(self, modality_dims: dict, num_classes, hidden=128):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = modality_dims
self.per_mod = nn.ModuleDict({
name: nn.Sequential(
nn.Linear(d, hidden), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(hidden, hidden), nn.ReLU(),
) for name, d in modality_dims.items()
})
concat_dim = hidden * len(modality_dims)
self.lstm = nn.LSTM(concat_dim, hidden, num_layers=2,
batch_first=True, bidirectional=True, dropout=0.2)
self.head = nn.Linear(hidden * 2, num_classes)
def forward(self, x, mask):
# x: (B, T, F_total), slice by modality
offset = 0
feats = []
for name in self.mod_names:
d = self.mod_dims[name]
x_m = x[..., offset:offset + d]
offset += d
feats.append(self.per_mod[name](x_m))
h = torch.cat(feats, dim=-1) # (B, T, hidden * M)
h, _ = self.lstm(h)
m = mask.unsqueeze(-1).float()
pooled = (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
return self.head(pooled)
# ---------------------------------------------------------------------------
# 6) MulT (Multimodal Transformer, Tsai et al., ACL 2019)
# Core idea: cross-modal attention between every pair of modalities.
# For a 3-modality input (A, B, C), produce
# {A->B, A->C, B->A, B->C, C->A, C->B} via directed cross-attention.
# ---------------------------------------------------------------------------
class CrossModalTransformer(nn.Module):
def __init__(self, d_model, n_heads=4, n_layers=2, dropout=0.1):
super().__init__()
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
dropout=dropout, batch_first=True, activation='gelu',
) for _ in range(n_layers)
])
def forward(self, q, kv, q_mask, kv_mask):
# q: (B, T_q, D), kv: (B, T_kv, D)
h = q
for layer in self.layers:
h = layer(h, kv,
tgt_key_padding_mask=~q_mask,
memory_key_padding_mask=~kv_mask)
return h
class MulT(nn.Module):
"""Multimodal Transformer. Uses MoCap + EMG + IMU as 3 modalities
(EyeTrack/Pressure omitted to match original 3-mod paper design)."""
def __init__(self, modality_dims: dict, num_classes, d_model=128,
n_layers=2, n_heads=4, dropout=0.1):
super().__init__()
self.mod_names = [m for m in ['mocap', 'emg', 'imu'] if m in modality_dims]
if len(self.mod_names) < 2:
self.mod_names = list(modality_dims.keys())[:3]
self.mod_dims = {m: modality_dims[m] for m in self.mod_names}
self.in_proj = nn.ModuleDict({
m: nn.Linear(d, d_model) for m, d in self.mod_dims.items()
})
# Pairwise cross-attention
self.cross = nn.ModuleDict({
f"{a}_to_{b}": CrossModalTransformer(d_model, n_heads, n_layers, dropout)
for a in self.mod_names for b in self.mod_names if a != b
})
# Self-attention after cross
self.self_tx = nn.ModuleDict({
m: nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads,
dim_feedforward=4 * d_model, dropout=dropout,
batch_first=True, activation='gelu',
), num_layers=1,
) for m in self.mod_names
})
total_dim = d_model * len(self.mod_names) * len(self.mod_names)
self.head = nn.Sequential(
nn.LayerNorm(total_dim),
nn.Dropout(dropout),
nn.Linear(total_dim, num_classes),
)
def forward(self, x, mask):
# Slice modalities from x
offset = 0
projs = {}
# Walk through all known mod_dims to find offsets
# We need the FULL modality_dims order, which we don't have here;
# expect caller to already supply x with exactly mod_names in order.
# Workaround: assume caller passes mod_names order matching projection.
for m in self.mod_names:
d = self.mod_dims[m]
projs[m] = self.in_proj[m](x[..., offset:offset + d])
offset += d
# Cross-attention: each modality attends to each other
fused = {m: [] for m in self.mod_names}
for a in self.mod_names:
for b in self.mod_names:
if a == b:
fused[a].append(projs[a])
else:
out = self.cross[f"{a}_to_{b}"](projs[a], projs[b], mask, mask)
fused[a].append(out)
# Self-attention + pool per modality
pooled = []
for a in self.mod_names:
# Concat all attended-to representations along feature dim
cat = torch.cat(fused[a], dim=-1) # (B, T, D * M)
# Actually re-project back to D per stream, then self-attn on stacked
# Simplified: self-attention over concatenated, pool, flatten
# Here we just pool each separately
for i, rep in enumerate(fused[a]):
rep = self.self_tx[a](rep)
m = mask.unsqueeze(-1).float()
p = (rep * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0)
pooled.append(p)
h = torch.cat(pooled, dim=-1)
return self.head(h)
# ---------------------------------------------------------------------------
# 7) Perceiver IO (Jaegle et al., ICML 2021)
# Cross-attention from a fixed-size latent query set to all input tokens,
# repeated for a few iterations.
# ---------------------------------------------------------------------------
class PerceiverBlock(nn.Module):
def __init__(self, latent_dim, n_heads, dropout):
super().__init__()
self.ca = nn.MultiheadAttention(
latent_dim, n_heads, dropout=dropout, batch_first=True,
)
self.norm1 = nn.LayerNorm(latent_dim)
self.sa = nn.TransformerEncoderLayer(
d_model=latent_dim, nhead=n_heads,
dim_feedforward=4 * latent_dim, dropout=dropout,
batch_first=True, activation='gelu',
)
def forward(self, latents, inputs, input_kpm):
# Cross-attn: latents attend to inputs
h, _ = self.ca(latents, inputs, inputs, key_padding_mask=input_kpm)
latents = self.norm1(latents + h)
# Self-attn on latents
latents = self.sa(latents)
return latents
class PerceiverIO(nn.Module):
"""Perceiver with N learnable latent queries; supports any modality mix."""
def __init__(self, modality_dims: dict, num_classes,
latent_dim=128, n_latents=32, n_layers=3, n_heads=4, dropout=0.1):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = modality_dims
# Per-modality input projection to latent_dim, with modality-id embedding
self.in_proj = nn.ModuleDict({
m: nn.Linear(d, latent_dim) for m, d in modality_dims.items()
})
self.mod_emb = nn.Parameter(torch.randn(len(self.mod_names), latent_dim) * 0.02)
# Positional encoding (shared)
self.pos = nn.Parameter(torch.zeros(1, 4096, latent_dim))
nn.init.trunc_normal_(self.pos, std=0.02)
# Learnable latents
self.latents = nn.Parameter(torch.randn(n_latents, latent_dim) * 0.02)
self.blocks = nn.ModuleList([
PerceiverBlock(latent_dim, n_heads, dropout) for _ in range(n_layers)
])
self.head = nn.Sequential(
nn.LayerNorm(latent_dim),
nn.Linear(latent_dim, num_classes),
)
def forward(self, x, mask):
B, T, _ = x.shape
# Project each modality + add modality embedding
offset = 0
tokens = []
for i, m in enumerate(self.mod_names):
d = self.mod_dims[m]
tok = self.in_proj[m](x[..., offset:offset + d]) # (B, T, D)
tok = tok + self.mod_emb[i]
offset += d
tokens.append(tok)
# Concatenate along TIME dim, add shared pos enc per-modality
# Each modality gets its own time sequence concatenated
# Simpler: sum across modalities (like early fusion in latent space) + pos
h = torch.stack(tokens, dim=2).mean(dim=2) # (B, T, D)
h = h + self.pos[:, :T, :]
input_kpm = ~mask # (B, T), True = ignore
# Iterative cross-attention
latents = self.latents.unsqueeze(0).expand(B, -1, -1) # (B, N, D)
for blk in self.blocks:
latents = blk(latents, h, input_kpm)
# Mean-pool latents
pooled = latents.mean(dim=1)
return self.head(pooled)