| """Causal sequence models for the Wunder stepwise next-state task. |
| |
| Design constraints derived from the data analysis (see memory / report): |
| - Features are pre-whitened (~N(0,1), no heavy tails) so no internal |
| normalization is required; the model predicts the next *level* directly. |
| - Inference is online, one row at a time, on a single CPU core within a |
| 60-minute budget. A unidirectional GRU carries hidden state across steps, |
| giving O(1) work per prediction. |
| - Training uses full 1000-step sequences with back-prop-through-time and a |
| loss masked to the scored steps (100..998). Because a GRU is a pure |
| recurrence, the full-sequence forward and the stepwise stateful forward |
| produce identical outputs — so what we train is exactly what we score. |
| |
| `forward(x, h0)` returns per-step predictions of the *next* state: |
| preds[:, t] predicts x[:, t + 1]. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class CausalGRUForecaster(nn.Module): |
| """Unidirectional GRU that predicts the next state at every step. |
| |
| Args: |
| n_features: input/output width (32). |
| d_model: GRU hidden size and embedding width. |
| n_layers: stacked GRU layers. |
| dropout: dropout between GRU layers and inside the head. |
| head_hidden: width of the MLP head hidden layer (defaults to d_model). |
| """ |
|
|
| def __init__( |
| self, |
| n_features: int = 32, |
| d_model: int = 256, |
| n_layers: int = 2, |
| dropout: float = 0.1, |
| head_hidden: Optional[int] = None, |
| rnn_type: str = "gru", |
| ): |
| super().__init__() |
| self.n_features = int(n_features) |
| self.d_model = int(d_model) |
| self.n_layers = int(n_layers) |
| self.rnn_type = str(rnn_type).lower() |
| head_hidden = int(head_hidden) if head_hidden else self.d_model |
|
|
| |
| self.input_proj = nn.Linear(n_features, d_model) |
| self.input_norm = nn.LayerNorm(d_model) |
|
|
| rnn_cls = nn.LSTM if self.rnn_type == "lstm" else nn.GRU |
| self.gru = rnn_cls( |
| input_size=d_model, |
| hidden_size=d_model, |
| num_layers=n_layers, |
| batch_first=True, |
| dropout=dropout if n_layers > 1 else 0.0, |
| ) |
|
|
| self.head = nn.Sequential( |
| nn.LayerNorm(d_model), |
| nn.Linear(d_model, head_hidden), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(head_hidden, n_features), |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| h0: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Args: |
| x: (batch, seq_len, n_features) |
| h0: optional initial hidden state (n_layers, batch, d_model) |
| Returns: |
| preds: (batch, seq_len, n_features) where preds[:, t] predicts x[:, t+1] |
| h_n: final hidden state (n_layers, batch, d_model) |
| """ |
| z = self.input_norm(self.input_proj(x)) |
| out, h_n = self.gru(z, h0) |
| preds = self.head(out) |
| return preds, h_n |
|
|
| @torch.no_grad() |
| def step( |
| self, |
| state: torch.Tensor, |
| h: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Single online step. |
| |
| Args: |
| state: (n_features,) or (1, n_features) current state. |
| h: previous hidden state (n_layers, 1, d_model) or None. |
| Returns: |
| pred: (n_features,) prediction of the next state. |
| h_new: updated hidden state. |
| """ |
| if state.dim() == 1: |
| state = state.view(1, 1, -1) |
| elif state.dim() == 2: |
| state = state.unsqueeze(1) |
| preds, h_new = self.forward(state, h) |
| return preds[0, -1], h_new |
|
|
|
|
| def build_gru(config: dict) -> CausalGRUForecaster: |
| """Construct a CausalGRUForecaster from a plain dict config.""" |
| return CausalGRUForecaster( |
| n_features=config.get("n_features", 32), |
| d_model=config.get("d_model", 256), |
| n_layers=config.get("n_layers", 2), |
| dropout=config.get("dropout", 0.1), |
| head_hidden=config.get("head_hidden"), |
| rnn_type=config.get("rnn_type", "gru"), |
| ) |
|
|
|
|
| def count_parameters(model: nn.Module) -> int: |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|