| """Online stateful inference wrapper for the causal GRU. |
| |
| Carries the GRU hidden state across steps so each prediction is O(1) work, |
| which is what makes the model fit the 1-core / 60-minute submission budget. |
| Used by both local stepwise evaluation and the packaged solution.py. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
| import torch |
|
|
| from src.models.sequence_models import CausalGRUForecaster |
|
|
|
|
| class GRUStatefulPredictionModel: |
| """Wrap a CausalGRUForecaster as a competition PredictionModel.""" |
|
|
| def __init__(self, model: CausalGRUForecaster, n_features: int = 32): |
| self.model = model.eval() |
| self.n_features = n_features |
| self.current_seq: Optional[int] = None |
| self.h: Optional[torch.Tensor] = None |
|
|
| def reset(self): |
| self.current_seq = None |
| self.h = None |
|
|
| @torch.no_grad() |
| def predict(self, data_point): |
| if self.current_seq != data_point.seq_ix: |
| self.current_seq = data_point.seq_ix |
| self.h = None |
| state = torch.from_numpy(np.asarray(data_point.state, dtype=np.float32)) |
| pred, self.h = self.model.step(state, self.h) |
| if not data_point.need_prediction: |
| return None |
| return pred.detach().numpy().astype(np.float32) |
|
|
|
|
| def build_model_from_config(params: dict, model_type: str = "CausalGRUForecaster"): |
| if model_type == "CausalTCN": |
| from src.models.tcn import CausalTCN |
| return CausalTCN( |
| n_features=params.get("n_features", 32), |
| d_model=params.get("d_model", 192), |
| n_layers=params.get("n_layers", 6), |
| kernel_size=params.get("kernel_size", 3), |
| dropout=params.get("dropout", 0.1), |
| head_hidden=params.get("head_hidden"), |
| ) |
| return CausalGRUForecaster( |
| n_features=params.get("n_features", 32), |
| d_model=params.get("d_model", 256), |
| n_layers=params.get("n_layers", 2), |
| dropout=params.get("dropout", 0.1), |
| head_hidden=params.get("head_hidden"), |
| rnn_type=params.get("rnn_type", "gru"), |
| ) |
|
|
|
|
| def load_gru_checkpoint(path: str | Path, map_location: str = "cpu"): |
| """Reconstruct a forecaster (GRU/LSTM/TCN) from a saved checkpoint.""" |
| ckpt = torch.load(str(path), map_location=map_location, weights_only=False) |
| model_cfg = ckpt.get("config", {}).get("model", {}) |
| params = model_cfg.get("params", {}) |
| model = build_model_from_config(params, model_cfg.get("type", "CausalGRUForecaster")) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
| return model |
|
|