"""Ridge multi-horizon CGM forecaster, packaged for the HF Hub. One repo holds four feature ablations (``cgm``, ``insulin``, ``carbs``, ``all``) as separate ``model_.safetensors`` files. The active ablation is selected at load time via the ``ablation=`` kwarg passed through ``AutoConfig`` or ``AutoModel`` ``from_pretrained``. Usage:: from transformers import AutoConfig, AutoModel cfg = AutoConfig.from_pretrained( "anonymous-4FAD/Ridge", trust_remote_code=True, ablation="cgm") model = AutoModel.from_pretrained( "anonymous-4FAD/Ridge", trust_remote_code=True, config=cfg) preds = model.predict(timestamps_ns, cgm, insulin, carbs) # (B, 12) """ from __future__ import annotations import math import os from typing import Optional import numpy as np import torch import torch.nn as nn from huggingface_hub import hf_hub_download from safetensors.torch import load_file from transformers import PretrainedConfig, PreTrainedModel _HUB_DOWNLOAD_KWARGS = ( "cache_dir", "force_download", "local_files_only", "proxies", "revision", "subfolder", "token", ) class RidgeMultiHorizonConfig(PretrainedConfig): """Config for the multi-horizon Ridge forecaster. The same repo serves four ablations (``cgm``, ``insulin``, ``carbs``, ``all``); the currently active one is ``self.ablation``. """ model_type = "ridge_multihorizon" def __init__( self, ablation: str = "all", ablations: Optional[list] = None, history_length: int = 24, horizon_length: int = 12, feature_names_by_ablation: Optional[dict] = None, n_features_by_ablation: Optional[dict] = None, target_names: Optional[list] = None, **kwargs, ): if ablations is None: ablations = ["cgm", "insulin", "carbs", "all"] if ablation not in ablations: raise ValueError( f"ablation must be one of {ablations}, got {ablation!r}" ) self.ablation = ablation self.ablations = list(ablations) self.history_length = int(history_length) self.horizon_length = int(horizon_length) self.feature_names_by_ablation = feature_names_by_ablation or {} self.n_features_by_ablation = n_features_by_ablation or {} self.target_names = list(target_names or []) super().__init__(**kwargs) @property def n_features(self) -> int: if self.n_features_by_ablation: return int(self.n_features_by_ablation[self.ablation]) return len(self.feature_names_by_ablation[self.ablation]) @property def feature_names(self) -> list: return list(self.feature_names_by_ablation[self.ablation]) class RidgeMultiHorizonModel(PreTrainedModel): """Multi-output Ridge regressor over standardized tabular features. Holds only buffers (``scaler_mean``, ``scaler_scale``, ``coef``, ``intercept``); there are no trainable parameters. """ config_class = RidgeMultiHorizonConfig main_input_name = "features" _tied_weights_keys: dict = None _no_split_modules: list = [] def __init__(self, config: RidgeMultiHorizonConfig): super().__init__(config) n_feat = config.n_features n_horiz = config.horizon_length self.register_buffer("scaler_mean", torch.zeros(n_feat)) self.register_buffer("scaler_scale", torch.ones(n_feat)) self.register_buffer("coef", torch.zeros(n_horiz, n_feat)) self.register_buffer("intercept", torch.zeros(n_horiz)) def _init_weights(self, module): # No trainable parameters; values come from safetensors. pass def forward(self, features: torch.Tensor) -> torch.Tensor: x = (features.to(self.coef.dtype) - self.scaler_mean) / self.scaler_scale return x @ self.coef.T + self.intercept @classmethod def from_pretrained( cls, pretrained_model_name_or_path, *model_args, config=None, ablation: Optional[str] = None, **kwargs, ): # Drop transformers-internal markers we don't need to act on. kwargs.pop("trust_remote_code", None) kwargs.pop("_from_auto", None) kwargs.pop("_commit_hash", None) hub_kwargs = {k: kwargs.pop(k) for k in _HUB_DOWNLOAD_KWARGS if k in kwargs} if config is None: config_kwargs = dict(hub_kwargs) if ablation is not None: config_kwargs["ablation"] = ablation config = RidgeMultiHorizonConfig.from_pretrained( pretrained_model_name_or_path, **config_kwargs ) elif ablation is not None: config.ablation = ablation model = cls(config) weights_filename = f"model_{config.ablation}.safetensors" if os.path.isdir(str(pretrained_model_name_or_path)): weights_path = os.path.join( str(pretrained_model_name_or_path), weights_filename) if not os.path.isfile(weights_path): raise FileNotFoundError( f"Expected {weights_filename} in {pretrained_model_name_or_path}" ) else: weights_path = hf_hub_download( repo_id=str(pretrained_model_name_or_path), filename=weights_filename, **hub_kwargs, ) state = load_file(weights_path) missing, unexpected = model.load_state_dict(state, strict=False) if missing: raise RuntimeError( f"{weights_filename} is missing buffers required by the model: {missing}" ) if unexpected: # Not fatal, but worth surfacing in case a checkpoint has stale keys. print( f"RidgeMultiHorizonModel: ignoring unexpected keys in " f"{weights_filename}: {unexpected}" ) model.eval() return model @torch.no_grad() def predict(self, timestamps, cgm, insulin, carbs) -> np.ndarray: """Run inference for a benchmark.py-style batch. Args: timestamps: int64 ns timestamps, shape ``(B, T_in)``. cgm: float CGM values, shape ``(B, T_in)``. insulin: float insulin values, shape ``(B, T_in)`` (used only if the active ablation requires Insulin features). carbs: float carb values, shape ``(B, T_in)`` (used only if the active ablation requires Carbs features). Returns: ``(B, horizon_length)`` numpy array of predicted CGM values. """ features = _build_tabular_features( timestamps=np.asarray(timestamps), cgm=np.asarray(cgm, dtype=np.float64), insulin=np.asarray(insulin, dtype=np.float64), carbs=np.asarray(carbs, dtype=np.float64), feature_names=self.config.feature_names, history_length=self.config.history_length, ) device = self.coef.device x = torch.as_tensor(features, dtype=self.coef.dtype, device=device) out = self.forward(x) return out.detach().cpu().numpy() def _build_tabular_features( *, timestamps: np.ndarray, cgm: np.ndarray, insulin: np.ndarray, carbs: np.ndarray, feature_names: list, history_length: int, ) -> np.ndarray: """Assemble a (B, F) feature matrix in the order given by ``feature_names``. Convention: ``CGM_t`` means the i-th *most recent* sample within the last ``history_length`` steps, i.e. ``CGM_t0`` = oldest in the window, ``CGM_t`` = newest. Same convention applies to ``Insulin_t`` / ``Carbs_t``. ``hour_sin`` / ``hour_cos`` are derived from the most recent input timestamp (UTC hour-of-day). """ if cgm.shape[-1] < history_length: raise ValueError( f"Need at least {history_length} CGM samples, got {cgm.shape[-1]}" ) cgm_h = cgm[..., -history_length:] insulin_h = insulin[..., -history_length:] carbs_h = carbs[..., -history_length:] # Hour-of-day from the most recent input timestamp (ns since epoch). last_ts = np.asarray(timestamps)[..., -1].astype(np.int64) hours = (last_ts // 3_600_000_000_000) % 24 hour_sin = np.sin(2.0 * math.pi * hours / 24.0) hour_cos = np.cos(2.0 * math.pi * hours / 24.0) columns = [] for name in feature_names: if name.startswith("CGM_t"): i = int(name.split("_t", 1)[1]) columns.append(cgm_h[..., i]) elif name.startswith("Insulin_t"): i = int(name.split("_t", 1)[1]) columns.append(insulin_h[..., i]) elif name.startswith("Carbs_t"): i = int(name.split("_t", 1)[1]) columns.append(carbs_h[..., i]) elif name == "hour_sin": columns.append(hour_sin) elif name == "hour_cos": columns.append(hour_cos) else: raise ValueError(f"Unknown feature column: {name!r}") return np.stack(columns, axis=-1).astype(np.float32)