zezkcy's picture
Upload folder using huggingface_hub
d30c591 verified
# Import necessary libraries
import numpy as np
import joblib
import pandas as pd
from flask import Flask, request, jsonify
from flask_cors import CORS # βœ… Import CORS
# Initialize Flask app
sales_prediction_api = Flask("SuperKart Sales Price Prediction API")
CORS(sales_prediction_api) # βœ… Enable CORS globally
# Load the trained machine learning model
model = joblib.load("sales_price_prediction_model_v1_0.joblib")
# Home route
@sales_prediction_api.get('/')
def home():
return "Welcome to SuperKart Sales Price Prediction API!"
# βœ… Single prediction endpoint
@sales_prediction_api.post('/v1/sales')
def predict_sales_price():
try:
data = request.get_json()
# Ensure all required fields are provided
required_fields = ['Product_Weight', 'Product_Allocated_Area', 'Product_MRP', 'Store_Establishment_Year',
'Product_Sugar_Content', 'Store_Size', 'Store_Location_City_Type',
'Store_Type', 'Product_Type', 'Store_ID']
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return jsonify({"error": f"Missing fields: {', '.join(missing_fields)}"}), 400
# Extract Store_Id
store_id = data['Store_ID']
input_sample = {
'Product_Weight': data['Product_Weight'],
'Product_Allocated_Area': data['Product_Allocated_Area'],
'Product_MRP': data['Product_MRP'],
'Store_Establishment_Year': data['Store_Establishment_Year'],
'Product_Sugar_Content': data['Product_Sugar_Content'],
'Store_Size': data['Store_Size'],
'Store_Location_City_Type': data['Store_Location_City_Type'],
'Store_Type': data['Store_Type'],
'Product_Type': data['Product_Type']
}
# One-hot encode Store_Id
for col in ['Store_OUT001', 'Store_OUT002', 'Store_OUT003', 'Store_OUT004']:
input_sample[col] = 1 if col == f'Store_{store_id}' else 0
input_df = pd.DataFrame([input_sample])
prediction = model.predict(input_df)[0]
return jsonify({"Predicted Sales": round(float(prediction), 2)})
except Exception as e:
return jsonify({"error": str(e)}), 500
# βœ… Batch prediction endpoint
@sales_prediction_api.post('/v1/salesbatch')
def predict_sales_price_batch():
if 'file' not in request.files:
return jsonify({"error": "CSV file not found."}), 400
file = request.files['file']
df = pd.read_csv(file)
if 'id' not in df.columns:
return jsonify({"error": "Missing 'id' column in uploaded CSV"}), 400
preds = model.predict(df)
results = [
{"id": i, "Predicted Sales": round(float(p), 2)}
for i, p in zip(df['id'], preds)
]
return jsonify(results)
# Debug run (for local testing)
if __name__ == '__main__':
sales_prediction_api.run(debug=True)