stoxchai-nse-predictor / inference.py
thoutam's picture
Upload inference.py with huggingface_hub
f31795e verified
#!/usr/bin/env python3
"""
Hugging Face compatible inference script for StoxChai NSE Stock Prediction Models
"""
import os
import joblib
import numpy as np
import pandas as pd
from typing import Dict, List, Union, Optional
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class StoxChaiStockPredictor:
"""
Hugging Face compatible stock price predictor
"""
def __init__(self, model_dir: str = "."):
"""
Initialize the predictor with trained models
Args:
model_dir: Directory containing the trained models
"""
self.model_dir = model_dir
self.models = {}
self.scalers = {}
self.feature_names = [
'OpnPric', 'HghPric', 'LwPric', 'LastPric', 'PrvsClsgPric',
'Price_Range', 'Price_Change', 'Price_Change_Pct', 'Volume_Price_Ratio',
'SMA_5', 'SMA_20', 'Price_Momentum', 'Volume_MA', 'Volume_Ratio',
'TtlTradgVol', 'TtlTrfVal'
]
self._load_models()
def _load_models(self):
"""Load all trained models and scalers"""
try:
# Load the single feature scaler used for all models
scaler_path = os.path.join(self.model_dir, 'feature_scaler.joblib')
if os.path.exists(scaler_path):
self.feature_scaler = joblib.load(scaler_path)
logger.info("Loaded feature scaler")
else:
raise FileNotFoundError("feature_scaler.joblib not found")
# Load traditional ML models
model_files = {
'randomforest': 'randomforest_model.joblib',
'gradientboosting': 'gradientboosting_model.joblib',
'linearregression': 'linearregression_model.joblib',
'ridge': 'ridge_model.joblib',
'lasso': 'lasso_model.joblib',
'svr': 'svr_model.joblib',
'xgboost': 'xgboost_model.joblib',
'lightgbm': 'lightgbm_model.joblib'
}
for model_name, filename in model_files.items():
model_path = os.path.join(self.model_dir, filename)
if os.path.exists(model_path):
self.models[model_name] = joblib.load(model_path)
# All models use the same scaler
self.scalers[model_name] = self.feature_scaler
logger.info(f"Loaded {model_name} model")
else:
logger.warning(f"Missing model file for {model_name}: {filename}")
logger.info(f"Successfully loaded {len(self.models)} models")
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
def predict(self, features: Union[List, np.ndarray], model_name: str = "randomforest",
lookback_days: int = 5) -> float:
"""
Make a prediction using the specified model
Args:
features: Input features (16 features in order)
model_name: Name of the model to use
lookback_days: Number of days to look back (default: 5)
Returns:
Predicted stock price
"""
if model_name not in self.models:
raise ValueError(f"Model '{model_name}' not found. Available models: {list(self.models.keys())}")
# Convert to numpy array
features = np.array(features, dtype=float)
if features.shape != (16,):
raise ValueError(f"Expected 16 features, got {features.shape[0]}")
# Create sequence by repeating the same features for lookback_days
# This simulates having the same features for multiple days
sequence = np.tile(features, (lookback_days, 1)) # Shape: (lookback_days, 16)
# Flatten to match training data format
flattened_features = sequence.reshape(1, -1) # Shape: (1, lookback_days * 16)
# Scale features
scaler = self.scalers[model_name]
scaled_features = scaler.transform(flattened_features)
# Make prediction
model = self.models[model_name]
prediction = model.predict(scaled_features)
return float(prediction[0])
def predict_all_models(self, features: Union[List, np.ndarray], lookback_days: int = 5) -> Dict[str, float]:
"""
Make predictions using all available models
Args:
features: Input features (16 features in order)
lookback_days: Number of days to look back (default: 5)
Returns:
Dictionary of predictions from all models
"""
predictions = {}
for model_name in self.models.keys():
try:
pred = self.predict(features, model_name, lookback_days)
predictions[model_name] = pred
except Exception as e:
logger.warning(f"Error with {model_name}: {e}")
predictions[model_name] = None
# Calculate ensemble prediction (average of successful predictions)
successful_predictions = [p for p in predictions.values() if p is not None]
if successful_predictions:
predictions['ensemble'] = np.mean(successful_predictions)
return predictions
def get_feature_names(self) -> List[str]:
"""Get the list of feature names in order"""
return self.feature_names.copy()
def get_available_models(self) -> List[str]:
"""Get list of available model names"""
return list(self.models.keys())
def get_model_info(self) -> Dict[str, Dict]:
"""Get information about all loaded models"""
info = {}
for model_name in self.models.keys():
model = self.models[model_name]
info[model_name] = {
'type': type(model).__name__,
'features': len(self.feature_names),
'scaler_type': type(self.feature_scaler).__name__
}
return info
def main():
"""Example usage of the predictor"""
try:
# Initialize predictor
predictor = StoxChaiStockPredictor()
print("🎯 StoxChai NSE Stock Price Predictor")
print("=" * 50)
# Show available models
print(f"Available models: {', '.join(predictor.get_available_models())}")
print(f"Features required: {len(predictor.get_feature_names())}")
# Example features (you would replace these with real data)
example_features = [
100.0, # OpnPric
105.0, # HghPric
98.0, # LwPric
102.0, # LastPric
100.0, # PrvsClsgPric
7.0, # Price_Range
2.0, # Price_Change
2.0, # Price_Change_Pct
1.5, # Volume_Price_Ratio
101.0, # SMA_5
100.5, # SMA_20
0.01, # Price_Momentum
1000.0, # Volume_MA
1.2, # Volume_Ratio
1200.0, # TtlTradgVol
120000.0 # TtlTrfVal
]
print(f"\nExample prediction with sample data:")
print(f"Features: {example_features}")
# Single model prediction
pred = predictor.predict(example_features, "randomforest")
print(f"RandomForest prediction: ₹{pred:.2f}")
# All models prediction
all_preds = predictor.predict_all_models(example_features)
print(f"\nAll model predictions:")
for model, pred in all_preds.items():
if pred is not None:
print(f" {model}: ₹{pred:.2f}")
print(f"\n✅ Inference test completed successfully!")
except Exception as e:
print(f"❌ Error: {e}")
logger.error(f"Error in main: {e}")
if __name__ == "__main__":
main()