File size: 9,157 Bytes
c4135cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""Ridge multi-horizon CGM forecaster, packaged for the HF Hub.

One repo holds four feature ablations (``cgm``, ``insulin``, ``carbs``, ``all``)
as separate ``model_<ablation>.safetensors`` files. The active ablation is
selected at load time via the ``ablation=`` kwarg passed through ``AutoConfig``
or ``AutoModel`` ``from_pretrained``.

Usage::

    from transformers import AutoConfig, AutoModel
    cfg = AutoConfig.from_pretrained(
        "anonymous-4FAD/Ridge", trust_remote_code=True, ablation="cgm")
    model = AutoModel.from_pretrained(
        "anonymous-4FAD/Ridge", trust_remote_code=True, config=cfg)
    preds = model.predict(timestamps_ns, cgm, insulin, carbs)  # (B, 12)
"""

from __future__ import annotations

import math
import os
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import PretrainedConfig, PreTrainedModel


_HUB_DOWNLOAD_KWARGS = (
    "cache_dir",
    "force_download",
    "local_files_only",
    "proxies",
    "revision",
    "subfolder",
    "token",
)


class RidgeMultiHorizonConfig(PretrainedConfig):
    """Config for the multi-horizon Ridge forecaster.

    The same repo serves four ablations (``cgm``, ``insulin``, ``carbs``,
    ``all``); the currently active one is ``self.ablation``.
    """

    model_type = "ridge_multihorizon"

    def __init__(
        self,
        ablation: str = "all",
        ablations: Optional[list] = None,
        history_length: int = 24,
        horizon_length: int = 12,
        feature_names_by_ablation: Optional[dict] = None,
        n_features_by_ablation: Optional[dict] = None,
        target_names: Optional[list] = None,
        **kwargs,
    ):
        if ablations is None:
            ablations = ["cgm", "insulin", "carbs", "all"]
        if ablation not in ablations:
            raise ValueError(
                f"ablation must be one of {ablations}, got {ablation!r}"
            )
        self.ablation = ablation
        self.ablations = list(ablations)
        self.history_length = int(history_length)
        self.horizon_length = int(horizon_length)
        self.feature_names_by_ablation = feature_names_by_ablation or {}
        self.n_features_by_ablation = n_features_by_ablation or {}
        self.target_names = list(target_names or [])
        super().__init__(**kwargs)

    @property
    def n_features(self) -> int:
        if self.n_features_by_ablation:
            return int(self.n_features_by_ablation[self.ablation])
        return len(self.feature_names_by_ablation[self.ablation])

    @property
    def feature_names(self) -> list:
        return list(self.feature_names_by_ablation[self.ablation])


class RidgeMultiHorizonModel(PreTrainedModel):
    """Multi-output Ridge regressor over standardized tabular features.

    Holds only buffers (``scaler_mean``, ``scaler_scale``, ``coef``,
    ``intercept``); there are no trainable parameters.
    """

    config_class = RidgeMultiHorizonConfig
    main_input_name = "features"
    _tied_weights_keys: dict = None
    _no_split_modules: list = []

    def __init__(self, config: RidgeMultiHorizonConfig):
        super().__init__(config)
        n_feat = config.n_features
        n_horiz = config.horizon_length
        self.register_buffer("scaler_mean", torch.zeros(n_feat))
        self.register_buffer("scaler_scale", torch.ones(n_feat))
        self.register_buffer("coef", torch.zeros(n_horiz, n_feat))
        self.register_buffer("intercept", torch.zeros(n_horiz))

    def _init_weights(self, module):
        # No trainable parameters; values come from safetensors.
        pass

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        x = (features.to(self.coef.dtype) - self.scaler_mean) / self.scaler_scale
        return x @ self.coef.T + self.intercept

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path,
        *model_args,
        config=None,
        ablation: Optional[str] = None,
        **kwargs,
    ):
        # Drop transformers-internal markers we don't need to act on.
        kwargs.pop("trust_remote_code", None)
        kwargs.pop("_from_auto", None)
        kwargs.pop("_commit_hash", None)

        hub_kwargs = {k: kwargs.pop(k) for k in _HUB_DOWNLOAD_KWARGS if k in kwargs}

        if config is None:
            config_kwargs = dict(hub_kwargs)
            if ablation is not None:
                config_kwargs["ablation"] = ablation
            config = RidgeMultiHorizonConfig.from_pretrained(
                pretrained_model_name_or_path, **config_kwargs
            )
        elif ablation is not None:
            config.ablation = ablation

        model = cls(config)
        weights_filename = f"model_{config.ablation}.safetensors"

        if os.path.isdir(str(pretrained_model_name_or_path)):
            weights_path = os.path.join(
                str(pretrained_model_name_or_path), weights_filename)
            if not os.path.isfile(weights_path):
                raise FileNotFoundError(
                    f"Expected {weights_filename} in {pretrained_model_name_or_path}"
                )
        else:
            weights_path = hf_hub_download(
                repo_id=str(pretrained_model_name_or_path),
                filename=weights_filename,
                **hub_kwargs,
            )

        state = load_file(weights_path)
        missing, unexpected = model.load_state_dict(state, strict=False)
        if missing:
            raise RuntimeError(
                f"{weights_filename} is missing buffers required by the model: {missing}"
            )
        if unexpected:
            # Not fatal, but worth surfacing in case a checkpoint has stale keys.
            print(
                f"RidgeMultiHorizonModel: ignoring unexpected keys in "
                f"{weights_filename}: {unexpected}"
            )

        model.eval()
        return model

    @torch.no_grad()
    def predict(self, timestamps, cgm, insulin, carbs) -> np.ndarray:
        """Run inference for a benchmark.py-style batch.

        Args:
            timestamps: int64 ns timestamps, shape ``(B, T_in)``.
            cgm: float CGM values, shape ``(B, T_in)``.
            insulin: float insulin values, shape ``(B, T_in)`` (used only if
                the active ablation requires Insulin features).
            carbs: float carb values, shape ``(B, T_in)`` (used only if the
                active ablation requires Carbs features).

        Returns:
            ``(B, horizon_length)`` numpy array of predicted CGM values.
        """
        features = _build_tabular_features(
            timestamps=np.asarray(timestamps),
            cgm=np.asarray(cgm, dtype=np.float64),
            insulin=np.asarray(insulin, dtype=np.float64),
            carbs=np.asarray(carbs, dtype=np.float64),
            feature_names=self.config.feature_names,
            history_length=self.config.history_length,
        )
        device = self.coef.device
        x = torch.as_tensor(features, dtype=self.coef.dtype, device=device)
        out = self.forward(x)
        return out.detach().cpu().numpy()


def _build_tabular_features(
    *,
    timestamps: np.ndarray,
    cgm: np.ndarray,
    insulin: np.ndarray,
    carbs: np.ndarray,
    feature_names: list,
    history_length: int,
) -> np.ndarray:
    """Assemble a (B, F) feature matrix in the order given by ``feature_names``.

    Convention: ``CGM_t<i>`` means the i-th *most recent* sample within the
    last ``history_length`` steps, i.e. ``CGM_t0`` = oldest in the window,
    ``CGM_t<history_length-1>`` = newest. Same convention applies to
    ``Insulin_t<i>`` / ``Carbs_t<i>``. ``hour_sin`` / ``hour_cos`` are derived
    from the most recent input timestamp (UTC hour-of-day).
    """
    if cgm.shape[-1] < history_length:
        raise ValueError(
            f"Need at least {history_length} CGM samples, got {cgm.shape[-1]}"
        )
    cgm_h = cgm[..., -history_length:]
    insulin_h = insulin[..., -history_length:]
    carbs_h = carbs[..., -history_length:]

    # Hour-of-day from the most recent input timestamp (ns since epoch).
    last_ts = np.asarray(timestamps)[..., -1].astype(np.int64)
    hours = (last_ts // 3_600_000_000_000) % 24
    hour_sin = np.sin(2.0 * math.pi * hours / 24.0)
    hour_cos = np.cos(2.0 * math.pi * hours / 24.0)

    columns = []
    for name in feature_names:
        if name.startswith("CGM_t"):
            i = int(name.split("_t", 1)[1])
            columns.append(cgm_h[..., i])
        elif name.startswith("Insulin_t"):
            i = int(name.split("_t", 1)[1])
            columns.append(insulin_h[..., i])
        elif name.startswith("Carbs_t"):
            i = int(name.split("_t", 1)[1])
            columns.append(carbs_h[..., i])
        elif name == "hour_sin":
            columns.append(hour_sin)
        elif name == "hour_cos":
            columns.append(hour_cos)
        else:
            raise ValueError(f"Unknown feature column: {name!r}")
    return np.stack(columns, axis=-1).astype(np.float32)