time-series-api / service.py
Taylor1998's picture
Upload 2 files
ea1a58e verified
from __future__ import annotations
import math
import os
from statistics import mean
from typing import Any
from schemas import HealthResponse, PredictRequest, PredictResponse, PredictionItem
class TimesFmService:
"""HF Space service wrapper for TimesFM.
By default this service attempts real HuggingFace CPU inference. If model
loading fails and `TIMESFM_ALLOW_BASELINE_FALLBACK=true`, it falls back to
the deterministic baseline implementation.
"""
def __init__(self) -> None:
self.model_id = "timesfm"
self.model_name = os.getenv(
"TIMESFM_MODEL_NAME",
"google/timesfm-2.5-200m-transformers",
)
self.backend = os.getenv("TIMESFM_BACKEND", "hf_cpu").strip() or "hf_cpu"
self.device = "cpu"
self.runtime_revision = os.getenv("TIMESFM_RUNTIME_REVISION", "timesfm-hf-patch-align-v1")
self.max_context_length = int(os.getenv("TIMESFM_MAX_CONTEXT_LENGTH", "512"))
self.max_horizon_step = int(os.getenv("TIMESFM_MAX_HORIZON_STEP", "288"))
self.patch_length = int(os.getenv("TIMESFM_PATCH_LENGTH", "32"))
self.confidence_floor = float(os.getenv("TIMESFM_CONFIDENCE_FLOOR", "0.20"))
self.confidence_ceiling = float(os.getenv("TIMESFM_CONFIDENCE_CEILING", "0.85"))
self.min_required_points = int(os.getenv("TIMESFM_MIN_REQUIRED_POINTS", "32"))
self.allow_baseline_fallback = os.getenv("TIMESFM_ALLOW_BASELINE_FALLBACK", "false").lower() == "true"
self.ready = False
self.load_error = ""
self._torch = None
self._model = None
self._initialize_backend()
def health(self) -> HealthResponse:
return HealthResponse(
status="ok" if self.ready else "degraded",
model=self.model_name,
model_id=self.model_id,
backend=self.backend,
device=self.device,
ready=self.ready,
max_context_length=self.max_context_length,
max_horizon_step=self.max_horizon_step,
patch_length=self.patch_length,
runtime_revision=self.runtime_revision,
)
def predict(self, payload: PredictRequest) -> PredictResponse:
self._validate_request(payload)
closes = payload.close_prices[-payload.context_length :]
if self.backend == "hf_cpu":
if not self.ready:
raise RuntimeError(self.load_error or "timesfm backend not ready")
predictions = self._predict_with_hf(closes, payload.horizons)
else:
predictions = self._predict_with_baseline(closes, payload.horizons)
return PredictResponse(model_id=self.model_id, predictions=predictions)
def _initialize_backend(self) -> None:
if self.backend == "baseline_cpu":
self.ready = True
return
if self.backend != "hf_cpu":
raise ValueError(f"unsupported TIMESFM_BACKEND={self.backend}")
try:
self._load_hf_model()
self.ready = True
except Exception as exc:
self.load_error = f"timesfm hf load failed: {exc}"
if self.allow_baseline_fallback:
self.backend = "baseline_cpu"
self.ready = True
else:
self.ready = False
def _load_hf_model(self) -> None:
import torch
from transformers import TimesFm2_5ModelForPrediction
self._torch = torch
torch.set_num_threads(max(1, int(os.getenv("TIMESFM_TORCH_THREADS", "2"))))
self._model = TimesFm2_5ModelForPrediction.from_pretrained(
self.model_name,
torch_dtype=torch.float32,
)
self._model.to("cpu")
self._model.eval()
def _predict_with_hf(
self, close_prices: list[float], horizons: list[int]
) -> list[PredictionItem]:
assert self._torch is not None
assert self._model is not None
torch = self._torch
context = self._aligned_hf_context(close_prices)
past_values = [torch.tensor(context, dtype=torch.float32)]
freq = torch.tensor([0], dtype=torch.long)
with torch.inference_mode():
outputs = self._model(
past_values=past_values,
freq=freq,
forecast_context_len=len(context),
return_forecast_on_context=False,
)
dense_mean = outputs.mean_predictions[0].detach().cpu().tolist()
if len(dense_mean) < max(horizons):
raise RuntimeError(
f"TimesFM output horizon {len(dense_mean)} is shorter than requested {max(horizons)}"
)
dense_conf = self._timesfm_confidence(outputs, dense_mean)
predictions: list[PredictionItem] = []
for step in horizons:
predictions.append(
PredictionItem(
step=step,
pred_price=round(max(0.00000001, float(dense_mean[step - 1])), 8),
pred_confidence=round(dense_conf[step - 1], 4),
)
)
return predictions
def _aligned_hf_context(self, close_prices: list[float]) -> list[float]:
context = close_prices[-self.max_context_length :]
usable_length = len(context)
if self.patch_length > 0:
usable_length = (usable_length // self.patch_length) * self.patch_length
if usable_length < self.min_required_points:
raise RuntimeError(
f"TimesFM usable context length {usable_length} is below "
f"TIMESFM_MIN_REQUIRED_POINTS={self.min_required_points}"
)
return context[-usable_length:]
def _timesfm_confidence(self, outputs: Any, dense_mean: list[float]) -> list[float]:
full_predictions = getattr(outputs, "full_predictions", None)
if full_predictions is None:
return [self.confidence_floor for _ in dense_mean]
quantiles = full_predictions[0].detach().cpu()
confidence: list[float] = []
for idx, pred in enumerate(dense_mean):
if quantiles.ndim != 2 or idx >= quantiles.shape[0]:
confidence.append(self.confidence_floor)
continue
lower = float(quantiles[idx][0])
upper = float(quantiles[idx][-1])
band = abs(upper - lower) / max(abs(pred), 1e-6)
raw = 1.0 / (1.0 + band)
confidence.append(max(self.confidence_floor, min(self.confidence_ceiling, raw)))
return confidence
def _validate_request(self, payload: PredictRequest) -> None:
if payload.context_length > self.max_context_length:
raise ValueError(
f"context_length {payload.context_length} exceeds "
f"TIMESFM_MAX_CONTEXT_LENGTH={self.max_context_length}"
)
if payload.context_length > len(payload.close_prices):
raise ValueError("context_length must not exceed len(close_prices)")
if len(payload.close_prices) < self.min_required_points:
raise ValueError(
f"at least {self.min_required_points} close prices are required "
"for TimesFM stability"
)
if any(step > self.max_horizon_step for step in payload.horizons):
raise ValueError(
f"horizons contain values above TIMESFM_MAX_HORIZON_STEP={self.max_horizon_step}"
)
def _predict_with_baseline(
self, close_prices: list[float], horizons: list[int]
) -> list[PredictionItem]:
last_price = close_prices[-1]
short_window = close_prices[-min(8, len(close_prices)) :]
long_window = close_prices[-min(32, len(close_prices)) :]
short_mean = mean(short_window)
long_mean = mean(long_window)
momentum = 0.0 if short_mean == 0 else (last_price - short_mean) / short_mean
regime_bias = 0.0 if long_mean == 0 else (short_mean - long_mean) / long_mean
predictions: list[PredictionItem] = []
for step in horizons:
damped_step = math.log(step + 1.0)
expected_return = momentum * 0.55 + regime_bias * 0.45
expected_return *= min(1.0, damped_step / 3.5)
pred_price = max(0.00000001, last_price * (1.0 + expected_return))
confidence = self._baseline_confidence(close_prices, step, abs(expected_return))
predictions.append(
PredictionItem(
step=step,
pred_price=round(pred_price, 8),
pred_confidence=round(confidence, 4),
)
)
return predictions
def _baseline_confidence(
self, close_prices: list[float], step: int, expected_move_abs: float
) -> float:
if len(close_prices) < 3:
return self.confidence_floor
changes: list[float] = []
for previous, current in zip(close_prices[:-1], close_prices[1:]):
if previous <= 0:
continue
changes.append(abs((current - previous) / previous))
realized_vol = mean(changes[-min(32, len(changes)) :]) if changes else 0.0
signal_to_noise = expected_move_abs / (realized_vol + 1e-9)
horizon_decay = 1.0 / (1.0 + math.log(step + 1.0))
raw = 0.25 + min(signal_to_noise, 2.0) * 0.25 + horizon_decay * 0.35
return max(self.confidence_floor, min(self.confidence_ceiling, raw))
def describe_runtime(self) -> dict[str, Any]:
return {
"model_id": self.model_id,
"model_name": self.model_name,
"backend": self.backend,
"device": self.device,
"ready": self.ready,
"load_error": self.load_error,
"max_context_length": self.max_context_length,
"max_horizon_step": self.max_horizon_step,
"min_required_points": self.min_required_points,
"patch_length": self.patch_length,
}