Prateek2001's picture
Upload inference.py with huggingface_hub
5cbf060 verified
"""
Nifty 50 Ensemble Inference Script
Usage:
python inference.py --data data.csv --model models_v2_1m --horizon 1
"""
import argparse
import pandas as pd
import numpy as np
from nifty_ensemble_v2 import NiftyEnsembleV2
def main():
parser = argparse.ArgumentParser(description='Nifty 50 Directional Prediction')
parser.add_argument('--data', required=True, help='CSV with OHLCV columns (Datetime,Open,High,Low,Close,Volume)')
parser.add_argument('--model', required=True, help='Model directory (e.g., models_v2_1m or models_v2_5m)')
parser.add_argument('--horizon', type=int, default=1, help='Prediction horizon in minutes')
parser.add_argument('--output', default='predictions.csv', help='Output CSV path')
args = parser.parse_args()
# Load data
df = pd.read_csv(args.data, index_col='Datetime', parse_dates=True)
print(f"[Inference] Loaded {len(df)} rows from {args.data}")
# Load model
pipe = NiftyEnsembleV2()
pipe.load(args.model)
# Predict
results = pipe.predict(df)
df_out = results['df'].copy()
df_out['pred_up'] = results['pred']
df_out['proba_up'] = results['proba'][:, 1]
df_out['regime'] = results['regime']
# Regime labels
regime_labels = {0: 'Bull', 1: 'Bear', 2: 'Transitional', 3: 'High-Vol', 4: 'Low-Vol'}
df_out['regime_label'] = df_out['regime'].map(lambda x: regime_labels.get(x, f'Regime_{x}'))
# Signal strength
df_out['signal'] = np.where(df_out['proba_up'] > 0.6, 'STRONG_UP',
np.where(df_out['proba_up'] < 0.4, 'STRONG_DOWN', 'NEUTRAL'))
# Keep relevant columns
out_cols = ['Open','High','Low','Close','target','pred_up','proba_up',
'regime','regime_label','signal']
available = [c for c in out_cols if c in df_out.columns]
df_out[available].to_csv(args.output)
print(f"[Inference] Saved predictions to {args.output}")
print(f" Up predictions: {df_out['pred_up'].sum()} ({df_out['pred_up'].mean()*100:.1f}%)")
print(f" Mean up-probability: {df_out['proba_up'].mean():.3f}")
print(f" Regime distribution:\n{df_out['regime_label'].value_counts().to_string()}")
if __name__ == '__main__':
main()