""" Kronos model singleton + Monte-Carlo prediction logic. On import this module: 1. Clones shiyu-coder/Kronos from GitHub if not already present at KRONOS_DIR. 2. Adds KRONOS_DIR to sys.path so `from model import ...` works. 3. Does NOT load the model weights yet (lazy, first-request). """ import logging import os import subprocess import sys import threading from typing import Tuple import numpy as np import pandas as pd import torch logger = logging.getLogger(__name__) # ── Paths / IDs ───────────────────────────────────────────────────────────── KRONOS_DIR = os.environ.get("KRONOS_DIR", "/app/Kronos") MODEL_ID = "NeoQuasar/Kronos-base" TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base" MC_BATCH_SIZE = max(1, int(os.environ.get("MC_BATCH_SIZE", "8"))) # ── Bootstrap Kronos source ────────────────────────────────────────────────── def _ensure_kronos_source() -> None: if not os.path.isdir(KRONOS_DIR): logger.info("Cloning Kronos source to %s …", KRONOS_DIR) subprocess.run( [ "git", "clone", "--depth", "1", "https://github.com/shiyu-coder/Kronos", KRONOS_DIR, ], check=True, ) if KRONOS_DIR not in sys.path: sys.path.insert(0, KRONOS_DIR) _ensure_kronos_source() from model import Kronos, KronosPredictor, KronosTokenizer # noqa: E402 (after sys.path setup) # ── Global singleton + inference lock ──────────────────────────────────────── # RotaryEmbedding keeps mutable instance-level cache (seq_len_cached / cos_cached). # Concurrent threads sharing the same model instance will race on that cache, # causing cos=None crashes. Serialise all predict() calls with this lock. _predictor: KronosPredictor | None = None _infer_lock = threading.Lock() def get_predictor() -> KronosPredictor: global _predictor if _predictor is None: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Loading Kronos model on %s …", device) tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID) model = Kronos.from_pretrained(MODEL_ID) _predictor = KronosPredictor(model, tokenizer, device=device, max_context=512) logger.info("Kronos predictor ready.") return _predictor def _split_batched_output( pred_output, expected_count: int, pred_len: int, ) -> list[pd.DataFrame]: """ Normalize predictor output into `expected_count` DataFrame samples. Supports single-sample DataFrame and common batched return shapes. """ if isinstance(pred_output, pd.DataFrame): if expected_count == 1: return [pred_output] if isinstance(pred_output.index, pd.MultiIndex): grouped = [g.droplevel(0) for _, g in pred_output.groupby(level=0, sort=False)] if len(grouped) == expected_count: return grouped if len(pred_output) == expected_count * pred_len: return [ pred_output.iloc[i * pred_len:(i + 1) * pred_len].copy() for i in range(expected_count) ] if isinstance(pred_output, (list, tuple)): if len(pred_output) == expected_count and all( isinstance(item, pd.DataFrame) for item in pred_output ): return list(pred_output) if expected_count == 1 and len(pred_output) == 1 and isinstance(pred_output[0], pd.DataFrame): return [pred_output[0]] raise ValueError("Unsupported predict() output format for batched sampling") # ── Monte-Carlo prediction ──────────────────────────────────────────────────── def run_mc_prediction( x_df: pd.DataFrame, x_timestamp: pd.Series, y_timestamp: pd.Series, pred_len: int, sample_count: int, ) -> Tuple[pd.DataFrame, dict, np.ndarray, np.ndarray, float, float]: """ Run `sample_count` independent samples (each with sample_count=1) to build MC statistics. Returns: pred_mean : DataFrame (index=y_timestamp, cols=OHLCVA), 均值轨迹 ci : dict[field]["low"/"high"] → ndarray(pred_len,), 95% CI trading_low : ndarray(pred_len,), q2.5 of predicted_low trading_high : ndarray(pred_len,), q97.5 of predicted_high direction_prob : float ∈ [0,1], horizon-level bullish probability last_close : float, closing price of the last historical bar """ predictor = get_predictor() samples: list[pd.DataFrame] = [] supports_batched_sampling = True remaining = sample_count while remaining > 0: batch_n = min(remaining, MC_BATCH_SIZE if supports_batched_sampling else 1) with _infer_lock: pred_output = predictor.predict( df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, pred_len=pred_len, T=0.8, top_p=0.9, sample_count=batch_n, verbose=False, ) try: batch_samples = _split_batched_output(pred_output, batch_n, pred_len) except ValueError: if batch_n > 1: # Fallback for predictor implementations that do not support # returning per-sample outputs for sample_count>1. supports_batched_sampling = False continue raise samples.extend(batch_samples) remaining -= batch_n pred_mean = pd.concat(samples).groupby(level=0).mean() stacked = { field: np.stack([s[field].values for s in samples]) # (sample_count, pred_len) for field in ["open", "high", "low", "close", "volume"] } alpha = 2.5 # → 95 % CI ci = { field: { "low": np.percentile(stacked[field], alpha, axis=0), "high": np.percentile(stacked[field], 100 - alpha, axis=0), } for field in stacked } trading_low = ci["low"]["low"] # q2.5 of the predicted daily low trading_high = ci["high"]["high"] # q97.5 of the predicted daily high last_close = float(x_df["close"].iloc[-1]) close_paths = stacked["close"] # (sample_count, pred_len) # Use all future points to estimate horizon bullish probability. bull_count = int((close_paths > last_close).sum()) total_points = int(close_paths.size) direction_prob = bull_count / total_points return pred_mean, ci, trading_low, trading_high, direction_prob, last_close