ramzai9's picture
Upload folder using huggingface_hub
1081883 verified
# ============================================================
# SuperKart Sales Prediction - Flask REST API
# ============================================================
# API Structure:
# GET / -> Welcome message
# GET /health -> Returns API health status (JSON)
# POST /v1/predict -> Predict total sales for a product-store pair
#
# Request Body for POST /v1/predict (Content-Type: application/json):
# {
# "Product_Weight": float (4.0 - 22.0 kg)
# "Product_Sugar_Content": string ("Low Sugar" | "Regular" | "No Sugar")
# "Product_Allocated_Area": float (0.004 - 0.298 display ratio)
# "Product_MRP": float (31.0 - 266.0, INR)
# "Store_Size": string ("High" | "Medium" | "Low")
# "Store_Location_City_Type": string ("Tier 1" | "Tier 2" | "Tier 3")
# "Store_Type": string ("Departmental Store" | "Supermarket Type1"
# | "Supermarket Type2" | "Food Mart")
# "Product_Id_char": string ("FD" | "DR" | "NC")
# "Store_Age_Years": int (0 - 100)
# "Product_Type_Category": string ("Perishables" | "Non Perishables")
# }
#
# Responses:
# 200 OK -> {"Sales": <float>}
# 400 Bad Req -> {"error": "<validation message>"}
# 500 Error -> {"error": "Internal server error: <detail>"}
# ============================================================
import numpy as np
import joblib
import pandas as pd
from flask import Flask, request, jsonify
# Initialize Flask app
superkart_api = Flask("SuperKartAPI")
# Load the trained XGBoost pipeline model
model = joblib.load("xgb_tuned_model.joblib")
# -----------------------------------------------------------
# Valid categorical values used for input validation
# -----------------------------------------------------------
VALID_SUGAR_CONTENT = {"Low Sugar", "Regular", "No Sugar"}
VALID_STORE_SIZE = {"High", "Medium", "Low"}
VALID_CITY_TYPE = {"Tier 1", "Tier 2", "Tier 3"}
VALID_STORE_TYPE = {"Departmental Store", "Supermarket Type1",
"Supermarket Type2", "Food Mart"}
VALID_PRODUCT_CHAR = {"FD", "DR", "NC"}
VALID_TYPE_CATEGORY = {"Perishables", "Non Perishables"}
REQUIRED_FIELDS = [
"Product_Weight", "Product_Sugar_Content", "Product_Allocated_Area",
"Product_MRP", "Store_Size", "Store_Location_City_Type", "Store_Type",
"Product_Id_char", "Store_Age_Years", "Product_Type_Category"
]
def validate_input(data):
"""Validate the JSON request payload.
Returns (True, None) if valid; (False, error_message) otherwise.
"""
# 1. Check all required fields are present
missing = [f for f in REQUIRED_FIELDS if f not in data]
if missing:
return False, f"Missing required fields: {missing}"
# 2. Numeric type validation
try:
float(data["Product_Weight"])
float(data["Product_Allocated_Area"])
float(data["Product_MRP"])
int(data["Store_Age_Years"])
except (ValueError, TypeError) as e:
return False, f"Invalid numeric value: {str(e)}"
# 3. Range validation
if not (4.0 <= float(data["Product_Weight"]) <= 22.0):
return False, "Product_Weight must be between 4.0 and 22.0"
if not (0.004 <= float(data["Product_Allocated_Area"]) <= 0.298):
return False, "Product_Allocated_Area must be between 0.004 and 0.298"
if not (31.0 <= float(data["Product_MRP"]) <= 266.0):
return False, "Product_MRP must be between 31.0 and 266.0"
if not (0 <= int(data["Store_Age_Years"]) <= 100):
return False, "Store_Age_Years must be between 0 and 100"
# 4. Categorical value validation
if data["Product_Sugar_Content"] not in VALID_SUGAR_CONTENT:
return False, f"Product_Sugar_Content must be one of {VALID_SUGAR_CONTENT}"
if data["Store_Size"] not in VALID_STORE_SIZE:
return False, f"Store_Size must be one of {VALID_STORE_SIZE}"
if data["Store_Location_City_Type"] not in VALID_CITY_TYPE:
return False, f"Store_Location_City_Type must be one of {VALID_CITY_TYPE}"
if data["Store_Type"] not in VALID_STORE_TYPE:
return False, f"Store_Type must be one of {VALID_STORE_TYPE}"
if data["Product_Id_char"] not in VALID_PRODUCT_CHAR:
return False, f"Product_Id_char must be one of {VALID_PRODUCT_CHAR}"
if data["Product_Type_Category"] not in VALID_TYPE_CATEGORY:
return False, f"Product_Type_Category must be one of {VALID_TYPE_CATEGORY}"
return True, None
# -----------------------------------------------------------
# Routes
# -----------------------------------------------------------
@superkart_api.get("/")
def home():
"""Welcome endpoint."""
return ("Welcome to the SuperKart Sales Prediction API\! "
"Send a POST request to /v1/predict to get a sales forecast.")
@superkart_api.get("/health")
def health():
"""Health-check endpoint - returns API status."""
return jsonify({"status": "healthy", "model": "XGBoost SuperKart Pipeline"})
@superkart_api.post("/v1/predict")
def predict_sales():
"""
Predict total product-store sales (Product_Store_Sales_Total).
Expects a JSON body with all 10 required feature fields.
Returns the predicted sales value rounded to 2 decimal places.
"""
try:
data = request.get_json(force=True)
if data is None:
return jsonify({"error": "Request body must be valid JSON"}), 400
# Validate input payload
is_valid, error_msg = validate_input(data)
if not is_valid:
return jsonify({"error": error_msg}), 400
# Build typed DataFrame (order must match model pipeline)
sample = {
"Product_Weight": float(data["Product_Weight"]),
"Product_Sugar_Content": data["Product_Sugar_Content"],
"Product_Allocated_Area": float(data["Product_Allocated_Area"]),
"Product_MRP": float(data["Product_MRP"]),
"Store_Size": data["Store_Size"],
"Store_Location_City_Type": data["Store_Location_City_Type"],
"Store_Type": data["Store_Type"],
"Product_Id_char": data["Product_Id_char"],
"Store_Age_Years": int(data["Store_Age_Years"]),
"Product_Type_Category": data["Product_Type_Category"]
}
input_df = pd.DataFrame([sample])
# Generate and return prediction
prediction = float(model.predict(input_df)[0])
return jsonify({"Sales": round(prediction, 2)}), 200
except Exception as e:
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
if __name__ == "__main__":
superkart_api.run(debug=True)