gl-backend / app.py
rakesh1715's picture
Upload folder using huggingface_hub
20c8664 verified
from datetime import datetime
import joblib
import pandas as pd
from flask import Flask, request, jsonify
REQUIRED_FIELDS = [
"Product_Weight",
"Product_Sugar_Content",
"Product_Allocated_Area",
"Product_MRP",
"Store_Size",
"Store_Id",
"Store_Location_City_Type",
"Store_Type",
"Store_Age",
"Product_Type_Categories",
]
# Initialize Flask app with a name
sales_prediction = Flask("Superkart Sales Prediction")
# Load the trained sales prediction model
model = joblib.load("super_kart_prediction_model_v1_0.joblib")
# Define a route for the home page
@sales_prediction.get('/')
def home():
"""Home API endpoint"""
return "Welcome to the Product Sales Prediction API!"
# Define an endpoint to predict sales for a single product
@sales_prediction.post('/v1/predict')
def predict_product_sales():
"""API endpoint to predict sales of a single product"""
try:
product_data = request.get_json()
# Validate required inputs
missing_fields = [f for f in REQUIRED_FIELDS if f not in product_data]
if missing_fields:
return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400
# Extract relevant product features from the input data
data = pd.DataFrame([
{
'Product_Weight': float(product_data['Product_Weight']),
'Product_Sugar_Content': product_data['Product_Sugar_Content'],
'Product_Allocated_Area': float(product_data['Product_Allocated_Area']),
'Product_MRP': float(product_data['Product_MRP']),
'Store_Size': product_data['Store_Size'],
'Store_Id': product_data['Store_Id'],
'Store_Location_City_Type': product_data['Store_Location_City_Type'],
'Store_Type': product_data['Store_Type'],
'Store_Age': int(product_data['Store_Age']),
'Product_Type_Categories': product_data['Product_Type_Categories']
}
])
# Make a churn prediction using the trained model
sales_predicted = model.predict(data).tolist()[0]
# Return the prediction as a JSON response
return jsonify({'Sales': sales_predicted})
except Exception as error:
return jsonify({"error": f"Prediction failed: {str(error)}"}), 500
# Define an endpoint to predict sales for a batch of products
@sales_prediction.post('/v1/bulk/predict')
def predict_multiple_products_sales():
"""API endpoint to predict sales of multiple products"""
# Get the uploaded CSV file from the request
try:
def add_store_age(est_year):
"""Function that adds store age"""
return datetime.now().year - int(est_year)
file = request.files['file']
# Read the file into a DataFrame
input_data = pd.read_csv(file)
input_data["Store_Age"] = input_data["Store_Establishment_Year"].apply(add_store_age)
input_data.drop("Store_Establishment_Year", axis=1, inplace=True)
# Make predictions for the batch data and convert raw predictions into a readable format
predictions = [x for x in model.predict(input_data.drop("Product_ID", axis=1)).tolist()]
product_id_list = input_data.Product_ID.values.tolist()
output_dict = dict(zip(product_id_list, predictions))
return output_dict
except Exception as error:
return jsonify({"error": f"Prediction failed: {str(error)}"}), 500
# Run the Flask app in debug mode
if __name__ == '__main__':
sales_prediction.run(debug=True)