Spaces:
Running
Running
| # ============================================================ | |
| # SuperKart Sales Prediction - Flask REST API | |
| # ============================================================ | |
| # API Structure: | |
| # GET / -> Welcome message | |
| # GET /health -> Returns API health status (JSON) | |
| # POST /v1/predict -> Predict total sales for a product-store pair | |
| # | |
| # Request Body for POST /v1/predict (Content-Type: application/json): | |
| # { | |
| # "Product_Weight": float (4.0 - 22.0 kg) | |
| # "Product_Sugar_Content": string ("Low Sugar" | "Regular" | "No Sugar") | |
| # "Product_Allocated_Area": float (0.004 - 0.298 display ratio) | |
| # "Product_MRP": float (31.0 - 266.0, INR) | |
| # "Store_Size": string ("High" | "Medium" | "Low") | |
| # "Store_Location_City_Type": string ("Tier 1" | "Tier 2" | "Tier 3") | |
| # "Store_Type": string ("Departmental Store" | "Supermarket Type1" | |
| # | "Supermarket Type2" | "Food Mart") | |
| # "Product_Id_char": string ("FD" | "DR" | "NC") | |
| # "Store_Age_Years": int (0 - 100) | |
| # "Product_Type_Category": string ("Perishables" | "Non Perishables") | |
| # } | |
| # | |
| # Responses: | |
| # 200 OK -> {"Sales": <float>} | |
| # 400 Bad Req -> {"error": "<validation message>"} | |
| # 500 Error -> {"error": "Internal server error: <detail>"} | |
| # ============================================================ | |
| import numpy as np | |
| import joblib | |
| import pandas as pd | |
| from flask import Flask, request, jsonify | |
| # Initialize Flask app | |
| superkart_api = Flask("SuperKartAPI") | |
| # Load the trained XGBoost pipeline model | |
| model = joblib.load("xgb_tuned_model.joblib") | |
| # ----------------------------------------------------------- | |
| # Valid categorical values used for input validation | |
| # ----------------------------------------------------------- | |
| VALID_SUGAR_CONTENT = {"Low Sugar", "Regular", "No Sugar"} | |
| VALID_STORE_SIZE = {"High", "Medium", "Low"} | |
| VALID_CITY_TYPE = {"Tier 1", "Tier 2", "Tier 3"} | |
| VALID_STORE_TYPE = {"Departmental Store", "Supermarket Type1", | |
| "Supermarket Type2", "Food Mart"} | |
| VALID_PRODUCT_CHAR = {"FD", "DR", "NC"} | |
| VALID_TYPE_CATEGORY = {"Perishables", "Non Perishables"} | |
| REQUIRED_FIELDS = [ | |
| "Product_Weight", "Product_Sugar_Content", "Product_Allocated_Area", | |
| "Product_MRP", "Store_Size", "Store_Location_City_Type", "Store_Type", | |
| "Product_Id_char", "Store_Age_Years", "Product_Type_Category" | |
| ] | |
| def validate_input(data): | |
| """Validate the JSON request payload. | |
| Returns (True, None) if valid; (False, error_message) otherwise. | |
| """ | |
| # 1. Check all required fields are present | |
| missing = [f for f in REQUIRED_FIELDS if f not in data] | |
| if missing: | |
| return False, f"Missing required fields: {missing}" | |
| # 2. Numeric type validation | |
| try: | |
| float(data["Product_Weight"]) | |
| float(data["Product_Allocated_Area"]) | |
| float(data["Product_MRP"]) | |
| int(data["Store_Age_Years"]) | |
| except (ValueError, TypeError) as e: | |
| return False, f"Invalid numeric value: {str(e)}" | |
| # 3. Range validation | |
| if not (4.0 <= float(data["Product_Weight"]) <= 22.0): | |
| return False, "Product_Weight must be between 4.0 and 22.0" | |
| if not (0.004 <= float(data["Product_Allocated_Area"]) <= 0.298): | |
| return False, "Product_Allocated_Area must be between 0.004 and 0.298" | |
| if not (31.0 <= float(data["Product_MRP"]) <= 266.0): | |
| return False, "Product_MRP must be between 31.0 and 266.0" | |
| if not (0 <= int(data["Store_Age_Years"]) <= 100): | |
| return False, "Store_Age_Years must be between 0 and 100" | |
| # 4. Categorical value validation | |
| if data["Product_Sugar_Content"] not in VALID_SUGAR_CONTENT: | |
| return False, f"Product_Sugar_Content must be one of {VALID_SUGAR_CONTENT}" | |
| if data["Store_Size"] not in VALID_STORE_SIZE: | |
| return False, f"Store_Size must be one of {VALID_STORE_SIZE}" | |
| if data["Store_Location_City_Type"] not in VALID_CITY_TYPE: | |
| return False, f"Store_Location_City_Type must be one of {VALID_CITY_TYPE}" | |
| if data["Store_Type"] not in VALID_STORE_TYPE: | |
| return False, f"Store_Type must be one of {VALID_STORE_TYPE}" | |
| if data["Product_Id_char"] not in VALID_PRODUCT_CHAR: | |
| return False, f"Product_Id_char must be one of {VALID_PRODUCT_CHAR}" | |
| if data["Product_Type_Category"] not in VALID_TYPE_CATEGORY: | |
| return False, f"Product_Type_Category must be one of {VALID_TYPE_CATEGORY}" | |
| return True, None | |
| # ----------------------------------------------------------- | |
| # Routes | |
| # ----------------------------------------------------------- | |
| def home(): | |
| """Welcome endpoint.""" | |
| return ("Welcome to the SuperKart Sales Prediction API\! " | |
| "Send a POST request to /v1/predict to get a sales forecast.") | |
| def health(): | |
| """Health-check endpoint - returns API status.""" | |
| return jsonify({"status": "healthy", "model": "XGBoost SuperKart Pipeline"}) | |
| def predict_sales(): | |
| """ | |
| Predict total product-store sales (Product_Store_Sales_Total). | |
| Expects a JSON body with all 10 required feature fields. | |
| Returns the predicted sales value rounded to 2 decimal places. | |
| """ | |
| try: | |
| data = request.get_json(force=True) | |
| if data is None: | |
| return jsonify({"error": "Request body must be valid JSON"}), 400 | |
| # Validate input payload | |
| is_valid, error_msg = validate_input(data) | |
| if not is_valid: | |
| return jsonify({"error": error_msg}), 400 | |
| # Build typed DataFrame (order must match model pipeline) | |
| sample = { | |
| "Product_Weight": float(data["Product_Weight"]), | |
| "Product_Sugar_Content": data["Product_Sugar_Content"], | |
| "Product_Allocated_Area": float(data["Product_Allocated_Area"]), | |
| "Product_MRP": float(data["Product_MRP"]), | |
| "Store_Size": data["Store_Size"], | |
| "Store_Location_City_Type": data["Store_Location_City_Type"], | |
| "Store_Type": data["Store_Type"], | |
| "Product_Id_char": data["Product_Id_char"], | |
| "Store_Age_Years": int(data["Store_Age_Years"]), | |
| "Product_Type_Category": data["Product_Type_Category"] | |
| } | |
| input_df = pd.DataFrame([sample]) | |
| # Generate and return prediction | |
| prediction = float(model.predict(input_df)[0]) | |
| return jsonify({"Sales": round(prediction, 2)}), 200 | |
| except Exception as e: | |
| return jsonify({"error": f"Internal server error: {str(e)}"}), 500 | |
| if __name__ == "__main__": | |
| superkart_api.run(debug=True) | |