PULSE-code / experiments /nets /models_forecast.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
"""Frame-level future forecasting models.
Three baselines (all sharing the same forecast head signature):
- TransformerForecast (our DAF-style)
- FUTRForecast (Transformer encoder + parallel query decoder)
- DeepConvLSTMForecast (Ordoñez & Roggen 2016 wearable HAR backbone)
All take a dict {mod: (B, T_obs, F_mod)} and output (B, T_fut, num_classes).
"""
from __future__ import annotations
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Shared per-modality projection: each modality -> hidden dim d_model
# ---------------------------------------------------------------------------
class _PerModalityProj(nn.Module):
def __init__(self, modality_dims: Dict[str, int], d_model: int):
super().__init__()
self.proj = nn.ModuleDict({
m: nn.Linear(d, d_model) for m, d in modality_dims.items()
})
self.mod_emb = nn.Parameter(torch.zeros(len(modality_dims), d_model))
nn.init.trunc_normal_(self.mod_emb, std=0.02)
self.mods = list(modality_dims.keys())
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
# Concatenate per-modality projections along time? Or sum?
# We sum modality-projected features per time step (with modality
# embedding broadcast). Equivalent to early-fusion at the d_model
# space and is what a "modality-aware Transformer" typically uses.
out = None
for i, m in enumerate(self.mods):
h = self.proj[m](x[m]) + self.mod_emb[i]
out = h if out is None else out + h
return out / len(self.mods) # (B, T_obs, d_model)
# ---------------------------------------------------------------------------
# 1. Transformer (DAF-style) forecast model
# ---------------------------------------------------------------------------
class TransformerForecast(nn.Module):
def __init__(self, modality_dims: Dict[str, int], num_classes: int,
t_obs: int, t_fut: int, d_model: int = 128,
n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1):
super().__init__()
self.t_obs = t_obs
self.t_fut = t_fut
self.num_classes = num_classes
self.embed = _PerModalityProj(modality_dims, d_model)
self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model))
nn.init.trunc_normal_(self.pos, std=0.02)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
dropout=dropout, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model))
nn.init.trunc_normal_(self.queries, std=0.02)
self.cross_attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True
)
self.norm = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, num_classes)
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
h = self.embed(x) + self.pos
h = self.encoder(h) # (B, T_obs, D)
q = self.queries.expand(h.size(0), -1, -1) # (B, T_fut, D)
out, _ = self.cross_attn(q, h, h, need_weights=False)
out = self.norm(out)
return self.head(out) # (B, T_fut, C)
# ---------------------------------------------------------------------------
# 2. FUTR-style forecast (Future Transformer, Gong et al. CVPR 2022)
# Same encoder + parallel query decoder. We add a small Transformer
# decoder so it's not literally identical to TransformerForecast.
# ---------------------------------------------------------------------------
class FUTRForecast(nn.Module):
def __init__(self, modality_dims: Dict[str, int], num_classes: int,
t_obs: int, t_fut: int, d_model: int = 128,
n_heads: int = 4, n_enc: int = 2, n_dec: int = 1,
dropout: float = 0.1):
super().__init__()
self.t_obs = t_obs
self.t_fut = t_fut
self.num_classes = num_classes
self.embed = _PerModalityProj(modality_dims, d_model)
self.pos = nn.Parameter(torch.zeros(1, t_obs, d_model))
nn.init.trunc_normal_(self.pos, std=0.02)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
dropout=dropout, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_enc)
dec_layer = nn.TransformerDecoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
dropout=dropout, batch_first=True, activation="gelu",
)
self.decoder = nn.TransformerDecoder(dec_layer, num_layers=n_dec)
self.queries = nn.Parameter(torch.zeros(1, t_fut, d_model))
nn.init.trunc_normal_(self.queries, std=0.02)
self.head = nn.Linear(d_model, num_classes)
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
memory = self.encoder(self.embed(x) + self.pos) # (B, T_obs, D)
q = self.queries.expand(memory.size(0), -1, -1) # (B, T_fut, D)
out = self.decoder(q, memory)
return self.head(out) # (B, T_fut, C)
# ---------------------------------------------------------------------------
# 3. DeepConvLSTM-style forecast
# ---------------------------------------------------------------------------
class DeepConvLSTMForecast(nn.Module):
def __init__(self, modality_dims: Dict[str, int], num_classes: int,
t_obs: int, t_fut: int, conv_filters: int = 64,
lstm_hidden: int = 128, n_lstm_layers: int = 2,
dropout: float = 0.1):
super().__init__()
self.t_obs = t_obs
self.t_fut = t_fut
self.num_classes = num_classes
self.mods = list(modality_dims.keys())
in_ch = sum(modality_dims.values())
# Same 4-layer conv stack as the original DeepConvLSTM
layers = []
ch = in_ch
for i in range(4):
layers.append(nn.Sequential(
nn.Conv1d(ch, conv_filters, kernel_size=5, padding=2),
nn.BatchNorm1d(conv_filters),
nn.ReLU(),
nn.Dropout(dropout if i < 3 else 0.2),
))
ch = conv_filters
self.convs = nn.ModuleList(layers)
self.lstm = nn.LSTM(
conv_filters, lstm_hidden, num_layers=n_lstm_layers,
batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0,
)
self.head = nn.Linear(lstm_hidden, t_fut * num_classes)
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
h = torch.cat([x[m] for m in self.mods], dim=-1) # (B, T_obs, F_total)
h = h.permute(0, 2, 1) # (B, F, T_obs)
for c in self.convs:
h = c(h)
h = h.permute(0, 2, 1) # (B, T_obs, conv_filters)
out, (h_n, _) = self.lstm(h)
feat = h_n[-1] # (B, lstm_hidden)
logits = self.head(feat).view(-1, self.t_fut, self.num_classes)
return logits
# ---------------------------------------------------------------------------
# 4. RU-LSTM (Furnari et al. RAL 2019, "Rolling-Unrolling LSTM for action
# anticipation"). Two-phase LSTM: a "rolling" phase encodes past, an
# "unrolling" phase autoregressively decodes future tokens.
# ---------------------------------------------------------------------------
class RULSTMForecast(nn.Module):
def __init__(self, modality_dims: Dict[str, int], num_classes: int,
t_obs: int, t_fut: int, d_model: int = 128,
n_lstm_layers: int = 2, dropout: float = 0.1):
super().__init__()
self.t_obs = t_obs
self.t_fut = t_fut
self.num_classes = num_classes
self.embed = _PerModalityProj(modality_dims, d_model)
self.rolling = nn.LSTM(
d_model, d_model, num_layers=n_lstm_layers,
batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0,
)
self.unrolling = nn.LSTM(
d_model, d_model, num_layers=n_lstm_layers,
batch_first=True, dropout=dropout if n_lstm_layers > 1 else 0,
)
self.fut_init = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.trunc_normal_(self.fut_init, std=0.02)
self.head = nn.Linear(d_model, num_classes)
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
h_past = self.embed(x) # (B, T_obs, D)
_, (h_n, c_n) = self.rolling(h_past)
B = h_past.size(0)
# Use a learned initial future token, repeated T_fut times
fut_input = self.fut_init.expand(B, self.t_fut, -1)
out, _ = self.unrolling(fut_input, (h_n, c_n))
return self.head(out) # (B, T_fut, C)
# ---------------------------------------------------------------------------
# 5. AVT (Girdhar & Grauman ICCV 2021, "Anticipative Video Transformer").
# Causal Transformer over the concatenation of past + future tokens.
# ---------------------------------------------------------------------------
class AVTForecast(nn.Module):
def __init__(self, modality_dims: Dict[str, int], num_classes: int,
t_obs: int, t_fut: int, d_model: int = 128,
n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1):
super().__init__()
self.t_obs = t_obs
self.t_fut = t_fut
self.num_classes = num_classes
self.embed = _PerModalityProj(modality_dims, d_model)
seq_len = t_obs + t_fut
self.pos = nn.Parameter(torch.zeros(1, seq_len, d_model))
nn.init.trunc_normal_(self.pos, std=0.02)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model,
dropout=dropout, batch_first=True, activation="gelu",
)
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
self.fut_tokens = nn.Parameter(torch.zeros(1, t_fut, d_model))
nn.init.trunc_normal_(self.fut_tokens, std=0.02)
self.head = nn.Linear(d_model, num_classes)
# Causal mask over concatenated [past | future] sequence
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
self.register_buffer("causal_mask", mask)
def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
h_past = self.embed(x) # (B, T_obs, D)
B = h_past.size(0)
h_fut = self.fut_tokens.expand(B, -1, -1) # (B, T_fut, D)
seq = torch.cat([h_past, h_fut], dim=1) + self.pos
out = self.encoder(seq, mask=self.causal_mask)
out_fut = out[:, self.t_obs:, :]
return self.head(out_fut) # (B, T_fut, C)
# ---------------------------------------------------------------------------
# Builder
# ---------------------------------------------------------------------------
def build_forecast_model(name: str, modality_dims: Dict[str, int],
num_classes: int, t_obs: int, t_fut: int,
d_model: int = 128, dropout: float = 0.1) -> nn.Module:
name = name.lower()
if name in ("daf", "transformer"):
return TransformerForecast(modality_dims, num_classes,
t_obs=t_obs, t_fut=t_fut,
d_model=d_model, dropout=dropout)
if name == "futr":
return FUTRForecast(modality_dims, num_classes,
t_obs=t_obs, t_fut=t_fut,
d_model=d_model, dropout=dropout)
if name == "deepconvlstm":
return DeepConvLSTMForecast(modality_dims, num_classes,
t_obs=t_obs, t_fut=t_fut,
dropout=dropout)
if name in ("rulstm", "ru-lstm", "ru_lstm"):
return RULSTMForecast(modality_dims, num_classes,
t_obs=t_obs, t_fut=t_fut,
d_model=d_model, dropout=dropout)
if name == "avt":
return AVTForecast(modality_dims, num_classes,
t_obs=t_obs, t_fut=t_fut,
d_model=d_model, dropout=dropout)
raise ValueError(f"Unknown forecast model: {name!r}")