File size: 3,901 Bytes
5fc6306 f969e9d 5fc6306 f969e9d 70d620a f969e9d 5fc6306 f969e9d 5fc6306 6cbb02b 5fc6306 6cbb02b 5fc6306 6cbb02b 5fc6306 eb8e0f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# Backend_files/app.y
import joblib
import pandas as pd
from flask import Flask, request, jsonify
from flask_cors import CORS
from sklearn.base import BaseEstimator, TransformerMixin
from custom_transformer import StoreAgeAdder, OutlierCapper
# Initialize flas app with name
Superkart_Sales_Predictor_API = Flask("Superkart Sales Predictor")
CORS(Superkart_Sales_Predictor_API) # Enable CORS for frontend integration (optional)
# Load the trained model
Random_Forest_Loaded_Model = joblib.load('Random_Forest_Model.pkl')
# Define a route for the home page (GET request)
@Superkart_Sales_Predictor_API.get('/')
def home():
"""
This function handles GET requests to the root URL ('/') of the API.
It returns a simple welcome message.
"""
return "Welcome to SuperKart Sales Predictor API!"
# Define an endpoint to predict the Superkart sales
@Superkart_Sales_Predictor_API.post('/predict')
def predict():
try:
# Parse input JSON
Product_And_Store_data = request.get_json()
# Validate input
required_fields = [
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
'Product_Type', 'Product_MRP', 'Store_Id', 'Store_Establishment_Year',
'Store_Size', 'Store_Location_City_Type', 'Store_Type'
]
missing = [f for f in required_fields if f not in Product_And_Store_data]
if missing:
return jsonify({"error": f"Missing fields: {missing}"}), 400
# Convert to DataFrame
sample_df = pd.DataFrame([{
'Product_Weight': Product_And_Store_data['Product_Weight'],
'Product_Sugar_Content': Product_And_Store_data['Product_Sugar_Content'],
'Product_Allocated_Area': Product_And_Store_data['Product_Allocated_Area'],
'Product_Type': Product_And_Store_data['Product_Type'],
'Product_MRP': Product_And_Store_data['Product_MRP'],
'Store_Id': Product_And_Store_data['Store_Id'],
'Store_Establishment_Year': Product_And_Store_data['Store_Establishment_Year'],
'Store_Size': Product_And_Store_data['Store_Size'],
'Store_Location_City_Type': Product_And_Store_data['Store_Location_City_Type'],
'Store_Type': Product_And_Store_data['Store_Type']
}])
# Predict
prediction = Random_Forest_Loaded_Model.predict(sample_df)
# Return response
return jsonify({'prediction': prediction.tolist()})
except Exception as e:
return jsonify({"error": str(e)}), 500
@Superkart_Sales_Predictor_API.post('/predict_batch')
def predict_batch():
try:
# Parse JSON input - should be a list of dicts
batch_data = request.get_json()
# Validate input type
if not isinstance(batch_data, list):
return jsonify({"error": "Input must be a list of records"}), 400
required_fields = [
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
'Product_Type', 'Product_MRP', 'Store_Id', 'Store_Establishment_Year',
'Store_Size', 'Store_Location_City_Type', 'Store_Type'
]
# Check each record
for i, record in enumerate(batch_data):
missing = [f for f in required_fields if f not in record]
if missing:
return jsonify({"error": f"Missing fields in record {i}: {missing}"}), 400
# Convert list of dicts to DataFrame
df = pd.DataFrame(batch_data)
# Predict
predictions = Random_Forest_Loaded_Model.predict(df)
# Return list of predictions
return jsonify({'predictions': predictions.tolist()})
except Exception as e:
return jsonify({"error": str(e)}), 500
# Run flask in debug mode
if __name__ == '__main__':
Superkart_Sales_Predictor_API.run(debug=False, host='0.0.0.0', port=7860) |