|
|
|
|
|
import joblib |
|
|
import pandas as pd |
|
|
from flask import Flask, request, jsonify |
|
|
from flask_cors import CORS |
|
|
from sklearn.base import BaseEstimator, TransformerMixin |
|
|
|
|
|
from custom_transformer import StoreAgeAdder, OutlierCapper |
|
|
|
|
|
|
|
|
Superkart_Sales_Predictor_API = Flask("Superkart Sales Predictor") |
|
|
CORS(Superkart_Sales_Predictor_API) |
|
|
|
|
|
|
|
|
Random_Forest_Loaded_Model = joblib.load('Random_Forest_Model.pkl') |
|
|
|
|
|
|
|
|
@Superkart_Sales_Predictor_API.get('/') |
|
|
def home(): |
|
|
""" |
|
|
This function handles GET requests to the root URL ('/') of the API. |
|
|
It returns a simple welcome message. |
|
|
""" |
|
|
return "Welcome to SuperKart Sales Predictor API!" |
|
|
|
|
|
|
|
|
@Superkart_Sales_Predictor_API.post('/predict') |
|
|
def predict(): |
|
|
try: |
|
|
|
|
|
Product_And_Store_data = request.get_json() |
|
|
|
|
|
|
|
|
required_fields = [ |
|
|
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area', |
|
|
'Product_Type', 'Product_MRP', 'Store_Id', 'Store_Establishment_Year', |
|
|
'Store_Size', 'Store_Location_City_Type', 'Store_Type' |
|
|
] |
|
|
missing = [f for f in required_fields if f not in Product_And_Store_data] |
|
|
if missing: |
|
|
return jsonify({"error": f"Missing fields: {missing}"}), 400 |
|
|
|
|
|
|
|
|
sample_df = pd.DataFrame([{ |
|
|
'Product_Weight': Product_And_Store_data['Product_Weight'], |
|
|
'Product_Sugar_Content': Product_And_Store_data['Product_Sugar_Content'], |
|
|
'Product_Allocated_Area': Product_And_Store_data['Product_Allocated_Area'], |
|
|
'Product_Type': Product_And_Store_data['Product_Type'], |
|
|
'Product_MRP': Product_And_Store_data['Product_MRP'], |
|
|
'Store_Id': Product_And_Store_data['Store_Id'], |
|
|
'Store_Establishment_Year': Product_And_Store_data['Store_Establishment_Year'], |
|
|
'Store_Size': Product_And_Store_data['Store_Size'], |
|
|
'Store_Location_City_Type': Product_And_Store_data['Store_Location_City_Type'], |
|
|
'Store_Type': Product_And_Store_data['Store_Type'] |
|
|
}]) |
|
|
|
|
|
|
|
|
prediction = Random_Forest_Loaded_Model.predict(sample_df) |
|
|
|
|
|
|
|
|
return jsonify({'prediction': prediction.tolist()}) |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
@Superkart_Sales_Predictor_API.post('/predict_batch') |
|
|
def predict_batch(): |
|
|
try: |
|
|
|
|
|
batch_data = request.get_json() |
|
|
|
|
|
|
|
|
if not isinstance(batch_data, list): |
|
|
return jsonify({"error": "Input must be a list of records"}), 400 |
|
|
|
|
|
required_fields = [ |
|
|
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area', |
|
|
'Product_Type', 'Product_MRP', 'Store_Id', 'Store_Establishment_Year', |
|
|
'Store_Size', 'Store_Location_City_Type', 'Store_Type' |
|
|
] |
|
|
|
|
|
|
|
|
for i, record in enumerate(batch_data): |
|
|
missing = [f for f in required_fields if f not in record] |
|
|
if missing: |
|
|
return jsonify({"error": f"Missing fields in record {i}: {missing}"}), 400 |
|
|
|
|
|
|
|
|
df = pd.DataFrame(batch_data) |
|
|
|
|
|
|
|
|
predictions = Random_Forest_Loaded_Model.predict(df) |
|
|
|
|
|
|
|
|
return jsonify({'predictions': predictions.tolist()}) |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
Superkart_Sales_Predictor_API.run(debug=False, host='0.0.0.0', port=7860) |