| from __future__ import annotations |
|
|
| import warnings |
| from typing import Any |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
|
|
| from models.base_model import BaseForecastModel |
|
|
|
|
| class TotoModel(BaseForecastModel): |
| def __init__(self, device: str | None = None) -> None: |
| self.device = device |
| self.model = None |
| self.forecaster = None |
|
|
| def load_model(self) -> None: |
| try: |
| import pkg_resources |
| except ModuleNotFoundError as exc: |
| raise ModuleNotFoundError( |
| "Missing dependency 'pkg_resources'. Install a compatible setuptools " |
| "version (e.g. `<81`): `pip install 'setuptools<81'`." |
| ) from exc |
|
|
| from toto.inference.forecaster import TotoForecaster |
| from toto.model.toto import Toto |
|
|
| if self.device is None: |
| try: |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| except Exception: |
| self.device = "cpu" |
| try: |
| self.model = Toto.from_pretrained("Datadog/Toto-Open-Base-1.0").to(self.device) |
| except RuntimeError as exc: |
| if self.device == "cuda" and self._is_cuda_runtime_error(exc): |
| warnings.warn( |
| "CUDA failed while loading Toto. Falling back to CPU.", |
| RuntimeWarning, |
| stacklevel=2, |
| ) |
| self.device = "cpu" |
| self.model = Toto.from_pretrained("Datadog/Toto-Open-Base-1.0").to(self.device) |
| else: |
| raise |
|
|
| core_model = getattr(self.model, "model", self.model) |
| self.forecaster = TotoForecaster(core_model) |
|
|
| @staticmethod |
| def _is_cuda_runtime_error(exc: RuntimeError) -> bool: |
| message = str(exc).lower() |
| cuda_signals = ( |
| "cuda error", |
| "cublas", |
| "cudnn", |
| "device-side assert", |
| "invalid configuration argument", |
| "illegal memory access", |
| ) |
| return any(signal in message for signal in cuda_signals) |
|
|
| def _reload_on_cpu(self) -> None: |
| self.device = "cpu" |
| self.model = None |
| self.forecaster = None |
| try: |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| except Exception: |
| pass |
| self.load_model() |
|
|
| @staticmethod |
| def _largest_divisor_leq(value: int, cap: int) -> int: |
| cap = max(1, min(value, cap)) |
| for candidate in range(cap, 0, -1): |
| if value % candidate == 0: |
| return candidate |
| return 1 |
|
|
| def _choose_samples_per_batch(self, context_df: pd.DataFrame, num_samples: int) -> int: |
| context_len = len(context_df) |
| num_variates = context_df.shape[1] |
|
|
| if self.device == "cuda": |
| if context_len > 8192 or num_variates >= 6: |
| cap = 8 |
| elif context_len > 4096 or num_variates >= 4: |
| cap = 16 |
| else: |
| cap = 32 |
| else: |
| cap = 64 |
|
|
| return self._largest_divisor_leq(num_samples, cap) |
|
|
| def _to_numpy(self, value: Any) -> np.ndarray: |
| if value is None: |
| raise ValueError("Received empty forecast value.") |
| if isinstance(value, np.ndarray): |
| return value |
| if hasattr(value, "detach"): |
| return value.detach().cpu().numpy() |
| if hasattr(value, "cpu") and hasattr(value, "numpy"): |
| return value.cpu().numpy() |
| return np.asarray(value) |
|
|
| def _normalize_mean_shape(self, arr: np.ndarray, n_vars: int) -> np.ndarray: |
| arr = np.asarray(arr) |
| arr = np.squeeze(arr) |
|
|
| if arr.ndim == 1: |
| return arr.reshape(-1, 1) |
|
|
| if arr.ndim != 2: |
| raise ValueError(f"Unexpected prediction shape: {arr.shape}") |
|
|
| if arr.shape[1] == n_vars: |
| return arr |
| if arr.shape[0] == n_vars: |
| return arr.T |
|
|
| if n_vars == 1: |
| return arr.reshape(-1, 1) |
|
|
| raise ValueError(f"Could not align forecast shape {arr.shape} for {n_vars} variables.") |
|
|
| def _build_model_input(self, context_df: pd.DataFrame) -> Any: |
| from toto.data.util.dataset import MaskedTimeseries |
|
|
| |
| values_2d = torch.tensor(context_df.values.T, dtype=torch.float32, device=self.device) |
| num_variates, _ = values_2d.shape |
| values = values_2d.unsqueeze(0) |
|
|
| ts_1d = ( |
| torch.tensor(context_df.index.view("int64"), dtype=torch.int64, device=self.device) |
| // 1_000_000_000 |
| ) |
|
|
| if len(ts_1d) >= 2: |
| diffs = ts_1d[1:] - ts_1d[:-1] |
| interval = int(torch.median(diffs).item()) |
| if interval <= 0: |
| interval = 1 |
| else: |
| interval = 1 |
|
|
| timestamp_seconds = ts_1d.unsqueeze(0).unsqueeze(0).repeat(1, num_variates, 1) |
| time_interval_seconds = torch.full((1, num_variates), interval, dtype=torch.int64, device=self.device) |
| id_mask = torch.zeros_like(values, dtype=torch.int64) |
|
|
| return MaskedTimeseries( |
| series=values, |
| padding_mask=torch.full_like(values, True, dtype=torch.bool), |
| id_mask=id_mask, |
| timestamp_seconds=timestamp_seconds, |
| time_interval_seconds=time_interval_seconds, |
| ) |
|
|
| def predict( |
| self, |
| context_data: Any, |
| prediction_length: int, |
| num_samples: int = 256, |
| quantile_levels: tuple[float, ...] = (0.1, 0.5, 0.9), |
| **kwargs: Any, |
| ) -> dict[str, Any]: |
| if self.forecaster is None: |
| self.load_model() |
|
|
| if isinstance(context_data, dict): |
| context_df = context_data.get("context") |
| frequency = context_data.get("frequency") |
| target_cols = context_data.get("target_cols") or list(context_df.columns) |
| else: |
| context_df = context_data |
| frequency = None |
| target_cols = list(context_df.columns) |
|
|
| if context_df is None or not isinstance(context_df, pd.DataFrame): |
| raise ValueError("Toto expects context_data to include a pandas DataFrame under 'context'.") |
|
|
| inputs = self._build_model_input(context_df) |
| samples_per_batch = self._choose_samples_per_batch(context_df, num_samples) |
|
|
| try: |
| forecast = self.forecaster.forecast( |
| inputs, |
| prediction_length=prediction_length, |
| num_samples=num_samples, |
| samples_per_batch=samples_per_batch, |
| **kwargs, |
| ) |
| except RuntimeError as exc: |
| if self.device == "cuda" and self._is_cuda_runtime_error(exc): |
| warnings.warn( |
| "CUDA failed during Toto forecasting. Retrying on CPU.", |
| RuntimeWarning, |
| stacklevel=2, |
| ) |
| self._reload_on_cpu() |
| inputs = self._build_model_input(context_df) |
| forecast = self.forecaster.forecast( |
| inputs, |
| prediction_length=prediction_length, |
| num_samples=num_samples, |
| samples_per_batch=samples_per_batch, |
| **kwargs, |
| ) |
| else: |
| raise |
|
|
| median_value = getattr(forecast, "median", None) |
| if callable(median_value): |
| median_value = median_value() |
| if median_value is None and hasattr(forecast, "mean"): |
| maybe_mean = getattr(forecast, "mean") |
| median_value = maybe_mean() if callable(maybe_mean) else maybe_mean |
|
|
| if median_value is None: |
| samples = getattr(forecast, "samples", None) |
| if samples is None: |
| raise RuntimeError("Toto forecast object did not include median/mean/samples.") |
| samples_np = self._to_numpy(samples) |
| median_np = np.median(samples_np, axis=0) |
| else: |
| median_np = self._to_numpy(median_value) |
|
|
| mean_2d = self._normalize_mean_shape(median_np, n_vars=len(target_cols)) |
|
|
| quantiles: dict[float, np.ndarray] = {} |
| if hasattr(forecast, "quantile"): |
| for q in quantile_levels: |
| q_values = forecast.quantile(q) |
| quantiles[q] = self._normalize_mean_shape(self._to_numpy(q_values), len(target_cols)) |
|
|
| samples_np = None |
| if hasattr(forecast, "samples"): |
| samples_np = self._to_numpy(forecast.samples) |
|
|
| return { |
| "mean": mean_2d, |
| "quantiles": quantiles, |
| "samples": samples_np, |
| "frequency": frequency, |
| "device_used": self.device, |
| "samples_per_batch_used": samples_per_batch, |
| "context_length_used": len(context_df), |
| "num_variates_used": context_df.shape[1], |
| } |
|
|
| def get_model_info(self) -> dict[str, Any]: |
| return { |
| "name": "Toto", |
| "full_name": "Toto-Open-Base-1.0", |
| "description": "Time Series Optimized Transformer for Observability by Datadog", |
| "supports_multivariate": True, |
| "supports_covariates": False, |
| "max_context_length": 4096, |
| "parameters": "151M", |
| "device": self.device or "auto", |
| } |
|
|