GradLLM / timesfs_backend.py
johnbridges's picture
.
67ec8f1
raw
history blame
9.01 kB
# timesfm_backend.py
import time, logging
from typing import Any, Dict, List, Optional, Tuple
import torch
try:
# If you install an official TimesFM package later, we’ll try to use it.
# (e.g., `pip install timesfm` if/when available)
import timesfm as tsm # type: ignore
except Exception:
tsm = None # graceful fallback
try:
# Optional: pull weights from HF if you want local inference
# pip install huggingface_hub
from huggingface_hub import snapshot_download
except Exception:
snapshot_download = None # optional
from backends_base import ImagesBackend # to mirror structure; not used here
from config import settings
logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------------------
# Config
# --------------------------------------------------------------------------------------
MODEL_ID = getattr(settings, "LlmHFModelID", None) or "google/timesfm-2.5-200m-pytorch"
DEFAULT_HORIZON = 24 # sensible default if caller omits
DEFAULT_FREQ = "H" # hour
ALLOW_GPU = True
# --------------------------------------------------------------------------------------
# Helpers
# --------------------------------------------------------------------------------------
def _pick_device() -> str:
if ALLOW_GPU and torch.cuda.is_available():
return "cuda"
return "cpu"
def _pick_dtype(device: str) -> torch.dtype:
# FP16 on CUDA, FP32 on CPU by default (safe and simple)
if device != "cpu":
return torch.float16
return torch.float32
def _as_1d_float_tensor(series: List[float], device: str, dtype: torch.dtype) -> torch.Tensor:
t = torch.tensor(series, dtype=torch.float32) # keep input parse stable
return t.to(device=device, dtype=dtype)
# --------------------------------------------------------------------------------------
# Fallback forecaster (naive)
# --------------------------------------------------------------------------------------
def _naive_forecast(x: torch.Tensor, horizon: int) -> torch.Tensor:
"""
Very simple fallback: repeat the last observed value for H steps.
Ensures the backend returns a forecast even without TimesFM installed.
"""
last = x[-1] if x.numel() > 0 else torch.tensor(0.0, device=x.device, dtype=x.dtype)
return last.repeat(horizon).to(dtype=x.dtype, device=x.device)
# --------------------------------------------------------------------------------------
# Backend
# --------------------------------------------------------------------------------------
class TimesFMBackend:
"""
Minimal forecasting backend. Input request (dict) shape:
{
"series": [float, ...], # required
"horizon": 48, # optional (default 24)
"freq": "H", # optional (default "H")
"normalize": true, # optional
"model_id": "google/...", # optional override
"use_gpu": true/false # optional
}
Output (dict):
{
"id": "tsfcst-...",
"object": "timeseries.forecast",
"created": 1234567890,
"model": "<model_id>",
"horizon": H,
"freq": "H",
"forecast": [float, ...],
"backend": "timesfm",
"note": "fallback-naive" # only when naive path used
}
"""
def __init__(self) -> None:
self._model = None
self._model_id = MODEL_ID
self._device = _pick_device()
self._dtype = _pick_dtype(self._device)
logger.info(f"[timesfm] init: model_id={self._model_id} device={self._device} dtype={self._dtype}")
# ---------- model load (best-effort) ----------
def _ensure_model(self, model_id: Optional[str] = None) -> None:
if self._model is not None and (not model_id or model_id == self._model_id):
return
want_id = model_id or self._model_id
self._model_id = want_id
if tsm is None:
logger.warning("[timesfm] timesfm package not available; using naive fallback")
self._model = None
return
# If the library provides a from_pretrained, use it; else attempt HF snapshot and custom load.
model = None
try:
if hasattr(tsm, "TimesFM") and hasattr(tsm.TimesFM, "from_pretrained"):
logger.info(f"[timesfm] loading via TimesFM.from_pretrained('{want_id}')")
model = tsm.TimesFM.from_pretrained(want_id) # type: ignore[attr-defined]
else:
# Manual path: download and let user wire loading code for their saved format
if snapshot_download is None:
raise RuntimeError("huggingface_hub not installed; cannot pull weights")
logger.info(f"[timesfm] snapshot_download('{want_id}')")
local_dir = snapshot_download(repo_id=want_id)
# TODO: Replace with actual load for the repo format if needed.
# Placeholder: try to import a generic torch file if present.
logger.warning(f"[timesfm] no direct loader available; using naive fallback. weights at {local_dir}")
model = None
except Exception as e:
logger.warning(f"[timesfm] failed to load model '{want_id}': {e}. Falling back to naive.")
model = None
self._model = model
if model is not None:
try:
self._model.to(self._device) # type: ignore[operator]
except Exception:
pass
logger.info("[timesfm] model ready on %s", self._device)
# ---------- public API ----------
async def forecast(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""
Async to match your other backends. Returns a single, non-streaming result dict.
"""
# parse inputs
model_id = request.get("model") or request.get("model_id") or self._model_id
series = request.get("series")
horizon = int(request.get("horizon") or DEFAULT_HORIZON)
freq = request.get("freq") or DEFAULT_FREQ
normalize = bool(request.get("normalize") or False)
use_gpu = request.get("use_gpu")
if use_gpu is not None:
self._device = "cuda" if (use_gpu and torch.cuda.is_available()) else "cpu"
self._dtype = _pick_dtype(self._device)
if not isinstance(series, (list, tuple)) or not all(isinstance(v, (int, float)) for v in series):
raise ValueError("request['series'] must be a list of numbers")
# ensure model (or fallback)
self._ensure_model(model_id)
# tensorize
x = _as_1d_float_tensor(list(series), self._device, self._dtype)
# optional normalization (z-score)
mu: Optional[torch.Tensor] = None
sigma: Optional[torch.Tensor] = None
if normalize and x.numel() > 1:
mu = x.mean()
sigma = x.std(unbiased=False).clamp_min(1e-6)
x_norm = (x - mu) / sigma
else:
x_norm = x
# run forecast
note = None
if self._model is None:
y_hat = _naive_forecast(x_norm, horizon)
note = "fallback-naive"
else:
try:
# Preferred path if the library supports it:
if hasattr(self._model, "forecast"):
y_hat = self._model.forecast(x_norm.unsqueeze(0), horizon=horizon) # type: ignore[attr-defined]
# Shape handling: [B, H] -> 1D
if isinstance(y_hat, (list, tuple)):
y_hat = torch.tensor(y_hat, device=x_norm.device, dtype=x_norm.dtype)
if isinstance(y_hat, torch.Tensor) and y_hat.dim() == 2:
y_hat = y_hat[0]
elif not isinstance(y_hat, torch.Tensor):
y_hat = torch.tensor(y_hat, device=x_norm.device, dtype=x_norm.dtype)
else:
# If no forecast method, fallback
y_hat = _naive_forecast(x_norm, horizon)
note = "fallback-naive"
except Exception as e:
logger.warning(f"[timesfm] forecast failed on model path: {e}. Using naive fallback.")
y_hat = _naive_forecast(x_norm, horizon)
note = "fallback-naive"
# denormalize
if normalize and mu is not None and sigma is not None:
y_hat = y_hat * sigma + mu
# move to cpu list
forecast = y_hat.detach().float().cpu().tolist()
rid = f"tsfcst-{int(time.time())}"
now = int(time.time())
resp = {
"id": rid,
"object": "timeseries.forecast",
"created": now,
"model": self._model_id,
"horizon": horizon,
"freq": freq,
"forecast": forecast,
"backend": "timesfm",
}
if note:
resp["note"] = note
return resp