wunder-rnn-gru-ensemble / src /models /sequence_inference.py
msrishav's picture
Add inference code, config, and technical report
e68eb1c verified
Raw
History Blame Contribute Delete
2.64 kB
"""Online stateful inference wrapper for the causal GRU.
Carries the GRU hidden state across steps so each prediction is O(1) work,
which is what makes the model fit the 1-core / 60-minute submission budget.
Used by both local stepwise evaluation and the packaged solution.py.
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from src.models.sequence_models import CausalGRUForecaster
class GRUStatefulPredictionModel:
"""Wrap a CausalGRUForecaster as a competition PredictionModel."""
def __init__(self, model: CausalGRUForecaster, n_features: int = 32):
self.model = model.eval()
self.n_features = n_features
self.current_seq: Optional[int] = None
self.h: Optional[torch.Tensor] = None
def reset(self):
self.current_seq = None
self.h = None
@torch.no_grad()
def predict(self, data_point):
if self.current_seq != data_point.seq_ix:
self.current_seq = data_point.seq_ix
self.h = None
state = torch.from_numpy(np.asarray(data_point.state, dtype=np.float32))
pred, self.h = self.model.step(state, self.h)
if not data_point.need_prediction:
return None
return pred.detach().numpy().astype(np.float32)
def build_model_from_config(params: dict, model_type: str = "CausalGRUForecaster"):
if model_type == "CausalTCN":
from src.models.tcn import CausalTCN
return CausalTCN(
n_features=params.get("n_features", 32),
d_model=params.get("d_model", 192),
n_layers=params.get("n_layers", 6),
kernel_size=params.get("kernel_size", 3),
dropout=params.get("dropout", 0.1),
head_hidden=params.get("head_hidden"),
)
return CausalGRUForecaster(
n_features=params.get("n_features", 32),
d_model=params.get("d_model", 256),
n_layers=params.get("n_layers", 2),
dropout=params.get("dropout", 0.1),
head_hidden=params.get("head_hidden"),
rnn_type=params.get("rnn_type", "gru"),
)
def load_gru_checkpoint(path: str | Path, map_location: str = "cpu"):
"""Reconstruct a forecaster (GRU/LSTM/TCN) from a saved checkpoint."""
ckpt = torch.load(str(path), map_location=map_location, weights_only=False)
model_cfg = ckpt.get("config", {}).get("model", {})
params = model_cfg.get("params", {})
model = build_model_from_config(params, model_cfg.get("type", "CausalGRUForecaster"))
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
return model