""" Published baseline models for DailyAct-5M benchmark. ASFormer: Transformer for Action Segmentation (Yi et al., BMVC 2021) - Multi-stage encoder-decoder transformer with dilated attention - For temporal action segmentation (Exp 2) and contact detection (Exp 3) TinyHAR: Lightweight Deep Learning Model for HAR (Zhou et al., ISWC 2022 Best Paper) - Multi-scale temporal convolution + cross-channel attention + temporal pooling - Implemented as backbone in models.py for scene recognition (Exp 1) """ import math import torch import torch.nn as nn import torch.nn.functional as F # ============================================================ # Positional Encoding (shared) # ============================================================ class PositionalEncoding1D(nn.Module): """Sinusoidal positional encoding.""" def __init__(self, d_model, dropout=0.1, max_len=10000): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) if d_model % 2 == 1: pe[:, 1::2] = torch.cos(position * div_term[:-1]) else: pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1)] return self.dropout(x) # ============================================================ # ASFormer (Yi et al., BMVC 2021) # ============================================================ class ConvFeedForward(nn.Module): """Position-wise convolution feed-forward used in ASFormer.""" def __init__(self, d_model, kernel_size=3, dropout=0.1): super().__init__() self.norm = nn.LayerNorm(d_model) self.conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size, padding=kernel_size // 2) self.conv2 = nn.Conv1d(d_model * 2, d_model, 1) self.dropout = nn.Dropout(dropout) def forward(self, x): # x: (B, T, D) residual = x x = self.norm(x) x = x.permute(0, 2, 1) # (B, D, T) x = self.dropout(F.relu(self.conv1(x))) x = self.dropout(self.conv2(x)) x = x.permute(0, 2, 1) # (B, T, D) return residual + x class DilatedAttention(nn.Module): """Multi-head self-attention with dilated temporal mask. At dilation d and window w, position t attends to positions {t + k*d : k in [-w, w]}, creating a hierarchical receptive field. """ def __init__(self, d_model, dilation, num_heads=1, dropout=0.1, window_size=5): super().__init__() self.d_model = d_model self.dilation = dilation self.window_size = window_size self.num_heads = num_heads self.head_dim = d_model // num_heads self.norm = nn.LayerNorm(d_model) self.qkv = nn.Linear(d_model, 3 * d_model) self.out_proj = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) # Cache for dilated masks self._mask_cache = {} def _get_dilated_mask(self, T, device): """Create or retrieve cached dilated attention mask.""" key = (T, self.dilation, self.window_size, device) if key not in self._mask_cache: positions = torch.arange(T, device=device) diff = positions.unsqueeze(1) - positions.unsqueeze(0) # (T, T) mask = torch.zeros(T, T, dtype=torch.bool, device=device) for w in range(-self.window_size, self.window_size + 1): mask |= (diff == w * self.dilation) self._mask_cache[key] = mask return self._mask_cache[key] def forward(self, x, cross_kv=None): # x: (B, T, D) B, T, D = x.shape residual = x x = self.norm(x) if cross_kv is not None: q = self.qkv(x)[:, :, :D] # only use Q from x kv = self.qkv(cross_kv)[:, :, D:] # K, V from cross_kv q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = kv[:, :, :D].view(B, T, self.num_heads, self.head_dim).transpose(1, 2) v = kv[:, :, D:].view(B, T, self.num_heads, self.head_dim).transpose(1, 2) else: qkv = self.qkv(x).view(B, T, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, T, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] scale = self.head_dim ** -0.5 attn = (q @ k.transpose(-2, -1)) * scale # (B, H, T, T) # Apply dilated attention mask dilated_mask = self._get_dilated_mask(T, x.device) # (T, T) attn = attn.masked_fill(~dilated_mask.unsqueeze(0).unsqueeze(0), float('-inf')) attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) out = (attn @ v).transpose(1, 2).reshape(B, T, D) out = self.out_proj(out) return residual + self.dropout(out) class ASFormerEncoderBlock(nn.Module): """Single encoder block: dilated self-attention + conv feed-forward.""" def __init__(self, d_model, dilation, num_heads=1, kernel_size=3, dropout=0.1, window_size=5): super().__init__() self.self_attn = DilatedAttention(d_model, dilation, num_heads, dropout, window_size) self.ffn = ConvFeedForward(d_model, kernel_size, dropout) def forward(self, x): x = self.self_attn(x) x = self.ffn(x) return x class ASFormerDecoderBlock(nn.Module): """Single decoder block: self-attention + cross-attention + conv feed-forward.""" def __init__(self, d_model, dilation, num_heads=1, kernel_size=3, dropout=0.1, window_size=5): super().__init__() self.self_attn = DilatedAttention(d_model, dilation, num_heads, dropout, window_size) self.cross_attn = DilatedAttention(d_model, dilation, num_heads, dropout, window_size) self.ffn = ConvFeedForward(d_model, kernel_size, dropout) def forward(self, x, enc_features): x = self.self_attn(x) x = self.cross_attn(x, cross_kv=enc_features) x = self.ffn(x) return x class ASFormerEncoder(nn.Module): """ASFormer encoder: projection + N dilated attention layers + output head.""" def __init__(self, input_dim, d_model, num_classes, num_layers=5, num_heads=1, kernel_size=3, dropout=0.1, window_size=5): super().__init__() self.input_proj = nn.Conv1d(input_dim, d_model, 1) self.pos_enc = PositionalEncoding1D(d_model, dropout) self.layers = nn.ModuleList([ ASFormerEncoderBlock(d_model, 2 ** i, num_heads, kernel_size, dropout, window_size) for i in range(num_layers) ]) self.output_proj = nn.Conv1d(d_model, num_classes, 1) def forward(self, x): # x: (B, T, C) x = x.permute(0, 2, 1) # (B, C, T) x = self.input_proj(x) # (B, d_model, T) x = x.permute(0, 2, 1) # (B, T, d_model) x = self.pos_enc(x) for layer in self.layers: x = layer(x) features = x logits = self.output_proj(x.permute(0, 2, 1)).permute(0, 2, 1) # (B, T, num_classes) return features, logits class ASFormerDecoder(nn.Module): """ASFormer decoder: refinement stage with cross-attention to encoder.""" def __init__(self, input_dim, d_model, num_classes, num_layers=5, num_heads=1, kernel_size=3, dropout=0.1, window_size=5): super().__init__() self.input_proj = nn.Conv1d(input_dim, d_model, 1) self.pos_enc = PositionalEncoding1D(d_model, dropout) self.layers = nn.ModuleList([ ASFormerDecoderBlock(d_model, 2 ** i, num_heads, kernel_size, dropout, window_size) for i in range(num_layers) ]) self.output_proj = nn.Conv1d(d_model, num_classes, 1) def forward(self, dec_input, enc_features): # dec_input: (B, T, input_dim), enc_features: (B, T, d_model) x = dec_input.permute(0, 2, 1) x = self.input_proj(x) x = x.permute(0, 2, 1) x = self.pos_enc(x) for layer in self.layers: x = layer(x, enc_features) logits = self.output_proj(x.permute(0, 2, 1)).permute(0, 2, 1) return x, logits class ASFormer(nn.Module): """ASFormer: Transformer for Action Segmentation (Yi et al., BMVC 2021). Multi-stage encoder-decoder transformer for frame-level action segmentation. Returns a list of per-stage logits for multi-stage training (same interface as MSTCN). Args: input_dim: Input feature dimension num_classes: Number of action classes hidden_dim: Hidden dimension (d_model) num_layers: Number of attention layers per stage (dilation 1, 2, ..., 2^(num_layers-1)) num_decoders: Number of decoder (refinement) stages num_heads: Number of attention heads kernel_size: Feed-forward convolution kernel size dropout: Dropout rate window_size: Dilated attention window size """ def __init__(self, input_dim, num_classes, hidden_dim=64, num_layers=5, num_decoders=3, num_heads=1, kernel_size=3, dropout=0.1, window_size=5): super().__init__() self.encoder = ASFormerEncoder( input_dim, hidden_dim, num_classes, num_layers, num_heads, kernel_size, dropout, window_size ) self.decoders = nn.ModuleList([ ASFormerDecoder( num_classes, hidden_dim, num_classes, num_layers, num_heads, kernel_size, dropout, window_size ) for _ in range(num_decoders) ]) def forward(self, x): # x: (B, T, C) outputs = [] enc_features, enc_logits = self.encoder(x) outputs.append(enc_logits) for decoder in self.decoders: dec_input = F.softmax(outputs[-1], dim=-1).detach() _, dec_logits = decoder(dec_input, enc_features) outputs.append(dec_logits) return outputs # list of (B, T, num_classes), compatible with MSTCN interface class ASFormerContact(nn.Module): """ASFormer adapted for binary contact detection (Exp 3). Wraps ASFormer to return only the final stage output (B, T, 2), compatible with the exp3 training loop. Uses multi-stage training internally but returns single output. """ def __init__(self, input_dim, hidden_dim=64, num_layers=5, num_decoders=2, num_heads=1, dropout=0.1): super().__init__() self.asformer = ASFormer( input_dim, num_classes=2, hidden_dim=hidden_dim, num_layers=num_layers, num_decoders=num_decoders, num_heads=num_heads, dropout=dropout ) def forward(self, x): # x: (B, T, C) -> (B, T, 2) outputs = self.asformer(x) return outputs[-1] # Return final stage only