| |
| import time, logging |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import torch |
|
|
| try: |
| |
| |
| import timesfm as tsm |
| except Exception: |
| tsm = None |
|
|
| try: |
| |
| |
| from huggingface_hub import snapshot_download |
| except Exception: |
| snapshot_download = None |
|
|
| from backends_base import ImagesBackend |
| from config import settings |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| MODEL_ID = getattr(settings, "LlmHFModelID", None) or "google/timesfm-2.5-200m-pytorch" |
| DEFAULT_HORIZON = 24 |
| DEFAULT_FREQ = "H" |
| ALLOW_GPU = True |
|
|
| |
| |
| |
| def _pick_device() -> str: |
| if ALLOW_GPU and torch.cuda.is_available(): |
| return "cuda" |
| return "cpu" |
|
|
| def _pick_dtype(device: str) -> torch.dtype: |
| |
| if device != "cpu": |
| return torch.float16 |
| return torch.float32 |
|
|
| def _as_1d_float_tensor(series: List[float], device: str, dtype: torch.dtype) -> torch.Tensor: |
| t = torch.tensor(series, dtype=torch.float32) |
| return t.to(device=device, dtype=dtype) |
|
|
| |
| |
| |
| def _naive_forecast(x: torch.Tensor, horizon: int) -> torch.Tensor: |
| """ |
| Very simple fallback: repeat the last observed value for H steps. |
| Ensures the backend returns a forecast even without TimesFM installed. |
| """ |
| last = x[-1] if x.numel() > 0 else torch.tensor(0.0, device=x.device, dtype=x.dtype) |
| return last.repeat(horizon).to(dtype=x.dtype, device=x.device) |
|
|
| |
| |
| |
| class TimesFMBackend: |
| """ |
| Minimal forecasting backend. Input request (dict) shape: |
| |
| { |
| "series": [float, ...], # required |
| "horizon": 48, # optional (default 24) |
| "freq": "H", # optional (default "H") |
| "normalize": true, # optional |
| "model_id": "google/...", # optional override |
| "use_gpu": true/false # optional |
| } |
| |
| Output (dict): |
| { |
| "id": "tsfcst-...", |
| "object": "timeseries.forecast", |
| "created": 1234567890, |
| "model": "<model_id>", |
| "horizon": H, |
| "freq": "H", |
| "forecast": [float, ...], |
| "backend": "timesfm", |
| "note": "fallback-naive" # only when naive path used |
| } |
| """ |
|
|
| def __init__(self) -> None: |
| self._model = None |
| self._model_id = MODEL_ID |
| self._device = _pick_device() |
| self._dtype = _pick_dtype(self._device) |
| logger.info(f"[timesfm] init: model_id={self._model_id} device={self._device} dtype={self._dtype}") |
|
|
| |
| def _ensure_model(self, model_id: Optional[str] = None) -> None: |
| if self._model is not None and (not model_id or model_id == self._model_id): |
| return |
|
|
| want_id = model_id or self._model_id |
| self._model_id = want_id |
|
|
| if tsm is None: |
| logger.warning("[timesfm] timesfm package not available; using naive fallback") |
| self._model = None |
| return |
|
|
| |
| model = None |
| try: |
| if hasattr(tsm, "TimesFM") and hasattr(tsm.TimesFM, "from_pretrained"): |
| logger.info(f"[timesfm] loading via TimesFM.from_pretrained('{want_id}')") |
| model = tsm.TimesFM.from_pretrained(want_id) |
| else: |
| |
| if snapshot_download is None: |
| raise RuntimeError("huggingface_hub not installed; cannot pull weights") |
| logger.info(f"[timesfm] snapshot_download('{want_id}')") |
| local_dir = snapshot_download(repo_id=want_id) |
| |
| |
| logger.warning(f"[timesfm] no direct loader available; using naive fallback. weights at {local_dir}") |
| model = None |
| except Exception as e: |
| logger.warning(f"[timesfm] failed to load model '{want_id}': {e}. Falling back to naive.") |
| model = None |
|
|
| self._model = model |
| if model is not None: |
| try: |
| self._model.to(self._device) |
| except Exception: |
| pass |
| logger.info("[timesfm] model ready on %s", self._device) |
|
|
| |
| async def forecast(self, request: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Async to match your other backends. Returns a single, non-streaming result dict. |
| """ |
| |
| model_id = request.get("model") or request.get("model_id") or self._model_id |
| series = request.get("series") |
| horizon = int(request.get("horizon") or DEFAULT_HORIZON) |
| freq = request.get("freq") or DEFAULT_FREQ |
| normalize = bool(request.get("normalize") or False) |
| use_gpu = request.get("use_gpu") |
| if use_gpu is not None: |
| self._device = "cuda" if (use_gpu and torch.cuda.is_available()) else "cpu" |
| self._dtype = _pick_dtype(self._device) |
|
|
| if not isinstance(series, (list, tuple)) or not all(isinstance(v, (int, float)) for v in series): |
| raise ValueError("request['series'] must be a list of numbers") |
|
|
| |
| self._ensure_model(model_id) |
|
|
| |
| x = _as_1d_float_tensor(list(series), self._device, self._dtype) |
|
|
| |
| mu: Optional[torch.Tensor] = None |
| sigma: Optional[torch.Tensor] = None |
| if normalize and x.numel() > 1: |
| mu = x.mean() |
| sigma = x.std(unbiased=False).clamp_min(1e-6) |
| x_norm = (x - mu) / sigma |
| else: |
| x_norm = x |
|
|
| |
| note = None |
| if self._model is None: |
| y_hat = _naive_forecast(x_norm, horizon) |
| note = "fallback-naive" |
| else: |
| try: |
| |
| if hasattr(self._model, "forecast"): |
| y_hat = self._model.forecast(x_norm.unsqueeze(0), horizon=horizon) |
| |
| if isinstance(y_hat, (list, tuple)): |
| y_hat = torch.tensor(y_hat, device=x_norm.device, dtype=x_norm.dtype) |
| if isinstance(y_hat, torch.Tensor) and y_hat.dim() == 2: |
| y_hat = y_hat[0] |
| elif not isinstance(y_hat, torch.Tensor): |
| y_hat = torch.tensor(y_hat, device=x_norm.device, dtype=x_norm.dtype) |
| else: |
| |
| y_hat = _naive_forecast(x_norm, horizon) |
| note = "fallback-naive" |
| except Exception as e: |
| logger.warning(f"[timesfm] forecast failed on model path: {e}. Using naive fallback.") |
| y_hat = _naive_forecast(x_norm, horizon) |
| note = "fallback-naive" |
|
|
| |
| if normalize and mu is not None and sigma is not None: |
| y_hat = y_hat * sigma + mu |
|
|
| |
| forecast = y_hat.detach().float().cpu().tolist() |
|
|
| rid = f"tsfcst-{int(time.time())}" |
| now = int(time.time()) |
| resp = { |
| "id": rid, |
| "object": "timeseries.forecast", |
| "created": now, |
| "model": self._model_id, |
| "horizon": horizon, |
| "freq": freq, |
| "forecast": forecast, |
| "backend": "timesfm", |
| } |
| if note: |
| resp["note"] = note |
| return resp |
|
|