muthukumar22's picture
Upload folder using huggingface_hub
c9b211d verified
# Import necessary libraries
import numpy as np
import joblib # For loading the serialized model
import pandas as pd # For data manipulation
from flask import Flask, request, jsonify # For creating the Flask API
# Initialize the Flask application
app = Flask("SuperKart Sales Predictor")
# Load the trained model
model = joblib.load("superkart_prediction_model_v1_0.joblib")
# --- Home route ---
@app.get('/')
def home():
"""
Health check endpoint for the API.
"""
return "Welcome to the SuperKart Sales Prediction API!"
# --- Single prediction endpoint ---
@app.post('/v1/predict')
def predict_sales_revenue():
"""
Predicts total sales revenue for a single product/store combination.
Expects a JSON payload with product and store details.
"""
try:
# Parse JSON input
data = request.get_json()
# Extract features
sample = {
'Product_Weight': data['Product_Weight'],
'Product_Sugar_Content': data['Product_Sugar_Content'],
'Product_Allocated_Area': data['Product_Allocated_Area'],
'Product_Type': data['Product_Type'],
'Product_MRP': data['Product_MRP'],
'Store_Size': data['Store_Size'],
'Store_Location_City_Type': data['Store_Location_City_Type'],
'Store_Type': data['Store_Type'],
'store_age': data['store_age']
}
# Convert to DataFrame
input_df = pd.DataFrame([sample])
# Make prediction
prediction = model.predict(input_df)[0]
# Return result
return jsonify({'Predicted_Sales_Revenue': round(float(prediction), 2)})
except Exception as e:
return jsonify({'error': str(e)}), 400
# --- Batch prediction endpoint ---
@app.post('/v1/predictbatch')
def predict_sales_batch():
"""
Predicts total sales revenue for multiple entries from a CSV file.
Expects a file upload under the key 'file'.
"""
try:
file = request.files['file']
input_data = pd.read_csv(file)
# Generate predictions
predictions = model.predict(input_data).tolist()
predictions = [round(float(p), 2) for p in predictions]
# Include IDs if present
if 'id' in input_data.columns:
ids = input_data['id'].tolist()
result = dict(zip(ids, predictions))
else:
result = {'Prediction_' + str(i + 1): p for i, p in enumerate(predictions)}
return jsonify(result)
except Exception as e:
return jsonify({'error': str(e)}), 400