GradLLM / timesfm_backend.py
Mungert's picture
Update timesfm_backend.py
c7f8c69 verified
raw
history blame
9.17 kB
import time
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Sequence
import numpy as np
import torch
from backends_base import ChatBackend, ImagesBackend
from config import settings
logger = logging.getLogger(__name__)
# ---------- helpers ----------
def _parse_series(series: Any) -> np.ndarray:
"""
Accepts: list[float|int], list[dict{'y'|'value'}], or dict with 'values'/'y'.
Returns: 1D float32 numpy array.
"""
if series is None:
raise ValueError("series is required")
if isinstance(series, dict):
series = series.get("values") or series.get("y")
vals: List[float] = []
if isinstance(series, (list, tuple)):
if series and isinstance(series[0], dict):
for item in series:
if "y" in item:
vals.append(float(item["y"]))
elif "value" in item:
vals.append(float(item["value"]))
else:
vals = [float(x) for x in series]
else:
raise ValueError("series must be a list/tuple or dict with 'values'/'y'")
if not vals:
raise ValueError("series is empty")
return np.asarray(vals, dtype=np.float32)
def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]:
s = s.strip()
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
try:
obj = json.loads(s)
return obj if isinstance(obj, dict) else None
except Exception:
pass
if "```" in s:
parts = s.split("```")
for i in range(1, len(parts), 2):
block = parts[i]
if block.lstrip().lower().startswith("json"):
block = block.split("\n", 1)[-1]
try:
obj = json.loads(block.strip())
return obj if isinstance(obj, dict) else None
except Exception:
continue
return None
def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]:
msgs = payload.get("messages")
if not isinstance(msgs, list):
return payload
for m in reversed(msgs):
if not isinstance(m, dict) or m.get("role") != "user":
continue
content = m.get("content")
texts: List[str] = []
if isinstance(content, list):
texts = [
p.get("text")
for p in content
if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str)
]
elif isinstance(content, str):
texts = [content]
for t in reversed(texts):
obj = _extract_json_from_text(t)
if isinstance(obj, dict):
return {**payload, **obj}
break
return payload
# ---------- backend ----------
class TimesFMBackend(ChatBackend):
"""
TimesFM 2.5 backend.
Input JSON can be in top-level keys, in CloudEvents .data, or embedded in last user message.
Keys:
series: list[float|int|{y|value}] OR list of such lists for batch
horizon: int (>0)
Optional:
quantiles: bool (default True) -> include quantile forecasts
max_context, max_horizon: ints to override defaults
"""
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
# HF id for bookkeeping only
self.model_id = model_id or "google/timesfm-2.5-200m-pytorch"
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self._model = None # lazy
def _ensure_model(self) -> None:
if self._model is not None:
return
try:
import os
import timesfm # 2.5 API
hf_token = getattr(settings, "HF_TOKEN", None) or os.environ.get("HF_TOKEN")
cache_dir = getattr(settings, "TIMESFM_CACHE_DIR", None)
model = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
self.model_id,
token=hf_token,
cache_dir=cache_dir,
local_files_only=False,
)
try:
# .model holds the underlying nn.Module; fall back to instance if absent.
target = getattr(model, "model", model)
target.to(self.device) # type: ignore[arg-type]
except Exception:
pass
cfg = timesfm.ForecastConfig(
max_context=1024,
max_horizon=256,
normalize_inputs=True,
use_continuous_quantile_head=True,
force_flip_invariance=True,
infer_is_positive=True,
fix_quantile_crossing=True,
)
model.compile(cfg)
self._model = model
logger.info("TimesFM 2.5 model loaded on %s", self.device)
except Exception as e:
logger.exception("TimesFM 2.5 init failed")
raise RuntimeError(f"timesfm 2.5 init failed: {e}") from e
def _prepare_inputs(self, payload: Dict[str, Any]) -> Tuple[List[np.ndarray], int, bool, Dict[str, int]]:
# unwrap CloudEvents and nested keys
if isinstance(payload.get("data"), dict):
payload = {**payload, **payload["data"]}
if isinstance(payload.get("timeseries"), dict):
payload = {**payload, **payload["timeseries"]}
# merge JSON in last user message
payload = _merge_openai_message_json(payload)
horizon = int(payload.get("horizon", 0))
if horizon <= 0:
raise ValueError("horizon must be a positive integer")
quantiles = bool(payload.get("quantiles", True))
mc = int(payload.get("max_context", 1024))
mh = int(payload.get("max_horizon", 256))
series = payload.get("series")
inputs: List[np.ndarray]
if isinstance(series, list) and series and isinstance(series[0], (list, tuple, dict)):
# batch input
inputs = [_parse_series(s) for s in series]
else:
# single series -> batch of 1
inputs = [_parse_series(series)]
return inputs, horizon, quantiles, {"max_context": mc, "max_horizon": mh}
async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]:
inputs, horizon, want_quantiles, cfg_overrides = self._prepare_inputs(payload)
self._ensure_model()
# if user wants larger limits, recompile once
try:
import timesfm
if cfg_overrides["max_context"] != 1024 or cfg_overrides["max_horizon"] != 256:
cfg = timesfm.ForecastConfig(
max_context=cfg_overrides["max_context"],
max_horizon=cfg_overrides["max_horizon"],
normalize_inputs=True,
use_continuous_quantile_head=want_quantiles,
force_flip_invariance=True,
infer_is_positive=True,
fix_quantile_crossing=True,
)
self._model.compile(cfg)
except Exception:
pass
try:
point, quant = self._model.forecast(horizon=horizon, inputs=inputs)
point_list = [row.astype(float).tolist() for row in point] # shape (B, H)
quant_list = None
if want_quantiles and quant is not None:
# shape (B, H, 10): mean, q10..q90
quant_list = [[row[h].astype(float).tolist() for h in range(row.shape[0])] for row in quant]
except Exception as e:
logger.exception("TimesFM 2.5 forecast failed")
raise RuntimeError(f"forecast failed: {e}") from e
# If single-series input, unwrap batch dim for convenience
single = len(inputs) == 1
return {
"model": self.model_id,
"horizon": horizon,
"forecast": point_list[0] if single else point_list,
"quantiles": (quant_list[0] if single else quant_list) if want_quantiles else None,
"backend": "timesfm-2.5",
}
async def stream(self, request: Dict[str, Any]):
rid = f"chatcmpl-timesfm-{int(time.time())}"
now = int(time.time())
try:
result = await self.forecast(dict(request) if isinstance(request, dict) else {})
content = json.dumps(result, separators=(",", ":"), ensure_ascii=False)
except Exception as e:
content = json.dumps({"error": str(e)}, separators=(",", ":"), ensure_ascii=False)
yield {
"id": rid,
"object": "chat.completion.chunk",
"created": now,
"model": self.model_id,
"choices": [
{"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"}
],
}
class StubImagesBackend(ImagesBackend):
async def generate_b64(self, request: Dict[str, Any]) -> str:
logger.warning("Image generation not supported in TimesFM backend.")
return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="