| """Frame-level future forecasting models. |
| |
| Three baselines (all sharing the same forecast head signature): |
| - TransformerForecast (our DAF-style) |
| - FUTRForecast (Transformer encoder + parallel query decoder) |
| - DeepConvLSTMForecast (Ordoñez & Roggen 2016 wearable HAR backbone) |
| |
| All take a dict {mod: (B, T_obs, F_mod)} and output (B, T_fut, num_classes). |
| """ |
| from __future__ import annotations |
| from typing import Dict, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| class _PerModalityProj(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], d_model: int): |
| super().__init__() |
| self.proj = nn.ModuleDict({ |
| m: nn.Linear(d, d_model) for m, d in modality_dims.items() |
| }) |
| self.mod_emb = nn.Parameter(torch.zeros(len(modality_dims), d_model)) |
| nn.init.trunc_normal_(self.mod_emb, std=0.02) |
| self.mods = list(modality_dims.keys()) |
|
|
| def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: |
| |
| |
| |
| |
| out = None |
| for i, m in enumerate(self.mods): |
| h = self.proj[m](x[m]) + self.mod_emb[i] |
| out = h if out is None else out + h |
| return out / len(self.mods) |
|
|
|
|
| |
| |
| |
|
|
| class TransformerForecast(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], num_classes: int, |
| t_obs: int, t_fut: int, d_model: int = 128, |
| n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1): |
| super().__init__() |
| self.t_obs = t_obs |
| self.t_fut = t_fut |
| self.num_classes = num_classes |
| self.embed = _PerModalityProj(modality_dims, d_model) |
| self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model)) |
| nn.init.trunc_normal_(self.pos, std=0.02) |
| layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, |
| dropout=dropout, batch_first=True, activation="gelu", |
| ) |
| self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers) |
| self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model)) |
| nn.init.trunc_normal_(self.queries, std=0.02) |
| self.cross_attn = nn.MultiheadAttention( |
| d_model, n_heads, dropout=dropout, batch_first=True |
| ) |
| self.norm = nn.LayerNorm(d_model) |
| self.head = nn.Linear(d_model, num_classes) |
|
|
| def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: |
| h = self.embed(x) + self.pos |
| h = self.encoder(h) |
| q = self.queries.expand(h.size(0), -1, -1) |
| out, _ = self.cross_attn(q, h, h, need_weights=False) |
| out = self.norm(out) |
| return self.head(out) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| class FUTRForecast(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], num_classes: int, |
| t_obs: int, t_fut: int, d_model: int = 128, |
| n_heads: int = 4, n_enc: int = 2, n_dec: int = 1, |
| dropout: float = 0.1): |
| super().__init__() |
| self.t_obs = t_obs |
| self.t_fut = t_fut |
| self.num_classes = num_classes |
| self.embed = _PerModalityProj(modality_dims, d_model) |
| self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model)) |
| nn.init.trunc_normal_(self.pos, std=0.02) |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, |
| dropout=dropout, batch_first=True, activation="gelu", |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_enc) |
| dec_layer = nn.TransformerDecoderLayer( |
| d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, |
| dropout=dropout, batch_first=True, activation="gelu", |
| ) |
| self.decoder = nn.TransformerDecoder(dec_layer, num_layers=n_dec) |
| self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model)) |
| nn.init.trunc_normal_(self.queries, std=0.02) |
| self.head = nn.Linear(d_model, num_classes) |
|
|
| def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: |
| memory = self.encoder(self.embed(x) + self.pos) |
| q = self.queries.expand(memory.size(0), -1, -1) |
| out = self.decoder(q, memory) |
| return self.head(out) |
|
|
|
|
| |
| |
| |
|
|
| class DeepConvLSTMForecast(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], num_classes: int, |
| t_obs: int, t_fut: int, conv_filters: int = 64, |
| lstm_hidden: int = 128, n_lstm_layers: int = 2, |
| dropout: float = 0.1): |
| super().__init__() |
| self.t_obs = t_obs |
| self.t_fut = t_fut |
| self.num_classes = num_classes |
| self.mods = list(modality_dims.keys()) |
| in_ch = sum(modality_dims.values()) |
| |
| layers = [] |
| ch = in_ch |
| for i in range(4): |
| layers.append(nn.Sequential( |
| nn.Conv1d(ch, conv_filters, kernel_size=5, padding=2), |
| nn.BatchNorm1d(conv_filters), |
| nn.ReLU(), |
| nn.Dropout(dropout if i < 3 else 0.2), |
| )) |
| ch = conv_filters |
| self.convs = nn.ModuleList(layers) |
| self.lstm = nn.LSTM( |
| conv_filters, lstm_hidden, num_layers=n_lstm_layers, |
| batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0, |
| ) |
| self.head = nn.Linear(lstm_hidden, t_fut * num_classes) |
|
|
| def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: |
| h = torch.cat([x[m] for m in self.mods], dim=-1) |
| h = h.permute(0, 2, 1) |
| for c in self.convs: |
| h = c(h) |
| h = h.permute(0, 2, 1) |
| out, (h_n, _) = self.lstm(h) |
| feat = h_n[-1] |
| logits = self.head(feat).view(-1, self.t_fut, self.num_classes) |
| return logits |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| class RULSTMForecast(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], num_classes: int, |
| t_obs: int, t_fut: int, d_model: int = 128, |
| n_lstm_layers: int = 2, dropout: float = 0.1): |
| super().__init__() |
| self.t_obs = t_obs |
| self.t_fut = t_fut |
| self.num_classes = num_classes |
| self.embed = _PerModalityProj(modality_dims, d_model) |
| self.rolling = nn.LSTM( |
| d_model, d_model, num_layers=n_lstm_layers, |
| batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0, |
| ) |
| self.unrolling = nn.LSTM( |
| d_model, d_model, num_layers=n_lstm_layers, |
| batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0, |
| ) |
| self.fut_init = nn.Parameter(torch.zeros(1, 1, d_model)) |
| nn.init.trunc_normal_(self.fut_init, std=0.02) |
| self.head = nn.Linear(d_model, num_classes) |
|
|
| def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: |
| h_past = self.embed(x) |
| _, (h_n, c_n) = self.rolling(h_past) |
| B = h_past.size(0) |
| |
| fut_input = self.fut_init.expand(B, self.t_fut, -1) |
| out, _ = self.unrolling(fut_input, (h_n, c_n)) |
| return self.head(out) |
|
|
|
|
| |
| |
| |
| |
|
|
| class AVTForecast(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], num_classes: int, |
| t_obs: int, t_fut: int, d_model: int = 128, |
| n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1): |
| super().__init__() |
| self.t_obs = t_obs |
| self.t_fut = t_fut |
| self.num_classes = num_classes |
| self.embed = _PerModalityProj(modality_dims, d_model) |
| seq_len = t_obs + t_fut |
| self.pos = nn.Parameter(torch.zeros(1, seq_len, d_model)) |
| nn.init.trunc_normal_(self.pos, std=0.02) |
| layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, |
| dropout=dropout, batch_first=True, activation="gelu", |
| ) |
| self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers) |
| self.fut_tokens = nn.Parameter(torch.zeros(1, t_fut, d_model)) |
| nn.init.trunc_normal_(self.fut_tokens, std=0.02) |
| self.head = nn.Linear(d_model, num_classes) |
| |
| mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() |
| self.register_buffer("causal_mask", mask) |
|
|
| def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: |
| h_past = self.embed(x) |
| B = h_past.size(0) |
| h_fut = self.fut_tokens.expand(B, -1, -1) |
| seq = torch.cat([h_past, h_fut], dim=1) + self.pos |
| out = self.encoder(seq, mask=self.causal_mask) |
| out_fut = out[:, self.t_obs:, :] |
| return self.head(out_fut) |
|
|
|
|
| |
| |
| |
|
|
| def build_forecast_model(name: str, modality_dims: Dict[str, int], |
| num_classes: int, t_obs: int, t_fut: int, |
| d_model: int = 128, dropout: float = 0.1) -> nn.Module: |
| name = name.lower() |
| if name in ("daf", "transformer"): |
| return TransformerForecast(modality_dims, num_classes, |
| t_obs=t_obs, t_fut=t_fut, |
| d_model=d_model, dropout=dropout) |
| if name == "futr": |
| return FUTRForecast(modality_dims, num_classes, |
| t_obs=t_obs, t_fut=t_fut, |
| d_model=d_model, dropout=dropout) |
| if name == "deepconvlstm": |
| return DeepConvLSTMForecast(modality_dims, num_classes, |
| t_obs=t_obs, t_fut=t_fut, |
| dropout=dropout) |
| if name in ("rulstm", "ru-lstm", "ru_lstm"): |
| return RULSTMForecast(modality_dims, num_classes, |
| t_obs=t_obs, t_fut=t_fut, |
| d_model=d_model, dropout=dropout) |
| if name == "avt": |
| return AVTForecast(modality_dims, num_classes, |
| t_obs=t_obs, t_fut=t_fut, |
| d_model=d_model, dropout=dropout) |
| raise ValueError(f"Unknown forecast model: {name!r}") |
|
|