Spaces:
Running
Running
File size: 6,935 Bytes
72a9562 bb6397a 72a9562 2a8a0f5 72a9562 bb6397a 72a9562 bb6397a 72a9562 2a8a0f5 72a9562 2a8a0f5 72a9562 2a8a0f5 72a9562 2a8a0f5 bb6397a 2a8a0f5 bb6397a 2a8a0f5 bb6397a 2a8a0f5 72a9562 2a8a0f5 72a9562 2a8a0f5 72a9562 2a8a0f5 72a9562 2a8a0f5 72a9562 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """
Kronos model singleton + Monte-Carlo prediction logic.
On import this module:
1. Clones shiyu-coder/Kronos from GitHub if not already present at KRONOS_DIR.
2. Adds KRONOS_DIR to sys.path so `from model import ...` works.
3. Does NOT load the model weights yet (lazy, first-request).
"""
import logging
import os
import subprocess
import sys
import threading
from typing import Tuple
import numpy as np
import pandas as pd
import torch
logger = logging.getLogger(__name__)
# ββ Paths / IDs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
KRONOS_DIR = os.environ.get("KRONOS_DIR", "/app/Kronos")
MODEL_ID = "NeoQuasar/Kronos-base"
TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base"
MC_BATCH_SIZE = max(1, int(os.environ.get("MC_BATCH_SIZE", "8")))
# ββ Bootstrap Kronos source ββββββββββββββββββββββββββββββββββββββββββββββββββ
def _ensure_kronos_source() -> None:
if not os.path.isdir(KRONOS_DIR):
logger.info("Cloning Kronos source to %s β¦", KRONOS_DIR)
subprocess.run(
[
"git", "clone", "--depth", "1",
"https://github.com/shiyu-coder/Kronos",
KRONOS_DIR,
],
check=True,
)
if KRONOS_DIR not in sys.path:
sys.path.insert(0, KRONOS_DIR)
_ensure_kronos_source()
from model import Kronos, KronosPredictor, KronosTokenizer # noqa: E402 (after sys.path setup)
# ββ Global singleton + inference lock ββββββββββββββββββββββββββββββββββββββββ
# RotaryEmbedding keeps mutable instance-level cache (seq_len_cached / cos_cached).
# Concurrent threads sharing the same model instance will race on that cache,
# causing cos=None crashes. Serialise all predict() calls with this lock.
_predictor: KronosPredictor | None = None
_infer_lock = threading.Lock()
def get_predictor() -> KronosPredictor:
global _predictor
if _predictor is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Loading Kronos model on %s β¦", device)
tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID)
model = Kronos.from_pretrained(MODEL_ID)
_predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
logger.info("Kronos predictor ready.")
return _predictor
def _split_batched_output(
pred_output,
expected_count: int,
pred_len: int,
) -> list[pd.DataFrame]:
"""
Normalize predictor output into `expected_count` DataFrame samples.
Supports single-sample DataFrame and common batched return shapes.
"""
if isinstance(pred_output, pd.DataFrame):
if expected_count == 1:
return [pred_output]
if isinstance(pred_output.index, pd.MultiIndex):
grouped = [g.droplevel(0) for _, g in pred_output.groupby(level=0, sort=False)]
if len(grouped) == expected_count:
return grouped
if len(pred_output) == expected_count * pred_len:
return [
pred_output.iloc[i * pred_len:(i + 1) * pred_len].copy()
for i in range(expected_count)
]
if isinstance(pred_output, (list, tuple)):
if len(pred_output) == expected_count and all(
isinstance(item, pd.DataFrame) for item in pred_output
):
return list(pred_output)
if expected_count == 1 and len(pred_output) == 1 and isinstance(pred_output[0], pd.DataFrame):
return [pred_output[0]]
raise ValueError("Unsupported predict() output format for batched sampling")
# ββ Monte-Carlo prediction ββββββββββββββββββββββββββββββββββββββββββββββββββββ
def run_mc_prediction(
x_df: pd.DataFrame,
x_timestamp: pd.Series,
y_timestamp: pd.Series,
pred_len: int,
sample_count: int,
) -> Tuple[pd.DataFrame, dict, np.ndarray, np.ndarray, float, float]:
"""
Run `sample_count` independent samples (each with sample_count=1) to build
MC statistics.
Returns:
pred_mean : DataFrame (index=y_timestamp, cols=OHLCVA), εεΌθ½¨θΏΉ
ci : dict[field]["low"/"high"] β ndarray(pred_len,), 95% CI
trading_low : ndarray(pred_len,), q2.5 of predicted_low
trading_high : ndarray(pred_len,), q97.5 of predicted_high
direction_prob : float β [0,1], horizon-level bullish probability
last_close : float, closing price of the last historical bar
"""
predictor = get_predictor()
samples: list[pd.DataFrame] = []
supports_batched_sampling = True
remaining = sample_count
while remaining > 0:
batch_n = min(remaining, MC_BATCH_SIZE if supports_batched_sampling else 1)
with _infer_lock:
pred_output = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=0.8,
top_p=0.9,
sample_count=batch_n,
verbose=False,
)
try:
batch_samples = _split_batched_output(pred_output, batch_n, pred_len)
except ValueError:
if batch_n > 1:
# Fallback for predictor implementations that do not support
# returning per-sample outputs for sample_count>1.
supports_batched_sampling = False
continue
raise
samples.extend(batch_samples)
remaining -= batch_n
pred_mean = pd.concat(samples).groupby(level=0).mean()
stacked = {
field: np.stack([s[field].values for s in samples]) # (sample_count, pred_len)
for field in ["open", "high", "low", "close", "volume"]
}
alpha = 2.5 # β 95 % CI
ci = {
field: {
"low": np.percentile(stacked[field], alpha, axis=0),
"high": np.percentile(stacked[field], 100 - alpha, axis=0),
}
for field in stacked
}
trading_low = ci["low"]["low"] # q2.5 of the predicted daily low
trading_high = ci["high"]["high"] # q97.5 of the predicted daily high
last_close = float(x_df["close"].iloc[-1])
close_paths = stacked["close"] # (sample_count, pred_len)
# Use all future points to estimate horizon bullish probability.
bull_count = int((close_paths > last_close).sum())
total_points = int(close_paths.size)
direction_prob = bull_count / total_points
return pred_mean, ci, trading_low, trading_high, direction_prob, last_close
|