thoutam's picture
Add app.py
ce0d6d0 verified
#!/usr/bin/env python3
"""
Gradio App for StoxChai NSE Stock Prediction Models
This will be deployed on Hugging Face Spaces for interactive testing
"""
import gradio as gr
import numpy as np
import pandas as pd
import joblib
import os
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')
# Global variables for models and scaler
models = {}
feature_scaler = None
feature_names = [
'OpnPric', 'HghPric', 'LwPric', 'LastPric', 'PrvsClsgPric',
'Price_Range', 'Price_Change', 'Price_Change_Pct', 'Volume_Price_Ratio',
'SMA_5', 'SMA_20', 'Price_Momentum', 'Volume_MA', 'Volume_Ratio',
'TtlTradgVol', 'TtlTrfVal'
]
def load_models():
"""Load all trained models and scaler"""
global models, feature_scaler
try:
# Load the feature scaler
scaler_path = 'feature_scaler.joblib'
if os.path.exists(scaler_path):
feature_scaler = joblib.load(scaler_path)
print("βœ… Feature scaler loaded")
# Load all models
model_files = {
'RandomForest': 'randomforest_model.joblib',
'GradientBoosting': 'gradientboosting_model.joblib',
'LinearRegression': 'linearregression_model.joblib',
'Ridge': 'ridge_model.joblib',
'Lasso': 'lasso_model.joblib',
'SVR': 'svr_model.joblib',
'XGBoost': 'xgboost_model.joblib',
'LightGBM': 'lightgbm_model.joblib'
}
for model_name, filename in model_files.items():
if os.path.exists(filename):
models[model_name] = joblib.load(filename)
print(f"βœ… {model_name} model loaded")
print(f"🎯 Successfully loaded {len(models)} models")
return True
except Exception as e:
print(f"❌ Error loading models: {e}")
return False
def predict_stock_price(features: List[float], model_name: str = "RandomForest") -> Tuple[str, Dict]:
"""Make prediction using the specified model"""
try:
if not models or feature_scaler is None:
return "❌ Models not loaded. Please check the model files.", {}
if model_name not in models:
return f"❌ Model '{model_name}' not found.", {}
# Convert features to numpy array
features = np.array(features, dtype=float)
# Validate input
if len(features) != 16:
return f"❌ Expected 16 features, got {len(features)}", {}
# Create sequence for 5-day lookback (as trained)
sequence = np.tile(features, (5, 1)) # Shape: (5, 16)
flattened_features = sequence.reshape(1, -1) # Shape: (1, 80)
# Scale features
scaled_features = feature_scaler.transform(flattened_features)
# Make prediction
model = models[model_name]
prediction = model.predict(scaled_features)[0]
# Get all model predictions for comparison
all_predictions = {}
for name, model_obj in models.items():
try:
pred = model_obj.predict(scaled_features)[0]
all_predictions[name] = pred
except:
all_predictions[name] = None
# Calculate ensemble prediction
successful_predictions = [p for p in all_predictions.values() if p is not None]
if successful_predictions:
all_predictions['Ensemble'] = np.mean(successful_predictions)
return f"βœ… {model_name} Prediction: β‚Ή{prediction:.2f}", all_predictions
except Exception as e:
return f"❌ Prediction error: {e}", {}
def create_sample_data():
"""Create sample data for demonstration"""
sample_features = [
100.0, # OpnPric
105.0, # HghPric
98.0, # LwPric
102.0, # LastPric
100.0, # PrvsClsgPric
7.0, # Price_Range
2.0, # Price_Change
2.0, # Price_Change_Pct
1.5, # Volume_Price_Ratio
101.0, # SMA_5
100.5, # SMA_20
0.01, # Price_Momentum
1000.0, # Volume_MA
1.2, # Volume_Ratio
1200.0, # TtlTradgVol
120000.0 # TtlTrfVal
]
return sample_features
def format_predictions(predictions: Dict) -> str:
"""Format predictions for display"""
if not predictions:
return "No predictions available"
result = "πŸ“Š All Model Predictions:\n\n"
for model, pred in predictions.items():
if pred is not None:
result += f"β€’ {model:15s}: β‚Ή{pred:8.2f}\n"
else:
result += f"β€’ {model:15s}: ❌ Error\n"
return result
# Create the Gradio interface
def create_interface():
"""Create the Gradio interface"""
# Load models first
if not load_models():
gr.Warning("⚠️ Failed to load models. Please check the model files.")
with gr.Blocks(
title="StoxChai NSE Stock Price Predictor",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
.feature-input {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 20px;
border-radius: 15px;
margin: 10px 0;
}
.prediction-output {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
padding: 20px;
border-radius: 15px;
margin: 10px 0;
}
"""
) as app:
gr.Markdown("""
# 🎯 StoxChai NSE Stock Price Predictor
## Overview
This app allows you to test our trained machine learning models for predicting Indian stock prices using NSE bhavcopy data.
**Models Available**: RandomForest, GradientBoosting, LinearRegression, Ridge, Lasso, SVR, XGBoost, LightGBM
**Data Source**: NSE Bhavcopy (Jan 1 - Aug 20, 2025) covering 3,257 Indian equity stocks
""")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### πŸ“Š Input Features")
gr.Markdown("""
Enter the 16 features in order:
1. **OpnPric** - Opening Price
2. **HghPric** - High Price
3. **LwPric** - Low Price
4. **LastPric** - Last Price
5. **PrvsClsgPric** - Previous Close
6. **Price_Range** - High - Low
7. **Price_Change** - Current - Previous Close
8. **Price_Change_Pct** - Price Change %
9. **Volume_Price_Ratio** - Trading Value / Volume
10. **SMA_5** - 5-day Moving Average
11. **SMA_20** - 20-day Moving Average
12. **Price_Momentum** - Current / SMA_5 - 1
13. **Volume_MA** - 5-day Volume Average
14. **Volume_Ratio** - Current Volume / Volume_MA
15. **TtlTradgVol** - Total Trading Volume
16. **TtlTrfVal** - Total Trading Value
""")
# Feature inputs
feature_inputs = []
for i, name in enumerate(feature_names):
with gr.Row():
gr.Markdown(f"**{name}:**")
input_field = gr.Number(
value=create_sample_data()[i],
label=f"{i+1}. {name}",
precision=2,
scale=1
)
feature_inputs.append(input_field)
# Model selection
model_dropdown = gr.Dropdown(
choices=list(models.keys()) if models else ["RandomForest"],
value="RandomForest" if models else None,
label="Select Model",
info="Choose which model to use for prediction"
)
# Buttons
with gr.Row():
predict_btn = gr.Button("πŸš€ Make Prediction", variant="primary", size="lg")
sample_btn = gr.Button("πŸ“‹ Load Sample Data", variant="secondary")
clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
with gr.Column(scale=1):
gr.Markdown("### 🎯 Prediction Results")
# Main prediction output
prediction_output = gr.Textbox(
label="Prediction Result",
placeholder="Click 'Make Prediction' to see results...",
lines=3,
max_lines=5
)
# All model predictions
all_predictions_output = gr.Textbox(
label="All Model Predictions",
placeholder="All model predictions will appear here...",
lines=15,
max_lines=20
)
# Model info
gr.Markdown("### πŸ“‹ Model Information")
model_info = gr.Markdown(f"""
**Models Loaded**: {len(models)}/8
**Features**: 16 technical indicators
**Training Data**: 464,548 samples
**Coverage**: 3,257 Indian stocks
**Period**: Jan 1 - Aug 20, 2025
""")
# Event handlers
def predict_handler():
features = [float(input_field.value) for input_field in feature_inputs]
result, predictions = predict_stock_price(features, model_dropdown.value)
return result, format_predictions(predictions)
def sample_handler():
sample_data = create_sample_data()
return sample_data
def clear_handler():
return [0.0] * 16
predict_btn.click(
fn=predict_handler,
outputs=[prediction_output, all_predictions_output]
)
sample_btn.click(
fn=sample_handler,
outputs=feature_inputs
)
clear_btn.click(
fn=clear_handler,
outputs=feature_inputs
)
# Footer
gr.Markdown("""
---
**Disclaimer**: These predictions are for educational and research purposes only.
Do not use them as investment advice. Stock markets are inherently unpredictable.
**Model Details**: All models were trained on NSE bhavcopy data from 2025 and use
16 technical features with 5-day lookback sequences.
""")
return app
# Create and launch the app
if __name__ == "__main__":
app = create_interface()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)