Ridge / model.py
anonymous-4FAD's picture
Upload 8 files
c4135cc verified
"""Ridge multi-horizon CGM forecaster, packaged for the HF Hub.
One repo holds four feature ablations (``cgm``, ``insulin``, ``carbs``, ``all``)
as separate ``model_<ablation>.safetensors`` files. The active ablation is
selected at load time via the ``ablation=`` kwarg passed through ``AutoConfig``
or ``AutoModel`` ``from_pretrained``.
Usage::
from transformers import AutoConfig, AutoModel
cfg = AutoConfig.from_pretrained(
"anonymous-4FAD/Ridge", trust_remote_code=True, ablation="cgm")
model = AutoModel.from_pretrained(
"anonymous-4FAD/Ridge", trust_remote_code=True, config=cfg)
preds = model.predict(timestamps_ns, cgm, insulin, carbs) # (B, 12)
"""
from __future__ import annotations
import math
import os
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import PretrainedConfig, PreTrainedModel
_HUB_DOWNLOAD_KWARGS = (
"cache_dir",
"force_download",
"local_files_only",
"proxies",
"revision",
"subfolder",
"token",
)
class RidgeMultiHorizonConfig(PretrainedConfig):
"""Config for the multi-horizon Ridge forecaster.
The same repo serves four ablations (``cgm``, ``insulin``, ``carbs``,
``all``); the currently active one is ``self.ablation``.
"""
model_type = "ridge_multihorizon"
def __init__(
self,
ablation: str = "all",
ablations: Optional[list] = None,
history_length: int = 24,
horizon_length: int = 12,
feature_names_by_ablation: Optional[dict] = None,
n_features_by_ablation: Optional[dict] = None,
target_names: Optional[list] = None,
**kwargs,
):
if ablations is None:
ablations = ["cgm", "insulin", "carbs", "all"]
if ablation not in ablations:
raise ValueError(
f"ablation must be one of {ablations}, got {ablation!r}"
)
self.ablation = ablation
self.ablations = list(ablations)
self.history_length = int(history_length)
self.horizon_length = int(horizon_length)
self.feature_names_by_ablation = feature_names_by_ablation or {}
self.n_features_by_ablation = n_features_by_ablation or {}
self.target_names = list(target_names or [])
super().__init__(**kwargs)
@property
def n_features(self) -> int:
if self.n_features_by_ablation:
return int(self.n_features_by_ablation[self.ablation])
return len(self.feature_names_by_ablation[self.ablation])
@property
def feature_names(self) -> list:
return list(self.feature_names_by_ablation[self.ablation])
class RidgeMultiHorizonModel(PreTrainedModel):
"""Multi-output Ridge regressor over standardized tabular features.
Holds only buffers (``scaler_mean``, ``scaler_scale``, ``coef``,
``intercept``); there are no trainable parameters.
"""
config_class = RidgeMultiHorizonConfig
main_input_name = "features"
_tied_weights_keys: dict = None
_no_split_modules: list = []
def __init__(self, config: RidgeMultiHorizonConfig):
super().__init__(config)
n_feat = config.n_features
n_horiz = config.horizon_length
self.register_buffer("scaler_mean", torch.zeros(n_feat))
self.register_buffer("scaler_scale", torch.ones(n_feat))
self.register_buffer("coef", torch.zeros(n_horiz, n_feat))
self.register_buffer("intercept", torch.zeros(n_horiz))
def _init_weights(self, module):
# No trainable parameters; values come from safetensors.
pass
def forward(self, features: torch.Tensor) -> torch.Tensor:
x = (features.to(self.coef.dtype) - self.scaler_mean) / self.scaler_scale
return x @ self.coef.T + self.intercept
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
*model_args,
config=None,
ablation: Optional[str] = None,
**kwargs,
):
# Drop transformers-internal markers we don't need to act on.
kwargs.pop("trust_remote_code", None)
kwargs.pop("_from_auto", None)
kwargs.pop("_commit_hash", None)
hub_kwargs = {k: kwargs.pop(k) for k in _HUB_DOWNLOAD_KWARGS if k in kwargs}
if config is None:
config_kwargs = dict(hub_kwargs)
if ablation is not None:
config_kwargs["ablation"] = ablation
config = RidgeMultiHorizonConfig.from_pretrained(
pretrained_model_name_or_path, **config_kwargs
)
elif ablation is not None:
config.ablation = ablation
model = cls(config)
weights_filename = f"model_{config.ablation}.safetensors"
if os.path.isdir(str(pretrained_model_name_or_path)):
weights_path = os.path.join(
str(pretrained_model_name_or_path), weights_filename)
if not os.path.isfile(weights_path):
raise FileNotFoundError(
f"Expected {weights_filename} in {pretrained_model_name_or_path}"
)
else:
weights_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path),
filename=weights_filename,
**hub_kwargs,
)
state = load_file(weights_path)
missing, unexpected = model.load_state_dict(state, strict=False)
if missing:
raise RuntimeError(
f"{weights_filename} is missing buffers required by the model: {missing}"
)
if unexpected:
# Not fatal, but worth surfacing in case a checkpoint has stale keys.
print(
f"RidgeMultiHorizonModel: ignoring unexpected keys in "
f"{weights_filename}: {unexpected}"
)
model.eval()
return model
@torch.no_grad()
def predict(self, timestamps, cgm, insulin, carbs) -> np.ndarray:
"""Run inference for a benchmark.py-style batch.
Args:
timestamps: int64 ns timestamps, shape ``(B, T_in)``.
cgm: float CGM values, shape ``(B, T_in)``.
insulin: float insulin values, shape ``(B, T_in)`` (used only if
the active ablation requires Insulin features).
carbs: float carb values, shape ``(B, T_in)`` (used only if the
active ablation requires Carbs features).
Returns:
``(B, horizon_length)`` numpy array of predicted CGM values.
"""
features = _build_tabular_features(
timestamps=np.asarray(timestamps),
cgm=np.asarray(cgm, dtype=np.float64),
insulin=np.asarray(insulin, dtype=np.float64),
carbs=np.asarray(carbs, dtype=np.float64),
feature_names=self.config.feature_names,
history_length=self.config.history_length,
)
device = self.coef.device
x = torch.as_tensor(features, dtype=self.coef.dtype, device=device)
out = self.forward(x)
return out.detach().cpu().numpy()
def _build_tabular_features(
*,
timestamps: np.ndarray,
cgm: np.ndarray,
insulin: np.ndarray,
carbs: np.ndarray,
feature_names: list,
history_length: int,
) -> np.ndarray:
"""Assemble a (B, F) feature matrix in the order given by ``feature_names``.
Convention: ``CGM_t<i>`` means the i-th *most recent* sample within the
last ``history_length`` steps, i.e. ``CGM_t0`` = oldest in the window,
``CGM_t<history_length-1>`` = newest. Same convention applies to
``Insulin_t<i>`` / ``Carbs_t<i>``. ``hour_sin`` / ``hour_cos`` are derived
from the most recent input timestamp (UTC hour-of-day).
"""
if cgm.shape[-1] < history_length:
raise ValueError(
f"Need at least {history_length} CGM samples, got {cgm.shape[-1]}"
)
cgm_h = cgm[..., -history_length:]
insulin_h = insulin[..., -history_length:]
carbs_h = carbs[..., -history_length:]
# Hour-of-day from the most recent input timestamp (ns since epoch).
last_ts = np.asarray(timestamps)[..., -1].astype(np.int64)
hours = (last_ts // 3_600_000_000_000) % 24
hour_sin = np.sin(2.0 * math.pi * hours / 24.0)
hour_cos = np.cos(2.0 * math.pi * hours / 24.0)
columns = []
for name in feature_names:
if name.startswith("CGM_t"):
i = int(name.split("_t", 1)[1])
columns.append(cgm_h[..., i])
elif name.startswith("Insulin_t"):
i = int(name.split("_t", 1)[1])
columns.append(insulin_h[..., i])
elif name.startswith("Carbs_t"):
i = int(name.split("_t", 1)[1])
columns.append(carbs_h[..., i])
elif name == "hour_sin":
columns.append(hour_sin)
elif name == "hour_cos":
columns.append(hour_cos)
else:
raise ValueError(f"Unknown feature column: {name!r}")
return np.stack(columns, axis=-1).astype(np.float32)