kronos-api / predictor.py
fengwm
ๆ›ดๆ–ฐ README๏ผŒๅˆ‡ๆขๆ•ฐๆฎๆบไธบ AkShare๏ผŒ็ปŸไธ€่ฏทๆฑ‚ๅญ—ๆฎตไธบ symbol๏ผŒไผ˜ๅŒ–็ผ“ๅญ˜ๆœบๅˆถ๏ผŒๆ–ฐๅขžๆ€ง่ƒฝๆ—ฅๅฟ—
2a8a0f5
"""
Kronos model singleton + Monte-Carlo prediction logic.
On import this module:
1. Clones shiyu-coder/Kronos from GitHub if not already present at KRONOS_DIR.
2. Adds KRONOS_DIR to sys.path so `from model import ...` works.
3. Does NOT load the model weights yet (lazy, first-request).
"""
import logging
import os
import subprocess
import sys
import threading
from typing import Tuple
import numpy as np
import pandas as pd
import torch
logger = logging.getLogger(__name__)
# โ”€โ”€ Paths / IDs โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
KRONOS_DIR = os.environ.get("KRONOS_DIR", "/app/Kronos")
MODEL_ID = "NeoQuasar/Kronos-base"
TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
MC_BATCH_SIZE = max(1, int(os.environ.get("MC_BATCH_SIZE", "8")))
# โ”€โ”€ Bootstrap Kronos source โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def _ensure_kronos_source() -> None:
if not os.path.isdir(KRONOS_DIR):
logger.info("Cloning Kronos source to %s โ€ฆ", KRONOS_DIR)
subprocess.run(
[
"git", "clone", "--depth", "1",
"https://github.com/shiyu-coder/Kronos",
KRONOS_DIR,
],
check=True,
)
if KRONOS_DIR not in sys.path:
sys.path.insert(0, KRONOS_DIR)
_ensure_kronos_source()
from model import Kronos, KronosPredictor, KronosTokenizer # noqa: E402 (after sys.path setup)
# โ”€โ”€ Global singleton + inference lock โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# RotaryEmbedding keeps mutable instance-level cache (seq_len_cached / cos_cached).
# Concurrent threads sharing the same model instance will race on that cache,
# causing cos=None crashes. Serialise all predict() calls with this lock.
_predictor: KronosPredictor | None = None
_infer_lock = threading.Lock()
def get_predictor() -> KronosPredictor:
global _predictor
if _predictor is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Loading Kronos model on %s โ€ฆ", device)
tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
model = Kronos.from_pretrained(MODEL_ID)
_predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
logger.info("Kronos predictor ready.")
return _predictor
def _split_batched_output(
pred_output,
expected_count: int,
pred_len: int,
) -> list[pd.DataFrame]:
"""
Normalize predictor output into `expected_count` DataFrame samples.
Supports single-sample DataFrame and common batched return shapes.
"""
if isinstance(pred_output, pd.DataFrame):
if expected_count == 1:
return [pred_output]
if isinstance(pred_output.index, pd.MultiIndex):
grouped = [g.droplevel(0) for _, g in pred_output.groupby(level=0, sort=False)]
if len(grouped) == expected_count:
return grouped
if len(pred_output) == expected_count * pred_len:
return [
pred_output.iloc[i * pred_len:(i + 1) * pred_len].copy()
for i in range(expected_count)
]
if isinstance(pred_output, (list, tuple)):
if len(pred_output) == expected_count and all(
isinstance(item, pd.DataFrame) for item in pred_output
):
return list(pred_output)
if expected_count == 1 and len(pred_output) == 1 and isinstance(pred_output[0], pd.DataFrame):
return [pred_output[0]]
raise ValueError("Unsupported predict() output format for batched sampling")
# โ”€โ”€ Monte-Carlo prediction โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def run_mc_prediction(
x_df: pd.DataFrame,
x_timestamp: pd.Series,
y_timestamp: pd.Series,
pred_len: int,
sample_count: int,
) -> Tuple[pd.DataFrame, dict, np.ndarray, np.ndarray, float, float]:
"""
Run `sample_count` independent samples (each with sample_count=1) to build
MC statistics.
Returns:
pred_mean : DataFrame (index=y_timestamp, cols=OHLCVA), ๅ‡ๅ€ผ่ฝจ่ฟน
ci : dict[field]["low"/"high"] โ†’ ndarray(pred_len,), 95% CI
trading_low : ndarray(pred_len,), q2.5 of predicted_low
trading_high : ndarray(pred_len,), q97.5 of predicted_high
direction_prob : float โˆˆ [0,1], horizon-level bullish probability
last_close : float, closing price of the last historical bar
"""
predictor = get_predictor()
samples: list[pd.DataFrame] = []
supports_batched_sampling = True
remaining = sample_count
while remaining > 0:
batch_n = min(remaining, MC_BATCH_SIZE if supports_batched_sampling else 1)
with _infer_lock:
pred_output = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=0.8,
top_p=0.9,
sample_count=batch_n,
verbose=False,
)
try:
batch_samples = _split_batched_output(pred_output, batch_n, pred_len)
except ValueError:
if batch_n > 1:
# Fallback for predictor implementations that do not support
# returning per-sample outputs for sample_count>1.
supports_batched_sampling = False
continue
raise
samples.extend(batch_samples)
remaining -= batch_n
pred_mean = pd.concat(samples).groupby(level=0).mean()
stacked = {
field: np.stack([s[field].values for s in samples]) # (sample_count, pred_len)
for field in ["open", "high", "low", "close", "volume"]
}
alpha = 2.5 # โ†’ 95 % CI
ci = {
field: {
"low": np.percentile(stacked[field], alpha, axis=0),
"high": np.percentile(stacked[field], 100 - alpha, axis=0),
}
for field in stacked
}
trading_low = ci["low"]["low"] # q2.5 of the predicted daily low
trading_high = ci["high"]["high"] # q97.5 of the predicted daily high
last_close = float(x_df["close"].iloc[-1])
close_paths = stacked["close"] # (sample_count, pred_len)
# Use all future points to estimate horizon bullish probability.
bull_count = int((close_paths > last_close).sum())
total_points = int(close_paths.size)
direction_prob = bull_count / total_points
return pred_mean, ci, trading_low, trading_high, direction_prob, last_close