Spaces:
Running
Running
| """ | |
| Kronos Prediction Engine | |
| Performs autoregressive financial time series prediction with probabilistic forecasts. | |
| """ | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from typing import Dict, Tuple, Optional | |
| from pathlib import Path | |
| import warnings | |
| from model import Kronos, KronosTokenizer, KronosPredictor | |
| from data_fetcher import fetch_hourly_klines, get_data_info | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings('ignore') | |
| # Global model cache to avoid reloading the same model multiple times | |
| loaded_models = {} | |
| # Model configuration mapping | |
| MODEL_CONFIG = { | |
| 'NeoQuasar/Kronos-mini': { | |
| 'name': 'Kronos-Mini', | |
| 'tokenizer': 'NeoQuasar/Kronos-Tokenizer-2k', | |
| 'context_length': 2048, | |
| 'params': '4.1M' | |
| }, | |
| 'NeoQuasar/Kronos-small': { | |
| 'name': 'Kronos-Small', | |
| 'tokenizer': 'NeoQuasar/Kronos-Tokenizer-base', | |
| 'context_length': 512, | |
| 'params': '24.7M' | |
| }, | |
| 'NeoQuasar/Kronos-base': { | |
| 'name': 'Kronos-Base', | |
| 'tokenizer': 'NeoQuasar/Kronos-Tokenizer-base', | |
| 'context_length': 512, | |
| 'params': '102.3M' | |
| } | |
| } | |
| class KronosPredictionEngine: | |
| """ | |
| Prediction engine for Kronos model. | |
| Handles model loading, data preparation, and probabilistic forecasting. | |
| """ | |
| def __init__(self, | |
| tokenizer_id: str = "NeoQuasar/Kronos-Tokenizer-base", | |
| model_id: str = "NeoQuasar/Kronos-small", | |
| model_path: Optional[str] = None, | |
| device: str = "cpu", | |
| max_context: int = 512, | |
| lookback: int = 400): | |
| """ | |
| Initialize the prediction engine. | |
| Args: | |
| tokenizer_id (str): HuggingFace tokenizer model ID (deprecated if model_path provided) | |
| model_id (str): HuggingFace model ID (deprecated if model_path provided) | |
| model_path (str): Model path (e.g., 'NeoQuasar/Kronos-small'). Overrides model_id if provided. | |
| device (str): Device to run on ('cpu', 'cuda', 'mps') | |
| max_context (int): Maximum context length for the model | |
| lookback (int): Lookback window for historical data (default: 400) | |
| """ | |
| # Use model_path if provided, otherwise use model_id | |
| if model_path: | |
| model_id = model_path | |
| # Get model configuration | |
| if model_id in MODEL_CONFIG: | |
| config = MODEL_CONFIG[model_id] | |
| tokenizer_id = config['tokenizer'] | |
| max_context = config['context_length'] | |
| model_name = config['name'] | |
| else: | |
| model_name = model_id | |
| print(f"🤖 Preparing Kronos models...") | |
| print(f" Model: {model_name} ({model_id})") | |
| print(f" Tokenizer: {tokenizer_id}") | |
| self.device = device | |
| self.lookback = lookback | |
| self.max_context = max_context # Store for use in prepare_data truncation | |
| self.pred_len = 24 | |
| self.model_id = model_id | |
| self.tokenizer_id = tokenizer_id | |
| try: | |
| # Check if model is already loaded | |
| if model_id in loaded_models: | |
| print(f" ♻️ Using cached model instance...") | |
| cached = loaded_models[model_id] | |
| self.tokenizer = cached['tokenizer'] | |
| self.model = cached['model'] | |
| self.predictor = cached['predictor'] | |
| print(f"✅ Models loaded from cache") | |
| else: | |
| print(f" 📥 Loading model from HuggingFace (this may take a minute)...") | |
| # Load tokenizer | |
| tokenizer = KronosTokenizer.from_pretrained(tokenizer_id) | |
| # Load model with OOM error handling | |
| try: | |
| model = Kronos.from_pretrained(model_id) | |
| except RuntimeError as e: | |
| if 'out of memory' in str(e).lower() or 'cuda out of memory' in str(e).lower(): | |
| print(f"❌ Out of Memory Error: The {model_name} model is too large for your system.") | |
| print(f" 💡 Try a smaller model:") | |
| print(f" - NeoQuasar/Kronos-mini (4.1M) - Most memory efficient") | |
| print(f" - NeoQuasar/Kronos-small (24.7M) - Balanced") | |
| if device == 'cuda': | |
| print(f" 💡 Or switch to CPU mode (slower but uses less GPU memory)") | |
| raise RuntimeError( | |
| f"Out of Memory: {model_name} is too large. Try a smaller model (Kronos-mini or Kronos-small) " | |
| f"or switch to CPU device." | |
| ) | |
| else: | |
| raise | |
| # Create predictor | |
| predictor = KronosPredictor( | |
| model, | |
| tokenizer, | |
| device=device, | |
| max_context=max_context | |
| ) | |
| # Cache the loaded models | |
| loaded_models[model_id] = { | |
| 'tokenizer': tokenizer, | |
| 'model': model, | |
| 'predictor': predictor | |
| } | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| self.predictor = predictor | |
| print(f"✅ Models loaded successfully on {device}") | |
| except RuntimeError as e: | |
| if 'Out of Memory' in str(e): | |
| raise e | |
| print(f"❌ Failed to load models: {str(e)}") | |
| raise | |
| def prepare_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.Series, pd.Series]: | |
| """ | |
| Prepare data for prediction. | |
| Automatically pads DataFrame to 400 rows if insufficient data. | |
| Args: | |
| df (pd.DataFrame): Input data with columns: timestamps, open, high, low, close, volume | |
| Returns: | |
| Tuple[pd.DataFrame, pd.Series, pd.Series]: (x_df, x_timestamp, y_timestamp) | |
| """ | |
| min_lookback = 50 # Minimum data points for model to work | |
| target_lookback = 400 # Target context window | |
| if len(df) < min_lookback: | |
| raise ValueError( | |
| f"Insufficient data: need at least {min_lookback} rows, got {len(df)}" | |
| ) | |
| # Pad DataFrame to target_lookback if insufficient | |
| if len(df) < target_lookback: | |
| print(f"⚠️ Data has {len(df)} rows, padding to {target_lookback}...") | |
| df = self._pad_dataframe(df, target_lookback) | |
| print(f"✅ DataFrame padded to {len(df)} rows") | |
| # Truncate to max_context (384 tokens) — the model only attends to this window anyway. | |
| # Using fewer tokens dramatically speeds up the attention computation. | |
| truncate_to = min(self.lookback, self.max_context, len(df) - self.pred_len) | |
| if truncate_to < min_lookback: | |
| raise ValueError( | |
| f"Insufficient data: need at least {min_lookback + self.pred_len} rows for lookback + prediction, got {len(df)}" | |
| ) | |
| # Use last truncate_to points as input | |
| x_df = df[['open', 'high', 'low', 'close', 'volume']].iloc[-truncate_to:].copy() | |
| x_timestamp = df['timestamps'].iloc[-truncate_to:].copy() | |
| # Generate future timestamps for prediction | |
| last_timestamp = df['timestamps'].iloc[-1] | |
| if len(df) > 1: | |
| # Use the minimum positive time diff across all rows to avoid | |
| # overnight/weekend gaps skewing the forecast frequency | |
| all_diffs = df['timestamps'].diff().dropna() | |
| positive_diffs = all_diffs[all_diffs > pd.Timedelta(0)] | |
| time_diff = positive_diffs.min() if len(positive_diffs) > 0 else pd.Timedelta(hours=1) | |
| else: | |
| time_diff = pd.Timedelta(hours=1) | |
| y_timestamp = pd.date_range( | |
| start=last_timestamp + time_diff, | |
| periods=self.pred_len, | |
| freq=time_diff | |
| ) | |
| return x_df, x_timestamp, y_timestamp | |
| def _pad_dataframe(self, df: pd.DataFrame, target_rows: int = 400) -> pd.DataFrame: | |
| """ | |
| Pad DataFrame to target_rows by duplicating the earliest row. | |
| Args: | |
| df (pd.DataFrame): Original DataFrame | |
| target_rows (int): Target number of rows | |
| Returns: | |
| pd.DataFrame: Padded DataFrame | |
| """ | |
| if len(df) >= target_rows: | |
| return df | |
| rows_needed = target_rows - len(df) | |
| # Get the earliest row for padding | |
| earliest_row = df.iloc[0].copy() | |
| # Calculate timestamp interval | |
| if len(df) > 1: | |
| time_diff = df.iloc[1]['timestamps'] - df.iloc[0]['timestamps'] | |
| else: | |
| time_diff = pd.Timedelta(hours=1) | |
| # Create padding rows | |
| padding_rows = [] | |
| for i in range(rows_needed): | |
| padded_row = earliest_row.copy() | |
| padded_row['timestamps'] = earliest_row['timestamps'] - (time_diff * (rows_needed - i)) | |
| padding_rows.append(padded_row) | |
| # Combine padding with original data | |
| padding_df = pd.DataFrame(padding_rows) | |
| result = pd.concat([padding_df, df], ignore_index=True) | |
| return result | |
| def predict(self, | |
| df: pd.DataFrame, | |
| sample_count: int = 30, | |
| temperature: float = 1.0, | |
| top_p: float = 0.9) -> Dict: | |
| """ | |
| Generate probabilistic predictions. | |
| Args: | |
| df (pd.DataFrame): Historical OHLCV data | |
| sample_count (int): Number of sample paths (default: 30) | |
| temperature (float): Sampling temperature (default: 1.0) | |
| top_p (float): Nucleus sampling parameter (default: 0.9) | |
| Returns: | |
| Dict: Prediction results including mean, std, percentiles, and all samples | |
| """ | |
| print(f"\n🔮 Generating {sample_count} sample paths for {self.pred_len}-hour forecast...") | |
| # Prepare data | |
| x_df, x_timestamp, y_timestamp = self.prepare_data(df) | |
| # Ensure timestamps are Series, not DatetimeIndex | |
| if isinstance(x_timestamp, pd.DatetimeIndex): | |
| x_timestamp = pd.Series(x_timestamp.values, name='timestamps') | |
| if isinstance(y_timestamp, pd.DatetimeIndex): | |
| y_timestamp = pd.Series(y_timestamp.values, name='timestamps') | |
| # Each call with sample_count=1 draws an independent stochastic sample. | |
| # auto_regressive_inference averages internally when sample_count>1, so | |
| # calling once with sample_count=N would collapse all variance → std=0. | |
| # We need independent calls to preserve the distribution for confidence intervals. | |
| predictions_list = [] | |
| print(f" Generating samples: ", end="", flush=True) | |
| for i in range(sample_count): | |
| if (i + 1) % max(1, sample_count // 5) == 0: | |
| print(f"{i+1}...", end="", flush=True) | |
| try: | |
| pred_df = self.predictor.predict( | |
| df=x_df, | |
| x_timestamp=x_timestamp, | |
| y_timestamp=y_timestamp, | |
| pred_len=self.pred_len, | |
| T=temperature, | |
| top_p=top_p, | |
| sample_count=1, | |
| verbose=False | |
| ) | |
| predictions_list.append(pred_df) | |
| except Exception as e: | |
| print(f"\n⚠️ Sample {i+1} failed: {str(e)}, skipping...") | |
| continue | |
| print("✅") | |
| if not predictions_list: | |
| raise RuntimeError("All predictions failed") | |
| print(f"✅ Successfully generated {len(predictions_list)} samples") | |
| results = self._aggregate_predictions(predictions_list, y_timestamp) | |
| return results | |
| def _aggregate_predictions(self, | |
| predictions_list: list, | |
| y_timestamp: pd.Series) -> Dict: | |
| """ | |
| Aggregate multiple sample predictions into probabilistic forecast. | |
| Args: | |
| predictions_list (list): List of prediction DataFrames | |
| y_timestamp (pd.Series): Future timestamps | |
| Returns: | |
| Dict: Aggregated statistics and forecasts | |
| """ | |
| # Stack all predictions | |
| samples = {} | |
| for col in predictions_list[0].columns: | |
| samples[col] = np.array([pred[col].values for pred in predictions_list]) | |
| # Calculate statistics | |
| results = { | |
| 'timestamps': np.array([ts.isoformat() if hasattr(ts, 'isoformat') else str(ts) | |
| for ts in y_timestamp]), | |
| 'samples': {} | |
| } | |
| for col in samples.keys(): | |
| data = samples[col] | |
| results[col] = { | |
| 'mean': np.mean(data, axis=0), | |
| 'std': np.std(data, axis=0), | |
| 'median': np.median(data, axis=0), | |
| 'q5': np.percentile(data, 5, axis=0), # 5th percentile | |
| 'q25': np.percentile(data, 25, axis=0), # 25th percentile | |
| 'q75': np.percentile(data, 75, axis=0), # 75th percentile | |
| 'q95': np.percentile(data, 95, axis=0), # 95th percentile | |
| } | |
| results['samples'][col] = data | |
| # Create summary DataFrame | |
| summary_df = pd.DataFrame({ | |
| 'timestamps': results['timestamps'], | |
| 'open_mean': results['open']['mean'], | |
| 'open_std': results['open']['std'], | |
| 'high_mean': results['high']['mean'], | |
| 'high_std': results['high']['std'], | |
| 'low_mean': results['low']['mean'], | |
| 'low_std': results['low']['std'], | |
| 'close_mean': results['close']['mean'], | |
| 'close_std': results['close']['std'], | |
| 'close_q5': results['close']['q5'], | |
| 'close_q25': results['close']['q25'], | |
| 'close_q75': results['close']['q75'], | |
| 'close_q95': results['close']['q95'], | |
| 'volume_mean': results['volume']['mean'], | |
| 'volume_std': results['volume']['std'], | |
| }) | |
| results['summary_df'] = summary_df | |
| return results | |
| def print_forecast(self, results: Dict) -> None: | |
| """ | |
| Print formatted forecast results. | |
| Args: | |
| results (Dict): Prediction results from predict() | |
| """ | |
| df = results['summary_df'] | |
| print("\n📊 Probabilistic Forecast Summary:") | |
| print("=" * 100) | |
| print(f"{'Time':<22} {'Close (Mean)':<12} {'±Std':<10} {'[5%, 95%]':<20}") | |
| print("-" * 100) | |
| for idx, row in df.iterrows(): | |
| ts = row['timestamps'][:16] if isinstance(row['timestamps'], str) else str(row['timestamps'])[:16] | |
| close = row['close_mean'] | |
| std = row['close_std'] | |
| q5 = row['close_q5'] | |
| q95 = row['close_q95'] | |
| print(f"{ts:<22} ${close:>10.2f} ±{std:>8.2f} [{q5:>8.2f}, {q95:>8.2f}]") | |
| print("=" * 100) | |
| def get_prediction(symbol: str = None, | |
| data_path: str = None, | |
| periods: int = 500, | |
| sample_count: int = 30, | |
| temperature: float = 1.0, | |
| top_p: float = 0.9, | |
| save_results: bool = True, | |
| lookback: int = 400) -> Dict: | |
| """ | |
| Main function to get prediction for a given ticker symbol or data file. | |
| Args: | |
| symbol (str): Stock ticker (e.g., 'AAPL', 'BTC-USD'). Either symbol or data_path required. | |
| data_path (str): Path to CSV file with OHLCV data. Either symbol or data_path required. | |
| periods (int): Number of historical periods to use (default: 500). Ignored if data_path provided. | |
| sample_count (int): Number of sample paths (default: 30) | |
| temperature (float): Sampling temperature (default: 1.0) | |
| top_p (float): Nucleus sampling parameter (default: 0.9) | |
| save_results (bool): Whether to save results to CSV (default: True) | |
| lookback (int): Lookback window for historical data (default: 400). Auto-adjusted based on data availability. | |
| Returns: | |
| Dict: Prediction results with mean, std, and confidence intervals | |
| Example: | |
| >>> results = get_prediction(symbol='AAPL') | |
| >>> results = get_prediction(data_path='examples/data/XSHG_5min_600977.csv', sample_count=30) | |
| >>> results = get_prediction(symbol='BTC-USD', sample_count=50, lookback=100) | |
| """ | |
| if not symbol and not data_path: | |
| raise ValueError("Either 'symbol' or 'data_path' must be provided") | |
| if symbol and data_path: | |
| raise ValueError("Provide only one of 'symbol' or 'data_path', not both") | |
| print(f"\n🚀 Kronos Prediction Engine") | |
| print(f"{'='*60}") | |
| # Fetch or load data | |
| print(f"\n1️⃣ Loading historical data...") | |
| try: | |
| if data_path: | |
| # Load from CSV file | |
| df = pd.read_csv(data_path) | |
| df['timestamps'] = pd.to_datetime(df['timestamps']) | |
| df = df.sort_values('timestamps').reset_index(drop=True) | |
| data_source = f"file: {data_path}" | |
| else: | |
| # Fetch from yfinance | |
| df = fetch_hourly_klines(symbol, periods=periods) | |
| data_source = f"ticker: {symbol}" | |
| info = get_data_info(df) | |
| print(f" ✅ Loaded {info['total_rows']} records from {data_source}") | |
| print(f" 📅 Date range: {info['start_date']} to {info['end_date']}") | |
| print(f" 💰 Price range: ${info['price_range_min']:.2f} - ${info['price_range_max']:.2f}") | |
| except Exception as e: | |
| print(f" ❌ Failed to load data: {str(e)}") | |
| raise | |
| # Initialize engine with configurable lookback | |
| print(f"\n2️⃣ Initializing Kronos prediction engine...") | |
| try: | |
| engine = KronosPredictionEngine(lookback=lookback) | |
| except Exception as e: | |
| print(f" ❌ Failed to initialize engine: {str(e)}") | |
| raise | |
| # Generate predictions | |
| print(f"\n3️⃣ Generating probabilistic forecast...") | |
| try: | |
| results = engine.predict( | |
| df, | |
| sample_count=sample_count, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| except Exception as e: | |
| print(f" ❌ Prediction failed: {str(e)}") | |
| raise | |
| # Print summary | |
| print(f"\n4️⃣ Forecast Summary") | |
| engine.print_forecast(results) | |
| # Save results | |
| if save_results: | |
| print(f"\n5️⃣ Saving results...") | |
| output_name = symbol if symbol else Path(data_path).stem | |
| output_path = Path('predictions') / f"{output_name}_forecast.csv" | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| results['summary_df'].to_csv(output_path, index=False) | |
| print(f" 💾 Results saved to: {output_path}") | |
| # Also save full sample paths | |
| samples_path = output_path.parent / f"{output_name}_samples.npz" | |
| np.savez(samples_path, **results['samples']) | |
| print(f" 💾 Sample paths saved to: {samples_path}") | |
| print(f"\n✅ Prediction complete!") | |
| print(f"{'='*60}\n") | |
| return results | |
| if __name__ == "__main__": | |
| import sys | |
| # Get symbol from command line or use default | |
| symbol = sys.argv[1].upper() if len(sys.argv) > 1 else "AAPL" | |
| sample_count = int(sys.argv[2]) if len(sys.argv) > 2 else 30 | |
| try: | |
| results = get_prediction(symbol, sample_count=sample_count) | |
| except Exception as e: | |
| print(f"\n❌ Error: {str(e)}") | |
| sys.exit(1) | |