from __future__ import annotations import math import os from statistics import mean from typing import Any from schemas import HealthResponse, PredictRequest, PredictResponse, PredictionItem class TimesFmService: """HF Space service wrapper for TimesFM. By default this service attempts real HuggingFace CPU inference. If model loading fails and `TIMESFM_ALLOW_BASELINE_FALLBACK=true`, it falls back to the deterministic baseline implementation. """ def __init__(self) -> None: self.model_id = "timesfm" self.model_name = os.getenv( "TIMESFM_MODEL_NAME", "google/timesfm-2.5-200m-transformers", ) self.backend = os.getenv("TIMESFM_BACKEND", "hf_cpu").strip() or "hf_cpu" self.device = "cpu" self.runtime_revision = os.getenv("TIMESFM_RUNTIME_REVISION", "timesfm-hf-patch-align-v1") self.max_context_length = int(os.getenv("TIMESFM_MAX_CONTEXT_LENGTH", "512")) self.max_horizon_step = int(os.getenv("TIMESFM_MAX_HORIZON_STEP", "288")) self.patch_length = int(os.getenv("TIMESFM_PATCH_LENGTH", "32")) self.confidence_floor = float(os.getenv("TIMESFM_CONFIDENCE_FLOOR", "0.20")) self.confidence_ceiling = float(os.getenv("TIMESFM_CONFIDENCE_CEILING", "0.85")) self.min_required_points = int(os.getenv("TIMESFM_MIN_REQUIRED_POINTS", "32")) self.allow_baseline_fallback = os.getenv("TIMESFM_ALLOW_BASELINE_FALLBACK", "false").lower() == "true" self.ready = False self.load_error = "" self._torch = None self._model = None self._initialize_backend() def health(self) -> HealthResponse: return HealthResponse( status="ok" if self.ready else "degraded", model=self.model_name, model_id=self.model_id, backend=self.backend, device=self.device, ready=self.ready, max_context_length=self.max_context_length, max_horizon_step=self.max_horizon_step, patch_length=self.patch_length, runtime_revision=self.runtime_revision, ) def predict(self, payload: PredictRequest) -> PredictResponse: self._validate_request(payload) closes = payload.close_prices[-payload.context_length :] if self.backend == "hf_cpu": if not self.ready: raise RuntimeError(self.load_error or "timesfm backend not ready") predictions = self._predict_with_hf(closes, payload.horizons) else: predictions = self._predict_with_baseline(closes, payload.horizons) return PredictResponse(model_id=self.model_id, predictions=predictions) def _initialize_backend(self) -> None: if self.backend == "baseline_cpu": self.ready = True return if self.backend != "hf_cpu": raise ValueError(f"unsupported TIMESFM_BACKEND={self.backend}") try: self._load_hf_model() self.ready = True except Exception as exc: self.load_error = f"timesfm hf load failed: {exc}" if self.allow_baseline_fallback: self.backend = "baseline_cpu" self.ready = True else: self.ready = False def _load_hf_model(self) -> None: import torch from transformers import TimesFm2_5ModelForPrediction self._torch = torch torch.set_num_threads(max(1, int(os.getenv("TIMESFM_TORCH_THREADS", "2")))) self._model = TimesFm2_5ModelForPrediction.from_pretrained( self.model_name, torch_dtype=torch.float32, ) self._model.to("cpu") self._model.eval() def _predict_with_hf( self, close_prices: list[float], horizons: list[int] ) -> list[PredictionItem]: assert self._torch is not None assert self._model is not None torch = self._torch context = self._aligned_hf_context(close_prices) past_values = [torch.tensor(context, dtype=torch.float32)] freq = torch.tensor([0], dtype=torch.long) with torch.inference_mode(): outputs = self._model( past_values=past_values, freq=freq, forecast_context_len=len(context), return_forecast_on_context=False, ) dense_mean = outputs.mean_predictions[0].detach().cpu().tolist() if len(dense_mean) < max(horizons): raise RuntimeError( f"TimesFM output horizon {len(dense_mean)} is shorter than requested {max(horizons)}" ) dense_conf = self._timesfm_confidence(outputs, dense_mean) predictions: list[PredictionItem] = [] for step in horizons: predictions.append( PredictionItem( step=step, pred_price=round(max(0.00000001, float(dense_mean[step - 1])), 8), pred_confidence=round(dense_conf[step - 1], 4), ) ) return predictions def _aligned_hf_context(self, close_prices: list[float]) -> list[float]: context = close_prices[-self.max_context_length :] usable_length = len(context) if self.patch_length > 0: usable_length = (usable_length // self.patch_length) * self.patch_length if usable_length < self.min_required_points: raise RuntimeError( f"TimesFM usable context length {usable_length} is below " f"TIMESFM_MIN_REQUIRED_POINTS={self.min_required_points}" ) return context[-usable_length:] def _timesfm_confidence(self, outputs: Any, dense_mean: list[float]) -> list[float]: full_predictions = getattr(outputs, "full_predictions", None) if full_predictions is None: return [self.confidence_floor for _ in dense_mean] quantiles = full_predictions[0].detach().cpu() confidence: list[float] = [] for idx, pred in enumerate(dense_mean): if quantiles.ndim != 2 or idx >= quantiles.shape[0]: confidence.append(self.confidence_floor) continue lower = float(quantiles[idx][0]) upper = float(quantiles[idx][-1]) band = abs(upper - lower) / max(abs(pred), 1e-6) raw = 1.0 / (1.0 + band) confidence.append(max(self.confidence_floor, min(self.confidence_ceiling, raw))) return confidence def _validate_request(self, payload: PredictRequest) -> None: if payload.context_length > self.max_context_length: raise ValueError( f"context_length {payload.context_length} exceeds " f"TIMESFM_MAX_CONTEXT_LENGTH={self.max_context_length}" ) if payload.context_length > len(payload.close_prices): raise ValueError("context_length must not exceed len(close_prices)") if len(payload.close_prices) < self.min_required_points: raise ValueError( f"at least {self.min_required_points} close prices are required " "for TimesFM stability" ) if any(step > self.max_horizon_step for step in payload.horizons): raise ValueError( f"horizons contain values above TIMESFM_MAX_HORIZON_STEP={self.max_horizon_step}" ) def _predict_with_baseline( self, close_prices: list[float], horizons: list[int] ) -> list[PredictionItem]: last_price = close_prices[-1] short_window = close_prices[-min(8, len(close_prices)) :] long_window = close_prices[-min(32, len(close_prices)) :] short_mean = mean(short_window) long_mean = mean(long_window) momentum = 0.0 if short_mean == 0 else (last_price - short_mean) / short_mean regime_bias = 0.0 if long_mean == 0 else (short_mean - long_mean) / long_mean predictions: list[PredictionItem] = [] for step in horizons: damped_step = math.log(step + 1.0) expected_return = momentum * 0.55 + regime_bias * 0.45 expected_return *= min(1.0, damped_step / 3.5) pred_price = max(0.00000001, last_price * (1.0 + expected_return)) confidence = self._baseline_confidence(close_prices, step, abs(expected_return)) predictions.append( PredictionItem( step=step, pred_price=round(pred_price, 8), pred_confidence=round(confidence, 4), ) ) return predictions def _baseline_confidence( self, close_prices: list[float], step: int, expected_move_abs: float ) -> float: if len(close_prices) < 3: return self.confidence_floor changes: list[float] = [] for previous, current in zip(close_prices[:-1], close_prices[1:]): if previous <= 0: continue changes.append(abs((current - previous) / previous)) realized_vol = mean(changes[-min(32, len(changes)) :]) if changes else 0.0 signal_to_noise = expected_move_abs / (realized_vol + 1e-9) horizon_decay = 1.0 / (1.0 + math.log(step + 1.0)) raw = 0.25 + min(signal_to_noise, 2.0) * 0.25 + horizon_decay * 0.35 return max(self.confidence_floor, min(self.confidence_ceiling, raw)) def describe_runtime(self) -> dict[str, Any]: return { "model_id": self.model_id, "model_name": self.model_name, "backend": self.backend, "device": self.device, "ready": self.ready, "load_error": self.load_error, "max_context_length": self.max_context_length, "max_horizon_step": self.max_horizon_step, "min_required_points": self.min_required_points, "patch_length": self.patch_length, }