File size: 20,160 Bytes
d3be94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
"""
Kronos Prediction Engine
Performs autoregressive financial time series prediction with probabilistic forecasts.
"""

import pandas as pd
import numpy as np
import torch
from typing import Dict, Tuple, Optional
from pathlib import Path
import warnings

from model import Kronos, KronosTokenizer, KronosPredictor
from data_fetcher import fetch_hourly_klines, get_data_info

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Global model cache to avoid reloading the same model multiple times
loaded_models = {}

# Model configuration mapping
MODEL_CONFIG = {
    'NeoQuasar/Kronos-mini': {
        'name': 'Kronos-Mini',
        'tokenizer': 'NeoQuasar/Kronos-Tokenizer-2k',
        'context_length': 2048,
        'params': '4.1M'
    },
    'NeoQuasar/Kronos-small': {
        'name': 'Kronos-Small',
        'tokenizer': 'NeoQuasar/Kronos-Tokenizer-base',
        'context_length': 512,
        'params': '24.7M'
    },
    'NeoQuasar/Kronos-base': {
        'name': 'Kronos-Base',
        'tokenizer': 'NeoQuasar/Kronos-Tokenizer-base',
        'context_length': 512,
        'params': '102.3M'
    }
}


class KronosPredictionEngine:
    """
    Prediction engine for Kronos model.
    Handles model loading, data preparation, and probabilistic forecasting.
    """
    
    def __init__(self, 
                 tokenizer_id: str = "NeoQuasar/Kronos-Tokenizer-base",
                 model_id: str = "NeoQuasar/Kronos-small",
                 model_path: Optional[str] = None,
                 device: str = "cpu",
                 max_context: int = 512,
                 lookback: int = 400):
        """
        Initialize the prediction engine.
        
        Args:
            tokenizer_id (str): HuggingFace tokenizer model ID (deprecated if model_path provided)
            model_id (str): HuggingFace model ID (deprecated if model_path provided)
            model_path (str): Model path (e.g., 'NeoQuasar/Kronos-small'). Overrides model_id if provided.
            device (str): Device to run on ('cpu', 'cuda', 'mps')
            max_context (int): Maximum context length for the model
            lookback (int): Lookback window for historical data (default: 400)
        """
        # Use model_path if provided, otherwise use model_id
        if model_path:
            model_id = model_path
        
        # Get model configuration
        if model_id in MODEL_CONFIG:
            config = MODEL_CONFIG[model_id]
            tokenizer_id = config['tokenizer']
            max_context = config['context_length']
            model_name = config['name']
        else:
            model_name = model_id
        
        print(f"🤖 Preparing Kronos models...")
        print(f"   Model: {model_name} ({model_id})")
        print(f"   Tokenizer: {tokenizer_id}")
        
        self.device = device
        self.lookback = lookback
        self.max_context = max_context  # Store for use in prepare_data truncation
        self.pred_len = 24
        self.model_id = model_id
        self.tokenizer_id = tokenizer_id
        
        try:
            # Check if model is already loaded
            if model_id in loaded_models:
                print(f"   ♻️  Using cached model instance...")
                cached = loaded_models[model_id]
                self.tokenizer = cached['tokenizer']
                self.model = cached['model']
                self.predictor = cached['predictor']
                print(f"✅ Models loaded from cache")
            else:
                print(f"   📥 Loading model from HuggingFace (this may take a minute)...")
                # Load tokenizer
                tokenizer = KronosTokenizer.from_pretrained(tokenizer_id)
                
                # Load model with OOM error handling
                try:
                    model = Kronos.from_pretrained(model_id)
                except RuntimeError as e:
                    if 'out of memory' in str(e).lower() or 'cuda out of memory' in str(e).lower():
                        print(f"❌ Out of Memory Error: The {model_name} model is too large for your system.")
                        print(f"   💡 Try a smaller model:")
                        print(f"      - NeoQuasar/Kronos-mini (4.1M) - Most memory efficient")
                        print(f"      - NeoQuasar/Kronos-small (24.7M) - Balanced")
                        if device == 'cuda':
                            print(f"   💡 Or switch to CPU mode (slower but uses less GPU memory)")
                        raise RuntimeError(
                            f"Out of Memory: {model_name} is too large. Try a smaller model (Kronos-mini or Kronos-small) "
                            f"or switch to CPU device."
                        )
                    else:
                        raise
                
                # Create predictor
                predictor = KronosPredictor(
                    model,
                    tokenizer,
                    device=device,
                    max_context=max_context
                )
                
                # Cache the loaded models
                loaded_models[model_id] = {
                    'tokenizer': tokenizer,
                    'model': model,
                    'predictor': predictor
                }
                
                self.tokenizer = tokenizer
                self.model = model
                self.predictor = predictor
                print(f"✅ Models loaded successfully on {device}")
                
        except RuntimeError as e:
            if 'Out of Memory' in str(e):
                raise e
            print(f"❌ Failed to load models: {str(e)}")
            raise
    
    def prepare_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.Series, pd.Series]:
        """
        Prepare data for prediction.
        Automatically pads DataFrame to 400 rows if insufficient data.
        
        Args:
            df (pd.DataFrame): Input data with columns: timestamps, open, high, low, close, volume
        
        Returns:
            Tuple[pd.DataFrame, pd.Series, pd.Series]: (x_df, x_timestamp, y_timestamp)
        """
        min_lookback = 50  # Minimum data points for model to work
        target_lookback = 400  # Target context window
        
        if len(df) < min_lookback:
            raise ValueError(
                f"Insufficient data: need at least {min_lookback} rows, got {len(df)}"
            )
        
        # Pad DataFrame to target_lookback if insufficient
        if len(df) < target_lookback:
            print(f"⚠️  Data has {len(df)} rows, padding to {target_lookback}...")
            df = self._pad_dataframe(df, target_lookback)
            print(f"✅ DataFrame padded to {len(df)} rows")
        
        # Truncate to max_context (384 tokens) — the model only attends to this window anyway.
        # Using fewer tokens dramatically speeds up the attention computation.
        truncate_to = min(self.lookback, self.max_context, len(df) - self.pred_len)
        
        if truncate_to < min_lookback:
            raise ValueError(
                f"Insufficient data: need at least {min_lookback + self.pred_len} rows for lookback + prediction, got {len(df)}"
            )
        
        # Use last truncate_to points as input
        x_df = df[['open', 'high', 'low', 'close', 'volume']].iloc[-truncate_to:].copy()
        x_timestamp = df['timestamps'].iloc[-truncate_to:].copy()
        
        # Generate future timestamps for prediction
        last_timestamp = df['timestamps'].iloc[-1]
        if len(df) > 1:
            # Use the minimum positive time diff across all rows to avoid
            # overnight/weekend gaps skewing the forecast frequency
            all_diffs = df['timestamps'].diff().dropna()
            positive_diffs = all_diffs[all_diffs > pd.Timedelta(0)]
            time_diff = positive_diffs.min() if len(positive_diffs) > 0 else pd.Timedelta(hours=1)
        else:
            time_diff = pd.Timedelta(hours=1)
        
        y_timestamp = pd.date_range(
            start=last_timestamp + time_diff,
            periods=self.pred_len,
            freq=time_diff
        )
        
        return x_df, x_timestamp, y_timestamp
    
    def _pad_dataframe(self, df: pd.DataFrame, target_rows: int = 400) -> pd.DataFrame:
        """
        Pad DataFrame to target_rows by duplicating the earliest row.
        
        Args:
            df (pd.DataFrame): Original DataFrame
            target_rows (int): Target number of rows
        
        Returns:
            pd.DataFrame: Padded DataFrame
        """
        if len(df) >= target_rows:
            return df
        
        rows_needed = target_rows - len(df)
        
        # Get the earliest row for padding
        earliest_row = df.iloc[0].copy()
        
        # Calculate timestamp interval
        if len(df) > 1:
            time_diff = df.iloc[1]['timestamps'] - df.iloc[0]['timestamps']
        else:
            time_diff = pd.Timedelta(hours=1)
        
        # Create padding rows
        padding_rows = []
        for i in range(rows_needed):
            padded_row = earliest_row.copy()
            padded_row['timestamps'] = earliest_row['timestamps'] - (time_diff * (rows_needed - i))
            padding_rows.append(padded_row)
        
        # Combine padding with original data
        padding_df = pd.DataFrame(padding_rows)
        result = pd.concat([padding_df, df], ignore_index=True)
        
        return result
    
    def predict(self, 
                df: pd.DataFrame,
                sample_count: int = 30,
                temperature: float = 1.0,
                top_p: float = 0.9) -> Dict:
        """
        Generate probabilistic predictions.
        
        Args:
            df (pd.DataFrame): Historical OHLCV data
            sample_count (int): Number of sample paths (default: 30)
            temperature (float): Sampling temperature (default: 1.0)
            top_p (float): Nucleus sampling parameter (default: 0.9)
        
        Returns:
            Dict: Prediction results including mean, std, percentiles, and all samples
        """
        print(f"\n🔮 Generating {sample_count} sample paths for {self.pred_len}-hour forecast...")
        
        # Prepare data
        x_df, x_timestamp, y_timestamp = self.prepare_data(df)
        
        # Ensure timestamps are Series, not DatetimeIndex
        if isinstance(x_timestamp, pd.DatetimeIndex):
            x_timestamp = pd.Series(x_timestamp.values, name='timestamps')
        if isinstance(y_timestamp, pd.DatetimeIndex):
            y_timestamp = pd.Series(y_timestamp.values, name='timestamps')
        
        # Each call with sample_count=1 draws an independent stochastic sample.
        # auto_regressive_inference averages internally when sample_count>1, so
        # calling once with sample_count=N would collapse all variance → std=0.
        # We need independent calls to preserve the distribution for confidence intervals.
        predictions_list = []
        print(f"   Generating samples: ", end="", flush=True)
        for i in range(sample_count):
            if (i + 1) % max(1, sample_count // 5) == 0:
                print(f"{i+1}...", end="", flush=True)
            try:
                pred_df = self.predictor.predict(
                    df=x_df,
                    x_timestamp=x_timestamp,
                    y_timestamp=y_timestamp,
                    pred_len=self.pred_len,
                    T=temperature,
                    top_p=top_p,
                    sample_count=1,
                    verbose=False
                )
                predictions_list.append(pred_df)
            except Exception as e:
                print(f"\n⚠️  Sample {i+1} failed: {str(e)}, skipping...")
                continue
        print("✅")
        
        if not predictions_list:
            raise RuntimeError("All predictions failed")
        
        print(f"✅ Successfully generated {len(predictions_list)} samples")
        results = self._aggregate_predictions(predictions_list, y_timestamp)
        
        return results
    
    def _aggregate_predictions(self, 
                               predictions_list: list,
                               y_timestamp: pd.Series) -> Dict:
        """
        Aggregate multiple sample predictions into probabilistic forecast.
        
        Args:
            predictions_list (list): List of prediction DataFrames
            y_timestamp (pd.Series): Future timestamps
        
        Returns:
            Dict: Aggregated statistics and forecasts
        """
        # Stack all predictions
        samples = {}
        for col in predictions_list[0].columns:
            samples[col] = np.array([pred[col].values for pred in predictions_list])
        
        # Calculate statistics
        results = {
            'timestamps': np.array([ts.isoformat() if hasattr(ts, 'isoformat') else str(ts) 
                                   for ts in y_timestamp]),
            'samples': {}
        }
        
        for col in samples.keys():
            data = samples[col]
            
            results[col] = {
                'mean': np.mean(data, axis=0),
                'std': np.std(data, axis=0),
                'median': np.median(data, axis=0),
                'q5': np.percentile(data, 5, axis=0),    # 5th percentile
                'q25': np.percentile(data, 25, axis=0),  # 25th percentile
                'q75': np.percentile(data, 75, axis=0),  # 75th percentile
                'q95': np.percentile(data, 95, axis=0),  # 95th percentile
            }
            
            results['samples'][col] = data
        
        # Create summary DataFrame
        summary_df = pd.DataFrame({
            'timestamps': results['timestamps'],
            'open_mean': results['open']['mean'],
            'open_std': results['open']['std'],
            'high_mean': results['high']['mean'],
            'high_std': results['high']['std'],
            'low_mean': results['low']['mean'],
            'low_std': results['low']['std'],
            'close_mean': results['close']['mean'],
            'close_std': results['close']['std'],
            'close_q5': results['close']['q5'],
            'close_q25': results['close']['q25'],
            'close_q75': results['close']['q75'],
            'close_q95': results['close']['q95'],
            'volume_mean': results['volume']['mean'],
            'volume_std': results['volume']['std'],
        })
        
        results['summary_df'] = summary_df
        
        return results
    
    def print_forecast(self, results: Dict) -> None:
        """
        Print formatted forecast results.
        
        Args:
            results (Dict): Prediction results from predict()
        """
        df = results['summary_df']
        
        print("\n📊 Probabilistic Forecast Summary:")
        print("=" * 100)
        print(f"{'Time':<22} {'Close (Mean)':<12} {'±Std':<10} {'[5%, 95%]':<20}")
        print("-" * 100)
        
        for idx, row in df.iterrows():
            ts = row['timestamps'][:16] if isinstance(row['timestamps'], str) else str(row['timestamps'])[:16]
            close = row['close_mean']
            std = row['close_std']
            q5 = row['close_q5']
            q95 = row['close_q95']
            
            print(f"{ts:<22} ${close:>10.2f} ±{std:>8.2f} [{q5:>8.2f}, {q95:>8.2f}]")
        
        print("=" * 100)


def get_prediction(symbol: str = None,
                   data_path: str = None,
                   periods: int = 500,
                   sample_count: int = 30,
                   temperature: float = 1.0,
                   top_p: float = 0.9,
                   save_results: bool = True,
                   lookback: int = 400) -> Dict:
    """
    Main function to get prediction for a given ticker symbol or data file.
    
    Args:
        symbol (str): Stock ticker (e.g., 'AAPL', 'BTC-USD'). Either symbol or data_path required.
        data_path (str): Path to CSV file with OHLCV data. Either symbol or data_path required.
        periods (int): Number of historical periods to use (default: 500). Ignored if data_path provided.
        sample_count (int): Number of sample paths (default: 30)
        temperature (float): Sampling temperature (default: 1.0)
        top_p (float): Nucleus sampling parameter (default: 0.9)
        save_results (bool): Whether to save results to CSV (default: True)
        lookback (int): Lookback window for historical data (default: 400). Auto-adjusted based on data availability.
    
    Returns:
        Dict: Prediction results with mean, std, and confidence intervals
    
    Example:
        >>> results = get_prediction(symbol='AAPL')
        >>> results = get_prediction(data_path='examples/data/XSHG_5min_600977.csv', sample_count=30)
        >>> results = get_prediction(symbol='BTC-USD', sample_count=50, lookback=100)
    """
    if not symbol and not data_path:
        raise ValueError("Either 'symbol' or 'data_path' must be provided")
    
    if symbol and data_path:
        raise ValueError("Provide only one of 'symbol' or 'data_path', not both")
    
    print(f"\n🚀 Kronos Prediction Engine")
    print(f"{'='*60}")
    
    # Fetch or load data
    print(f"\n1️⃣  Loading historical data...")
    try:
        if data_path:
            # Load from CSV file
            df = pd.read_csv(data_path)
            df['timestamps'] = pd.to_datetime(df['timestamps'])
            df = df.sort_values('timestamps').reset_index(drop=True)
            data_source = f"file: {data_path}"
        else:
            # Fetch from yfinance
            df = fetch_hourly_klines(symbol, periods=periods)
            data_source = f"ticker: {symbol}"
        
        info = get_data_info(df)
        print(f"   ✅ Loaded {info['total_rows']} records from {data_source}")
        print(f"   📅 Date range: {info['start_date']} to {info['end_date']}")
        print(f"   💰 Price range: ${info['price_range_min']:.2f} - ${info['price_range_max']:.2f}")
    except Exception as e:
        print(f"   ❌ Failed to load data: {str(e)}")
        raise
    
    # Initialize engine with configurable lookback
    print(f"\n2️⃣  Initializing Kronos prediction engine...")
    try:
        engine = KronosPredictionEngine(lookback=lookback)
    except Exception as e:
        print(f"   ❌ Failed to initialize engine: {str(e)}")
        raise
    
    # Generate predictions
    print(f"\n3️⃣  Generating probabilistic forecast...")
    try:
        results = engine.predict(
            df,
            sample_count=sample_count,
            temperature=temperature,
            top_p=top_p
        )
    except Exception as e:
        print(f"   ❌ Prediction failed: {str(e)}")
        raise
    
    # Print summary
    print(f"\n4️⃣  Forecast Summary")
    engine.print_forecast(results)
    
    # Save results
    if save_results:
        print(f"\n5️⃣  Saving results...")
        output_name = symbol if symbol else Path(data_path).stem
        output_path = Path('predictions') / f"{output_name}_forecast.csv"
        output_path.parent.mkdir(parents=True, exist_ok=True)
        results['summary_df'].to_csv(output_path, index=False)
        print(f"   💾 Results saved to: {output_path}")
        
        # Also save full sample paths
        samples_path = output_path.parent / f"{output_name}_samples.npz"
        np.savez(samples_path, **results['samples'])
        print(f"   💾 Sample paths saved to: {samples_path}")
    
    print(f"\n✅ Prediction complete!")
    print(f"{'='*60}\n")
    
    return results


if __name__ == "__main__":
    import sys
    
    # Get symbol from command line or use default
    symbol = sys.argv[1].upper() if len(sys.argv) > 1 else "AAPL"
    sample_count = int(sys.argv[2]) if len(sys.argv) > 2 else 30
    
    try:
        results = get_prediction(symbol, sample_count=sample_count)
    except Exception as e:
        print(f"\n❌ Error: {str(e)}")
        sys.exit(1)