Pushpak21's picture
Update app.py
6cbb02b verified
# Backend_files/app.y
import joblib
import pandas as pd
from flask import Flask, request, jsonify
from flask_cors import CORS
from sklearn.base import BaseEstimator, TransformerMixin
from custom_transformer import StoreAgeAdder, OutlierCapper
# Initialize flas app with name
Superkart_Sales_Predictor_API = Flask("Superkart Sales Predictor")
CORS(Superkart_Sales_Predictor_API) # Enable CORS for frontend integration (optional)
# Load the trained model
Random_Forest_Loaded_Model = joblib.load('Random_Forest_Model.pkl')
# Define a route for the home page (GET request)
@Superkart_Sales_Predictor_API.get('/')
def home():
"""
This function handles GET requests to the root URL ('/') of the API.
It returns a simple welcome message.
"""
return "Welcome to SuperKart Sales Predictor API!"
# Define an endpoint to predict the Superkart sales
@Superkart_Sales_Predictor_API.post('/predict')
def predict():
try:
# Parse input JSON
Product_And_Store_data = request.get_json()
# Validate input
required_fields = [
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
'Product_Type', 'Product_MRP', 'Store_Id', 'Store_Establishment_Year',
'Store_Size', 'Store_Location_City_Type', 'Store_Type'
]
missing = [f for f in required_fields if f not in Product_And_Store_data]
if missing:
return jsonify({"error": f"Missing fields: {missing}"}), 400
# Convert to DataFrame
sample_df = pd.DataFrame([{
'Product_Weight': Product_And_Store_data['Product_Weight'],
'Product_Sugar_Content': Product_And_Store_data['Product_Sugar_Content'],
'Product_Allocated_Area': Product_And_Store_data['Product_Allocated_Area'],
'Product_Type': Product_And_Store_data['Product_Type'],
'Product_MRP': Product_And_Store_data['Product_MRP'],
'Store_Id': Product_And_Store_data['Store_Id'],
'Store_Establishment_Year': Product_And_Store_data['Store_Establishment_Year'],
'Store_Size': Product_And_Store_data['Store_Size'],
'Store_Location_City_Type': Product_And_Store_data['Store_Location_City_Type'],
'Store_Type': Product_And_Store_data['Store_Type']
}])
# Predict
prediction = Random_Forest_Loaded_Model.predict(sample_df)
# Return response
return jsonify({'prediction': prediction.tolist()})
except Exception as e:
return jsonify({"error": str(e)}), 500
@Superkart_Sales_Predictor_API.post('/predict_batch')
def predict_batch():
try:
# Parse JSON input - should be a list of dicts
batch_data = request.get_json()
# Validate input type
if not isinstance(batch_data, list):
return jsonify({"error": "Input must be a list of records"}), 400
required_fields = [
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
'Product_Type', 'Product_MRP', 'Store_Id', 'Store_Establishment_Year',
'Store_Size', 'Store_Location_City_Type', 'Store_Type'
]
# Check each record
for i, record in enumerate(batch_data):
missing = [f for f in required_fields if f not in record]
if missing:
return jsonify({"error": f"Missing fields in record {i}: {missing}"}), 400
# Convert list of dicts to DataFrame
df = pd.DataFrame(batch_data)
# Predict
predictions = Random_Forest_Loaded_Model.predict(df)
# Return list of predictions
return jsonify({'predictions': predictions.tolist()})
except Exception as e:
return jsonify({"error": str(e)}), 500
# Run flask in debug mode
if __name__ == '__main__':
Superkart_Sales_Predictor_API.run(debug=False, host='0.0.0.0', port=7860)