| """ |
| Nifty 50 Ensemble Inference Script |
| Usage: |
| python inference.py --data data.csv --model models_v2_1m --horizon 1 |
| """ |
| import argparse |
| import pandas as pd |
| import numpy as np |
| from nifty_ensemble_v2 import NiftyEnsembleV2 |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Nifty 50 Directional Prediction') |
| parser.add_argument('--data', required=True, help='CSV with OHLCV columns (Datetime,Open,High,Low,Close,Volume)') |
| parser.add_argument('--model', required=True, help='Model directory (e.g., models_v2_1m or models_v2_5m)') |
| parser.add_argument('--horizon', type=int, default=1, help='Prediction horizon in minutes') |
| parser.add_argument('--output', default='predictions.csv', help='Output CSV path') |
| args = parser.parse_args() |
| |
| |
| df = pd.read_csv(args.data, index_col='Datetime', parse_dates=True) |
| print(f"[Inference] Loaded {len(df)} rows from {args.data}") |
| |
| |
| pipe = NiftyEnsembleV2() |
| pipe.load(args.model) |
| |
| |
| results = pipe.predict(df) |
| df_out = results['df'].copy() |
| df_out['pred_up'] = results['pred'] |
| df_out['proba_up'] = results['proba'][:, 1] |
| df_out['regime'] = results['regime'] |
| |
| |
| regime_labels = {0: 'Bull', 1: 'Bear', 2: 'Transitional', 3: 'High-Vol', 4: 'Low-Vol'} |
| df_out['regime_label'] = df_out['regime'].map(lambda x: regime_labels.get(x, f'Regime_{x}')) |
| |
| |
| df_out['signal'] = np.where(df_out['proba_up'] > 0.6, 'STRONG_UP', |
| np.where(df_out['proba_up'] < 0.4, 'STRONG_DOWN', 'NEUTRAL')) |
| |
| |
| out_cols = ['Open','High','Low','Close','target','pred_up','proba_up', |
| 'regime','regime_label','signal'] |
| available = [c for c in out_cols if c in df_out.columns] |
| df_out[available].to_csv(args.output) |
| |
| print(f"[Inference] Saved predictions to {args.output}") |
| print(f" Up predictions: {df_out['pred_up'].sum()} ({df_out['pred_up'].mean()*100:.1f}%)") |
| print(f" Mean up-probability: {df_out['proba_up'].mean():.3f}") |
| print(f" Regime distribution:\n{df_out['regime_label'].value_counts().to_string()}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|