dystopianfoe's picture
Update app.py
e801346 verified
import joblib
import pandas as pd
from flask import Flask, request, jsonify
# Initialize flask app
sales_predictor_api = Flask("Sales Predictor")
# Load the trained sales prediction model
model = joblib.load("sales_prediction_model_v1_0.joblib")
# Home route
@sales_predictor_api.get('/')
def home():
return "Welcome to the Sales Prediction API"
# Predict for a single product
@sales_predictor_api.post('/v1/product')
def predict_sales():
try:
# Get JSON data from the request
customer_data = request.get_json()
# Extract relevant features
sample = {
'Product_Weight': customer_data['Product_Weight'],
'Product_Sugar_Content': customer_data['Product_Sugar_Content'],
'Product_Allocated_Area': customer_data['Product_Allocated_Area'],
'Product_Type': customer_data['Product_Type'],
'Product_MRP': customer_data['Product_MRP'],
'Store_Id': customer_data['Store_Id'],
'Store_Establishment_Year': customer_data['Store_Establishment_Year'],
'Store_Size': customer_data['Store_Size'],
'Store_Location_City_Type': customer_data['Store_Location_City_Type'],
'Store_Type': customer_data['Store_Type'],
'Product_Id': customer_data['Product_Id']
}
# Convert to DataFrame
input_data = pd.DataFrame([sample])
# Extract ID prefix
input_data["Id"] = input_data["Product_Id"].str[:2]
input_data.drop("Product_Id", axis=1, inplace=True)
# Predict
prediction = model.predict(input_data).tolist()[0]
return jsonify({"Prediction": prediction})
# To catch errors
except Exception as e:
return jsonify({"error": str(e)}), 400
# Predict for a batch of products
@sales_predictor_api.post('/v1/productbatch')
def predict_sales_batch():
try:
# Get uploaded CSV
file = request.files['file']
input_data = pd.read_csv(file)
# Extract ID prefix
input_data["Id"] = input_data["Product_Id"].str[:2]
# Making a save of the product ID in a list
cust_id_list = input_data["Product_Id"].tolist()
input_data.drop("Product_Id", axis=1, inplace=True)
# Predict
predictions = model.predict(input_data).tolist()
# Stitch Product ID with the predictions
output_dict = dict(zip(cust_id_list, predictions))
return jsonify(output_dict)
# To catch errors
except Exception as e:
return jsonify({"error": str(e)}), 400
# Run Flask app
if __name__ == "__main__":
sales_predictor_api.run(debug=True)