walidhadri's picture
Initial HF Space app
84f224f
from __future__ import annotations
import warnings
from typing import Any
import numpy as np
import pandas as pd
import torch
from models.base_model import BaseForecastModel
class TotoModel(BaseForecastModel):
def __init__(self, device: str | None = None) -> None:
self.device = device
self.model = None
self.forecaster = None
def load_model(self) -> None:
try:
import pkg_resources # noqa: F401
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Missing dependency 'pkg_resources'. Install a compatible setuptools "
"version (e.g. `<81`): `pip install 'setuptools<81'`."
) from exc
from toto.inference.forecaster import TotoForecaster
from toto.model.toto import Toto
if self.device is None:
try:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
except Exception: # noqa: BLE001
self.device = "cpu"
try:
self.model = Toto.from_pretrained("Datadog/Toto-Open-Base-1.0").to(self.device)
except RuntimeError as exc:
if self.device == "cuda" and self._is_cuda_runtime_error(exc):
warnings.warn(
"CUDA failed while loading Toto. Falling back to CPU.",
RuntimeWarning,
stacklevel=2,
)
self.device = "cpu"
self.model = Toto.from_pretrained("Datadog/Toto-Open-Base-1.0").to(self.device)
else:
raise
core_model = getattr(self.model, "model", self.model)
self.forecaster = TotoForecaster(core_model)
@staticmethod
def _is_cuda_runtime_error(exc: RuntimeError) -> bool:
message = str(exc).lower()
cuda_signals = (
"cuda error",
"cublas",
"cudnn",
"device-side assert",
"invalid configuration argument",
"illegal memory access",
)
return any(signal in message for signal in cuda_signals)
def _reload_on_cpu(self) -> None:
self.device = "cpu"
self.model = None
self.forecaster = None
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception: # noqa: BLE001
pass
self.load_model()
@staticmethod
def _largest_divisor_leq(value: int, cap: int) -> int:
cap = max(1, min(value, cap))
for candidate in range(cap, 0, -1):
if value % candidate == 0:
return candidate
return 1
def _choose_samples_per_batch(self, context_df: pd.DataFrame, num_samples: int) -> int:
context_len = len(context_df)
num_variates = context_df.shape[1]
if self.device == "cuda":
if context_len > 8192 or num_variates >= 6:
cap = 8
elif context_len > 4096 or num_variates >= 4:
cap = 16
else:
cap = 32
else:
cap = 64
return self._largest_divisor_leq(num_samples, cap)
def _to_numpy(self, value: Any) -> np.ndarray:
if value is None:
raise ValueError("Received empty forecast value.")
if isinstance(value, np.ndarray):
return value
if hasattr(value, "detach"):
return value.detach().cpu().numpy()
if hasattr(value, "cpu") and hasattr(value, "numpy"):
return value.cpu().numpy()
return np.asarray(value)
def _normalize_mean_shape(self, arr: np.ndarray, n_vars: int) -> np.ndarray:
arr = np.asarray(arr)
arr = np.squeeze(arr)
if arr.ndim == 1:
return arr.reshape(-1, 1)
if arr.ndim != 2:
raise ValueError(f"Unexpected prediction shape: {arr.shape}")
if arr.shape[1] == n_vars:
return arr
if arr.shape[0] == n_vars:
return arr.T
if n_vars == 1:
return arr.reshape(-1, 1)
raise ValueError(f"Could not align forecast shape {arr.shape} for {n_vars} variables.")
def _build_model_input(self, context_df: pd.DataFrame) -> Any:
from toto.data.util.dataset import MaskedTimeseries
# Toto expects batched tensors: [batch, variates, seq_len]
values_2d = torch.tensor(context_df.values.T, dtype=torch.float32, device=self.device)
num_variates, _ = values_2d.shape
values = values_2d.unsqueeze(0)
ts_1d = (
torch.tensor(context_df.index.view("int64"), dtype=torch.int64, device=self.device)
// 1_000_000_000
)
if len(ts_1d) >= 2:
diffs = ts_1d[1:] - ts_1d[:-1]
interval = int(torch.median(diffs).item())
if interval <= 0:
interval = 1
else:
interval = 1
timestamp_seconds = ts_1d.unsqueeze(0).unsqueeze(0).repeat(1, num_variates, 1)
time_interval_seconds = torch.full((1, num_variates), interval, dtype=torch.int64, device=self.device)
id_mask = torch.zeros_like(values, dtype=torch.int64)
return MaskedTimeseries(
series=values,
padding_mask=torch.full_like(values, True, dtype=torch.bool),
id_mask=id_mask,
timestamp_seconds=timestamp_seconds,
time_interval_seconds=time_interval_seconds,
)
def predict(
self,
context_data: Any,
prediction_length: int,
num_samples: int = 256,
quantile_levels: tuple[float, ...] = (0.1, 0.5, 0.9),
**kwargs: Any,
) -> dict[str, Any]:
if self.forecaster is None:
self.load_model()
if isinstance(context_data, dict):
context_df = context_data.get("context")
frequency = context_data.get("frequency")
target_cols = context_data.get("target_cols") or list(context_df.columns)
else:
context_df = context_data
frequency = None
target_cols = list(context_df.columns)
if context_df is None or not isinstance(context_df, pd.DataFrame):
raise ValueError("Toto expects context_data to include a pandas DataFrame under 'context'.")
inputs = self._build_model_input(context_df)
samples_per_batch = self._choose_samples_per_batch(context_df, num_samples)
try:
forecast = self.forecaster.forecast(
inputs,
prediction_length=prediction_length,
num_samples=num_samples,
samples_per_batch=samples_per_batch,
**kwargs,
)
except RuntimeError as exc:
if self.device == "cuda" and self._is_cuda_runtime_error(exc):
warnings.warn(
"CUDA failed during Toto forecasting. Retrying on CPU.",
RuntimeWarning,
stacklevel=2,
)
self._reload_on_cpu()
inputs = self._build_model_input(context_df)
forecast = self.forecaster.forecast(
inputs,
prediction_length=prediction_length,
num_samples=num_samples,
samples_per_batch=samples_per_batch,
**kwargs,
)
else:
raise
median_value = getattr(forecast, "median", None)
if callable(median_value):
median_value = median_value()
if median_value is None and hasattr(forecast, "mean"):
maybe_mean = getattr(forecast, "mean")
median_value = maybe_mean() if callable(maybe_mean) else maybe_mean
if median_value is None:
samples = getattr(forecast, "samples", None)
if samples is None:
raise RuntimeError("Toto forecast object did not include median/mean/samples.")
samples_np = self._to_numpy(samples)
median_np = np.median(samples_np, axis=0)
else:
median_np = self._to_numpy(median_value)
mean_2d = self._normalize_mean_shape(median_np, n_vars=len(target_cols))
quantiles: dict[float, np.ndarray] = {}
if hasattr(forecast, "quantile"):
for q in quantile_levels:
q_values = forecast.quantile(q)
quantiles[q] = self._normalize_mean_shape(self._to_numpy(q_values), len(target_cols))
samples_np = None
if hasattr(forecast, "samples"):
samples_np = self._to_numpy(forecast.samples)
return {
"mean": mean_2d,
"quantiles": quantiles,
"samples": samples_np,
"frequency": frequency,
"device_used": self.device,
"samples_per_batch_used": samples_per_batch,
"context_length_used": len(context_df),
"num_variates_used": context_df.shape[1],
}
def get_model_info(self) -> dict[str, Any]:
return {
"name": "Toto",
"full_name": "Toto-Open-Base-1.0",
"description": "Time Series Optimized Transformer for Observability by Datadog",
"supports_multivariate": True,
"supports_covariates": False,
"max_context_length": 4096,
"parameters": "151M",
"device": self.device or "auto",
}