"""Frame-level future forecasting models. Three baselines (all sharing the same forecast head signature): - TransformerForecast (our DAF-style) - FUTRForecast (Transformer encoder + parallel query decoder) - DeepConvLSTMForecast (OrdoƱez & Roggen 2016 wearable HAR backbone) All take a dict {mod: (B, T_obs, F_mod)} and output (B, T_fut, num_classes). """ from __future__ import annotations from typing import Dict, List import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Shared per-modality projection: each modality -> hidden dim d_model # --------------------------------------------------------------------------- class _PerModalityProj(nn.Module): def __init__(self, modality_dims: Dict[str, int], d_model: int): super().__init__() self.proj = nn.ModuleDict({ m: nn.Linear(d, d_model) for m, d in modality_dims.items() }) self.mod_emb = nn.Parameter(torch.zeros(len(modality_dims), d_model)) nn.init.trunc_normal_(self.mod_emb, std=0.02) self.mods = list(modality_dims.keys()) def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: # Concatenate per-modality projections along time? Or sum? # We sum modality-projected features per time step (with modality # embedding broadcast). Equivalent to early-fusion at the d_model # space and is what a "modality-aware Transformer" typically uses. out = None for i, m in enumerate(self.mods): h = self.proj[m](x[m]) + self.mod_emb[i] out = h if out is None else out + h return out / len(self.mods) # (B, T_obs, d_model) # --------------------------------------------------------------------------- # 1. Transformer (DAF-style) forecast model # --------------------------------------------------------------------------- class TransformerForecast(nn.Module): def __init__(self, modality_dims: Dict[str, int], num_classes: int, t_obs: int, t_fut: int, d_model: int = 128, n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1): super().__init__() self.t_obs = t_obs self.t_fut = t_fut self.num_classes = num_classes self.embed = _PerModalityProj(modality_dims, d_model) self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model)) nn.init.trunc_normal_(self.pos, std=0.02) layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, dropout=dropout, batch_first=True, activation="gelu", ) self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers) self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model)) nn.init.trunc_normal_(self.queries, std=0.02) self.cross_attn = nn.MultiheadAttention( d_model, n_heads, dropout=dropout, batch_first=True ) self.norm = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, num_classes) def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: h = self.embed(x) + self.pos h = self.encoder(h) # (B, T_obs, D) q = self.queries.expand(h.size(0), -1, -1) # (B, T_fut, D) out, _ = self.cross_attn(q, h, h, need_weights=False) out = self.norm(out) return self.head(out) # (B, T_fut, C) # --------------------------------------------------------------------------- # 2. FUTR-style forecast (Future Transformer, Gong et al. CVPR 2022) # Same encoder + parallel query decoder. We add a small Transformer # decoder so it's not literally identical to TransformerForecast. # --------------------------------------------------------------------------- class FUTRForecast(nn.Module): def __init__(self, modality_dims: Dict[str, int], num_classes: int, t_obs: int, t_fut: int, d_model: int = 128, n_heads: int = 4, n_enc: int = 2, n_dec: int = 1, dropout: float = 0.1): super().__init__() self.t_obs = t_obs self.t_fut = t_fut self.num_classes = num_classes self.embed = _PerModalityProj(modality_dims, d_model) self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model)) nn.init.trunc_normal_(self.pos, std=0.02) enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, dropout=dropout, batch_first=True, activation="gelu", ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_enc) dec_layer = nn.TransformerDecoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, dropout=dropout, batch_first=True, activation="gelu", ) self.decoder = nn.TransformerDecoder(dec_layer, num_layers=n_dec) self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model)) nn.init.trunc_normal_(self.queries, std=0.02) self.head = nn.Linear(d_model, num_classes) def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: memory = self.encoder(self.embed(x) + self.pos) # (B, T_obs, D) q = self.queries.expand(memory.size(0), -1, -1) # (B, T_fut, D) out = self.decoder(q, memory) return self.head(out) # (B, T_fut, C) # --------------------------------------------------------------------------- # 3. DeepConvLSTM-style forecast # --------------------------------------------------------------------------- class DeepConvLSTMForecast(nn.Module): def __init__(self, modality_dims: Dict[str, int], num_classes: int, t_obs: int, t_fut: int, conv_filters: int = 64, lstm_hidden: int = 128, n_lstm_layers: int = 2, dropout: float = 0.1): super().__init__() self.t_obs = t_obs self.t_fut = t_fut self.num_classes = num_classes self.mods = list(modality_dims.keys()) in_ch = sum(modality_dims.values()) # Same 4-layer conv stack as the original DeepConvLSTM layers = [] ch = in_ch for i in range(4): layers.append(nn.Sequential( nn.Conv1d(ch, conv_filters, kernel_size=5, padding=2), nn.BatchNorm1d(conv_filters), nn.ReLU(), nn.Dropout(dropout if i < 3 else 0.2), )) ch = conv_filters self.convs = nn.ModuleList(layers) self.lstm = nn.LSTM( conv_filters, lstm_hidden, num_layers=n_lstm_layers, batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0, ) self.head = nn.Linear(lstm_hidden, t_fut * num_classes) def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: h = torch.cat([x[m] for m in self.mods], dim=-1) # (B, T_obs, F_total) h = h.permute(0, 2, 1) # (B, F, T_obs) for c in self.convs: h = c(h) h = h.permute(0, 2, 1) # (B, T_obs, conv_filters) out, (h_n, _) = self.lstm(h) feat = h_n[-1] # (B, lstm_hidden) logits = self.head(feat).view(-1, self.t_fut, self.num_classes) return logits # --------------------------------------------------------------------------- # 4. RU-LSTM (Furnari et al. RAL 2019, "Rolling-Unrolling LSTM for action # anticipation"). Two-phase LSTM: a "rolling" phase encodes past, an # "unrolling" phase autoregressively decodes future tokens. # --------------------------------------------------------------------------- class RULSTMForecast(nn.Module): def __init__(self, modality_dims: Dict[str, int], num_classes: int, t_obs: int, t_fut: int, d_model: int = 128, n_lstm_layers: int = 2, dropout: float = 0.1): super().__init__() self.t_obs = t_obs self.t_fut = t_fut self.num_classes = num_classes self.embed = _PerModalityProj(modality_dims, d_model) self.rolling = nn.LSTM( d_model, d_model, num_layers=n_lstm_layers, batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0, ) self.unrolling = nn.LSTM( d_model, d_model, num_layers=n_lstm_layers, batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0, ) self.fut_init = nn.Parameter(torch.zeros(1, 1, d_model)) nn.init.trunc_normal_(self.fut_init, std=0.02) self.head = nn.Linear(d_model, num_classes) def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: h_past = self.embed(x) # (B, T_obs, D) _, (h_n, c_n) = self.rolling(h_past) B = h_past.size(0) # Use a learned initial future token, repeated T_fut times fut_input = self.fut_init.expand(B, self.t_fut, -1) out, _ = self.unrolling(fut_input, (h_n, c_n)) return self.head(out) # (B, T_fut, C) # --------------------------------------------------------------------------- # 5. AVT (Girdhar & Grauman ICCV 2021, "Anticipative Video Transformer"). # Causal Transformer over the concatenation of past + future tokens. # --------------------------------------------------------------------------- class AVTForecast(nn.Module): def __init__(self, modality_dims: Dict[str, int], num_classes: int, t_obs: int, t_fut: int, d_model: int = 128, n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1): super().__init__() self.t_obs = t_obs self.t_fut = t_fut self.num_classes = num_classes self.embed = _PerModalityProj(modality_dims, d_model) seq_len = t_obs + t_fut self.pos = nn.Parameter(torch.zeros(1, seq_len, d_model)) nn.init.trunc_normal_(self.pos, std=0.02) layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, dropout=dropout, batch_first=True, activation="gelu", ) self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers) self.fut_tokens = nn.Parameter(torch.zeros(1, t_fut, d_model)) nn.init.trunc_normal_(self.fut_tokens, std=0.02) self.head = nn.Linear(d_model, num_classes) # Causal mask over concatenated [past | future] sequence mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() self.register_buffer("causal_mask", mask) def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: h_past = self.embed(x) # (B, T_obs, D) B = h_past.size(0) h_fut = self.fut_tokens.expand(B, -1, -1) # (B, T_fut, D) seq = torch.cat([h_past, h_fut], dim=1) + self.pos out = self.encoder(seq, mask=self.causal_mask) out_fut = out[:, self.t_obs:, :] return self.head(out_fut) # (B, T_fut, C) # --------------------------------------------------------------------------- # Builder # --------------------------------------------------------------------------- def build_forecast_model(name: str, modality_dims: Dict[str, int], num_classes: int, t_obs: int, t_fut: int, d_model: int = 128, dropout: float = 0.1) -> nn.Module: name = name.lower() if name in ("daf", "transformer"): return TransformerForecast(modality_dims, num_classes, t_obs=t_obs, t_fut=t_fut, d_model=d_model, dropout=dropout) if name == "futr": return FUTRForecast(modality_dims, num_classes, t_obs=t_obs, t_fut=t_fut, d_model=d_model, dropout=dropout) if name == "deepconvlstm": return DeepConvLSTMForecast(modality_dims, num_classes, t_obs=t_obs, t_fut=t_fut, dropout=dropout) if name in ("rulstm", "ru-lstm", "ru_lstm"): return RULSTMForecast(modality_dims, num_classes, t_obs=t_obs, t_fut=t_fut, d_model=d_model, dropout=dropout) if name == "avt": return AVTForecast(modality_dims, num_classes, t_obs=t_obs, t_fut=t_fut, d_model=d_model, dropout=dropout) raise ValueError(f"Unknown forecast model: {name!r}")