Spaces:
Running
Running
fengwm
ๆดๆฐ README๏ผๅๆขๆฐๆฎๆบไธบ AkShare๏ผ็ปไธ่ฏทๆฑๅญๆฎตไธบ symbol๏ผไผๅ็ผๅญๆบๅถ๏ผๆฐๅขๆง่ฝๆฅๅฟ
2a8a0f5 | """ | |
| Kronos model singleton + Monte-Carlo prediction logic. | |
| On import this module: | |
| 1. Clones shiyu-coder/Kronos from GitHub if not already present at KRONOS_DIR. | |
| 2. Adds KRONOS_DIR to sys.path so `from model import ...` works. | |
| 3. Does NOT load the model weights yet (lazy, first-request). | |
| """ | |
| import logging | |
| import os | |
| import subprocess | |
| import sys | |
| import threading | |
| from typing import Tuple | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| # โโ Paths / IDs โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| KRONOS_DIR = os.environ.get("KRONOS_DIR", "/app/Kronos") | |
| MODEL_ID = "NeoQuasar/Kronos-base" | |
| TOKENIZER_ID = "NeoQuasar/Kronos-Tokenizer-base" | |
| MC_BATCH_SIZE = max(1, int(os.environ.get("MC_BATCH_SIZE", "8"))) | |
| # โโ Bootstrap Kronos source โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _ensure_kronos_source() -> None: | |
| if not os.path.isdir(KRONOS_DIR): | |
| logger.info("Cloning Kronos source to %s โฆ", KRONOS_DIR) | |
| subprocess.run( | |
| [ | |
| "git", "clone", "--depth", "1", | |
| "https://github.com/shiyu-coder/Kronos", | |
| KRONOS_DIR, | |
| ], | |
| check=True, | |
| ) | |
| if KRONOS_DIR not in sys.path: | |
| sys.path.insert(0, KRONOS_DIR) | |
| _ensure_kronos_source() | |
| from model import Kronos, KronosPredictor, KronosTokenizer # noqa: E402 (after sys.path setup) | |
| # โโ Global singleton + inference lock โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # RotaryEmbedding keeps mutable instance-level cache (seq_len_cached / cos_cached). | |
| # Concurrent threads sharing the same model instance will race on that cache, | |
| # causing cos=None crashes. Serialise all predict() calls with this lock. | |
| _predictor: KronosPredictor | None = None | |
| _infer_lock = threading.Lock() | |
| def get_predictor() -> KronosPredictor: | |
| global _predictor | |
| if _predictor is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info("Loading Kronos model on %s โฆ", device) | |
| tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_ID) | |
| model = Kronos.from_pretrained(MODEL_ID) | |
| _predictor = KronosPredictor(model, tokenizer, device=device, max_context=512) | |
| logger.info("Kronos predictor ready.") | |
| return _predictor | |
| def _split_batched_output( | |
| pred_output, | |
| expected_count: int, | |
| pred_len: int, | |
| ) -> list[pd.DataFrame]: | |
| """ | |
| Normalize predictor output into `expected_count` DataFrame samples. | |
| Supports single-sample DataFrame and common batched return shapes. | |
| """ | |
| if isinstance(pred_output, pd.DataFrame): | |
| if expected_count == 1: | |
| return [pred_output] | |
| if isinstance(pred_output.index, pd.MultiIndex): | |
| grouped = [g.droplevel(0) for _, g in pred_output.groupby(level=0, sort=False)] | |
| if len(grouped) == expected_count: | |
| return grouped | |
| if len(pred_output) == expected_count * pred_len: | |
| return [ | |
| pred_output.iloc[i * pred_len:(i + 1) * pred_len].copy() | |
| for i in range(expected_count) | |
| ] | |
| if isinstance(pred_output, (list, tuple)): | |
| if len(pred_output) == expected_count and all( | |
| isinstance(item, pd.DataFrame) for item in pred_output | |
| ): | |
| return list(pred_output) | |
| if expected_count == 1 and len(pred_output) == 1 and isinstance(pred_output[0], pd.DataFrame): | |
| return [pred_output[0]] | |
| raise ValueError("Unsupported predict() output format for batched sampling") | |
| # โโ Monte-Carlo prediction โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def run_mc_prediction( | |
| x_df: pd.DataFrame, | |
| x_timestamp: pd.Series, | |
| y_timestamp: pd.Series, | |
| pred_len: int, | |
| sample_count: int, | |
| ) -> Tuple[pd.DataFrame, dict, np.ndarray, np.ndarray, float, float]: | |
| """ | |
| Run `sample_count` independent samples (each with sample_count=1) to build | |
| MC statistics. | |
| Returns: | |
| pred_mean : DataFrame (index=y_timestamp, cols=OHLCVA), ๅๅผ่ฝจ่ฟน | |
| ci : dict[field]["low"/"high"] โ ndarray(pred_len,), 95% CI | |
| trading_low : ndarray(pred_len,), q2.5 of predicted_low | |
| trading_high : ndarray(pred_len,), q97.5 of predicted_high | |
| direction_prob : float โ [0,1], horizon-level bullish probability | |
| last_close : float, closing price of the last historical bar | |
| """ | |
| predictor = get_predictor() | |
| samples: list[pd.DataFrame] = [] | |
| supports_batched_sampling = True | |
| remaining = sample_count | |
| while remaining > 0: | |
| batch_n = min(remaining, MC_BATCH_SIZE if supports_batched_sampling else 1) | |
| with _infer_lock: | |
| pred_output = predictor.predict( | |
| df=x_df, | |
| x_timestamp=x_timestamp, | |
| y_timestamp=y_timestamp, | |
| pred_len=pred_len, | |
| T=0.8, | |
| top_p=0.9, | |
| sample_count=batch_n, | |
| verbose=False, | |
| ) | |
| try: | |
| batch_samples = _split_batched_output(pred_output, batch_n, pred_len) | |
| except ValueError: | |
| if batch_n > 1: | |
| # Fallback for predictor implementations that do not support | |
| # returning per-sample outputs for sample_count>1. | |
| supports_batched_sampling = False | |
| continue | |
| raise | |
| samples.extend(batch_samples) | |
| remaining -= batch_n | |
| pred_mean = pd.concat(samples).groupby(level=0).mean() | |
| stacked = { | |
| field: np.stack([s[field].values for s in samples]) # (sample_count, pred_len) | |
| for field in ["open", "high", "low", "close", "volume"] | |
| } | |
| alpha = 2.5 # โ 95 % CI | |
| ci = { | |
| field: { | |
| "low": np.percentile(stacked[field], alpha, axis=0), | |
| "high": np.percentile(stacked[field], 100 - alpha, axis=0), | |
| } | |
| for field in stacked | |
| } | |
| trading_low = ci["low"]["low"] # q2.5 of the predicted daily low | |
| trading_high = ci["high"]["high"] # q97.5 of the predicted daily high | |
| last_close = float(x_df["close"].iloc[-1]) | |
| close_paths = stacked["close"] # (sample_count, pred_len) | |
| # Use all future points to estimate horizon bullish probability. | |
| bull_count = int((close_paths > last_close).sum()) | |
| total_points = int(close_paths.size) | |
| direction_prob = bull_count / total_points | |
| return pred_mean, ci, trading_low, trading_high, direction_prob, last_close | |