walidhadri's picture
Initial HF Space app
84f224f
from __future__ import annotations
import warnings
from typing import Any
import numpy as np
import pandas as pd
import torch
from models.base_model import BaseForecastModel
class Chronos2Model(BaseForecastModel):
def __init__(self, model_id: str = "amazon/chronos-2", device_map: str | None = None) -> None:
self.model_id = model_id
self.device_map = device_map
self.pipeline = None
@staticmethod
def _is_cuda_runtime_error(exc: RuntimeError) -> bool:
message = str(exc).lower()
cuda_signals = (
"cuda error",
"cublas",
"cudnn",
"device-side assert",
"invalid configuration argument",
"illegal memory access",
"out of memory",
)
return any(signal in message for signal in cuda_signals)
def _resolve_device_map(self) -> str:
if self.device_map:
return self.device_map
try:
return "cuda" if torch.cuda.is_available() else "cpu"
except Exception: # noqa: BLE001
return "cpu"
def load_model(self) -> None:
try:
from chronos import Chronos2Pipeline
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Missing dependency 'chronos-forecasting'. Install it with: "
"`pip install chronos-forecasting` "
"(if pip resolver conflicts with toto-ts, try `pip install --no-deps chronos-forecasting`)."
) from exc
chosen_device = self._resolve_device_map()
try:
self.pipeline = Chronos2Pipeline.from_pretrained(self.model_id, device_map=chosen_device)
self.device_map = chosen_device
except RuntimeError as exc:
if chosen_device == "cuda" and self._is_cuda_runtime_error(exc):
warnings.warn(
"CUDA failed while loading Chronos-2. Falling back to CPU.",
RuntimeWarning,
stacklevel=2,
)
self.pipeline = Chronos2Pipeline.from_pretrained(self.model_id, device_map="cpu")
self.device_map = "cpu"
else:
raise
@staticmethod
def _to_long_context(context_df: pd.DataFrame) -> pd.DataFrame:
if context_df.empty:
raise ValueError("Chronos-2 received an empty context dataframe.")
ts_col = context_df.index.name or "index"
prepared = context_df.copy().reset_index().rename(columns={ts_col: "timestamp"})
prepared["timestamp"] = pd.to_datetime(prepared["timestamp"], errors="coerce")
prepared = prepared.dropna(subset=["timestamp"]).sort_values("timestamp")
if prepared.empty:
raise ValueError("Chronos-2 could not parse timestamps from context data.")
id_column = "id"
target_column = "target"
long_df = prepared.melt(
id_vars=["timestamp"],
value_vars=list(context_df.columns),
var_name=id_column,
value_name=target_column,
)
long_df[target_column] = pd.to_numeric(long_df[target_column], errors="coerce")
long_df = long_df.dropna(subset=[target_column])
if long_df.empty:
raise ValueError("Chronos-2 context data has no numeric values after preprocessing.")
return long_df[[id_column, "timestamp", target_column]]
@staticmethod
def _find_point_column(pred_df: pd.DataFrame) -> Any:
if "predictions" in pred_df.columns:
return "predictions"
if "prediction" in pred_df.columns:
return "prediction"
if "0.5" in pred_df.columns:
return "0.5"
if 0.5 in pred_df.columns:
return 0.5
quantile_candidates: list[tuple[float, Any]] = []
for col in pred_df.columns:
try:
q = float(str(col))
except (TypeError, ValueError):
continue
quantile_candidates.append((q, col))
if quantile_candidates:
_, closest_col = min(quantile_candidates, key=lambda item: abs(item[0] - 0.5))
return closest_col
raise ValueError("Chronos-2 output did not include a point forecast column.")
@staticmethod
def _extract_quantile_columns(pred_df: pd.DataFrame) -> dict[float, Any]:
quantile_cols: dict[float, Any] = {}
for col in pred_df.columns:
try:
q = float(str(col))
except (TypeError, ValueError):
continue
if 0.0 <= q <= 1.0:
quantile_cols[q] = col
return quantile_cols
@staticmethod
def _pivot_forecast(
pred_df: pd.DataFrame,
value_column: Any,
target_cols: list[str],
prediction_length: int,
) -> tuple[np.ndarray, pd.DatetimeIndex]:
timestamp_col = "timestamp"
id_col = "id"
if not {id_col, timestamp_col, value_column}.issubset(set(pred_df.columns)):
missing = {id_col, timestamp_col, value_column}.difference(set(pred_df.columns))
raise ValueError(f"Chronos-2 output missing columns: {sorted(missing)}")
prepared = pred_df[[id_col, timestamp_col, value_column]].copy()
prepared[timestamp_col] = pd.to_datetime(prepared[timestamp_col], errors="coerce")
prepared[value_column] = pd.to_numeric(prepared[value_column], errors="coerce")
prepared = prepared.dropna(subset=[timestamp_col, value_column]).sort_values([id_col, timestamp_col])
pivot = prepared.pivot_table(index=timestamp_col, columns=id_col, values=value_column, aggfunc="last")
pivot = pivot.sort_index()
pivot = pivot.reindex(columns=target_cols)
if pivot.empty:
raise ValueError("Chronos-2 returned empty forecasts after pivot.")
pivot = pivot.tail(prediction_length)
if len(pivot) < prediction_length:
raise ValueError(
f"Chronos-2 returned {len(pivot)} rows, expected at least {prediction_length}."
)
return pivot.values.astype(float), pd.DatetimeIndex(pivot.index)
def predict(
self,
context_data: Any,
prediction_length: int,
num_samples: int = 256,
quantile_levels: tuple[float, ...] = (0.1, 0.5, 0.9),
**kwargs: Any,
) -> dict[str, Any]:
_ = num_samples
if self.pipeline is None:
self.load_model()
if isinstance(context_data, dict):
context_df = context_data.get("context")
frequency = context_data.get("frequency")
target_cols = context_data.get("target_cols") or (list(context_df.columns) if context_df is not None else [])
else:
context_df = context_data
frequency = None
target_cols = list(context_df.columns) if isinstance(context_df, pd.DataFrame) else []
if context_df is None or not isinstance(context_df, pd.DataFrame):
raise ValueError("Chronos-2 expects context_data to include a pandas DataFrame under 'context'.")
if not target_cols:
raise ValueError("Chronos-2 requires at least one target column.")
long_context = self._to_long_context(context_df)
prediction_df = self.pipeline.predict_df(
long_context,
prediction_length=prediction_length,
quantile_levels=list(quantile_levels),
id_column="id",
timestamp_column="timestamp",
target="target",
**kwargs,
)
if prediction_df is None or len(prediction_df) == 0:
raise ValueError("Chronos-2 returned an empty prediction dataframe.")
point_col = self._find_point_column(prediction_df)
mean_values, forecast_index = self._pivot_forecast(
prediction_df,
value_column=point_col,
target_cols=target_cols,
prediction_length=prediction_length,
)
quantiles: dict[float, np.ndarray] = {}
quantile_columns = self._extract_quantile_columns(prediction_df)
for q, col in quantile_columns.items():
try:
q_values, _ = self._pivot_forecast(
prediction_df,
value_column=col,
target_cols=target_cols,
prediction_length=prediction_length,
)
quantiles[q] = q_values
except Exception: # noqa: BLE001
continue
return {
"mean": mean_values,
"quantiles": quantiles,
"samples": None,
"forecast_index": forecast_index,
"frequency": frequency,
"device_used": self.device_map or "auto",
"samples_per_batch_used": "n/a",
"context_length_used": len(context_df),
"num_variates_used": context_df.shape[1],
}
def get_model_info(self) -> dict[str, Any]:
return {
"name": "Chronos-2",
"full_name": "amazon/chronos-2",
"description": "Universal foundation model by Amazon for univariate/multivariate forecasting.",
"supports_multivariate": True,
"supports_covariates": True,
"max_context_length": "Auto",
"parameters": "120M",
"device": self.device_map or "auto",
}