Spaces:
Sleeping
Sleeping
| """ | |
| SuperKart Sales Prediction Flask API | |
| This Flask application provides a REST API for predicting product sales using a pre-trained | |
| Random Forest model. The API accepts product and store features and returns predicted sales revenue. | |
| """ | |
| import os | |
| import joblib | |
| import pandas as pd | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import logging | |
| from typing import Any, Dict | |
| from pydantic import BaseModel, ValidationError, field_validator | |
| from datetime import datetime | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for frontend integration | |
| # Global variables for model and preprocessing pipeline | |
| model = None | |
| feature_columns = None | |
| # Define user input features (what user provides) | |
| USER_INPUT_FEATURES = [ | |
| "Product_Weight", | |
| "Product_Sugar_Content", | |
| "Product_Allocated_Area", | |
| "Product_Type", | |
| "Product_MRP", | |
| "Store_Establishment_Year", | |
| "Store_Size", | |
| "Store_Location_City_Type", | |
| "Store_Type", | |
| ] | |
| # Define model features (what model expects after preprocessing) | |
| MODEL_FEATURES = [ | |
| "Product_Weight", | |
| "Product_Sugar_Content", | |
| "Product_Allocated_Area", | |
| "Product_Type", | |
| "Product_MRP", | |
| "Store_Size", | |
| "Store_Location_City_Type", | |
| "Store_Type", | |
| "Store_Age", | |
| ] | |
| # Pydantic model for input validation | |
| class PredictionInput(BaseModel): | |
| Product_Weight: float | |
| Product_Sugar_Content: str | |
| Product_Allocated_Area: float | |
| Product_Type: str | |
| Product_MRP: float | |
| Store_Establishment_Year: int | |
| Store_Size: str | |
| Store_Location_City_Type: str | |
| Store_Type: str | |
| def validate_product_weight(cls, v: float) -> float: | |
| if v <= 0: | |
| raise ValueError("Product_Weight must be greater than 0") | |
| if v < 4.0 or v > 22.0: | |
| raise ValueError("Product_Weight must be between 4.0 and 22.0") | |
| return v | |
| def validate_allocated_area(cls, v: float) -> float: | |
| if v < 0 or v > 1: | |
| raise ValueError("Product_Allocated_Area must be between 0 and 1") | |
| return v | |
| def validate_mrp(cls, v: float) -> float: | |
| if v <= 0: | |
| raise ValueError("Product_MRP must be greater than 0") | |
| if v < 31.0 or v > 266.0: | |
| raise ValueError("Product_MRP must be between 31.0 and 266.0") | |
| return v | |
| def validate_establishment_year(cls, v: int) -> int: | |
| valid_years = [1987, 1998, 1999, 2009] | |
| if v not in valid_years: | |
| raise ValueError(f"Store_Establishment_Year must be one of: {valid_years}") | |
| return v | |
| def validate_sugar_content(cls, v: str) -> str: | |
| valid = ["Low Sugar", "Regular", "No Sugar"] | |
| if v not in valid: | |
| raise ValueError(f"Product_Sugar_Content must be one of: {valid}") | |
| return v | |
| def validate_product_type(cls, v: str) -> str: | |
| valid = [ | |
| "Dairy", | |
| "Soft Drinks", | |
| "Meat", | |
| "Fruits and Vegetables", | |
| "Household", | |
| "Baking Goods", | |
| "Snack Foods", | |
| "Frozen Foods", | |
| "Breakfast", | |
| "Health and Hygiene", | |
| "Hard Drinks", | |
| "Canned", | |
| "Bread", | |
| "Starchy Foods", | |
| "Others", | |
| "Seafood", | |
| ] | |
| if v not in valid: | |
| raise ValueError(f"Product_Type must be one of: {valid}") | |
| return v | |
| def validate_store_size(cls, v: str) -> str: | |
| valid = ["Small", "Medium", "High"] | |
| if v not in valid: | |
| raise ValueError(f"Store_Size must be one of: {valid}") | |
| return v | |
| def validate_city_type(cls, v: str) -> str: | |
| valid = ["Tier 1", "Tier 2", "Tier 3"] | |
| if v not in valid: | |
| raise ValueError(f"Store_Location_City_Type must be one of: {valid}") | |
| return v | |
| def validate_store_type(cls, v: str) -> str: | |
| valid = [ | |
| "Supermarket Type1", | |
| "Supermarket Type2", | |
| "Supermarket Type3", | |
| "Departmental Store", | |
| "Food Mart", | |
| ] | |
| if v not in valid: | |
| raise ValueError(f"Store_Type must be one of: {valid}") | |
| return v | |
| def load_model(model_path: str): | |
| """ | |
| Load the trained model from the specified path. | |
| Args: | |
| model_path (str): Path to the model file. | |
| Returns: | |
| bool: True if model loaded successfully, False otherwise. | |
| """ | |
| global model, feature_columns | |
| try: | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found at: {model_path}") | |
| # Load the trained model (which includes preprocessing pipeline) | |
| model = joblib.load(model_path) | |
| logger.info(f"β Model loaded successfully from: {model_path}") | |
| # Set feature columns | |
| feature_columns = MODEL_FEATURES | |
| logger.info(f"π Model features: {MODEL_FEATURES}") | |
| logger.info(f"π User input features: {USER_INPUT_FEATURES}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Error loading model: {str(e)}") | |
| return False | |
| def convert_establishment_year_to_age(data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Convert Store_Establishment_Year to Store_Age.""" | |
| # Create a copy to avoid modifying the original | |
| converted_data = data.copy() | |
| # Get current year | |
| current_year = datetime.now().year | |
| # Convert establishment year to age | |
| if "Store_Establishment_Year" in converted_data: | |
| establishment_year = converted_data.pop("Store_Establishment_Year") | |
| converted_data["Store_Age"] = current_year - establishment_year | |
| return converted_data | |
| def preprocess_input(data: Dict[str, Any]) -> pd.DataFrame: | |
| """Convert input data to DataFrame format expected by the model.""" | |
| # First convert establishment year to age | |
| converted_data = convert_establishment_year_to_age(data) | |
| # Create DataFrame with model features | |
| df = pd.DataFrame([converted_data]) | |
| df = df[MODEL_FEATURES] | |
| return df | |
| def health_check(): | |
| """Health check endpoint.""" | |
| return jsonify( | |
| { | |
| "status": "healthy", | |
| "message": "SuperKart Sales Prediction API is running", | |
| "model_loaded": model is not None, | |
| } | |
| ) | |
| def predict(): | |
| """Predict sales for given product and store features.""" | |
| if model is None: | |
| return jsonify({"error": "Model not loaded. Please check server logs."}), 500 | |
| try: | |
| # Get JSON data from request | |
| data = request.get_json() | |
| if not data: | |
| return jsonify( | |
| { | |
| "error": "No data provided. Please send JSON data in the request body." | |
| } | |
| ), 400 | |
| # Validate input using Pydantic | |
| try: | |
| validated = PredictionInput(**data) | |
| except ValidationError as ve: | |
| return jsonify( | |
| {"error": "Input validation failed", "details": ve.errors()} | |
| ), 400 | |
| # Preprocess input data | |
| input_df = preprocess_input(validated.model_dump()) | |
| # Make prediction | |
| prediction = model.predict(input_df) | |
| predicted_sales = float(prediction[0]) | |
| # Prepare response | |
| response = { | |
| "predicted_sales": round(predicted_sales, 2), | |
| "currency": "USD", | |
| "input_features": validated.model_dump(), | |
| "status": "success", | |
| } | |
| logger.info(f"β Prediction successful: ${predicted_sales:.2f}") | |
| return jsonify(response) | |
| except Exception as e: | |
| logger.error(f"β Prediction error: {str(e)}") | |
| return jsonify({"error": f"Prediction failed: {str(e)}"}), 500 | |
| def get_features(): | |
| """Get information about expected input features.""" | |
| feature_info = { | |
| "required_features": USER_INPUT_FEATURES, | |
| "feature_descriptions": { | |
| "Product_Weight": "Weight of the product (4.0-22.0 kg)", | |
| "Product_Sugar_Content": "Sugar content (Low Sugar, Regular, No Sugar)", | |
| "Product_Allocated_Area": "Allocated display area ratio (0.0-1.0)", | |
| "Product_Type": "Product category (16 types: Dairy, Soft Drinks, Meat, etc.)", | |
| "Product_MRP": "Maximum retail price (31.0-266.0 USD)", | |
| "Store_Establishment_Year": "Year store was established (1987, 1998, 1999, 2009)", | |
| "Store_Size": "Store size (Small, Medium, High)", | |
| "Store_Location_City_Type": "City type (Tier 1, Tier 2, Tier 3)", | |
| "Store_Type": "Store type (Supermarket Type1/2/3, Departmental Store, Food Mart)", | |
| }, | |
| "example_input": { | |
| "Product_Weight": 12.66, | |
| "Product_Sugar_Content": "Low Sugar", | |
| "Product_Allocated_Area": 0.027, | |
| "Product_Type": "Frozen Foods", | |
| "Product_MRP": 117.08, | |
| "Store_Establishment_Year": 2009, | |
| "Store_Size": "Medium", | |
| "Store_Location_City_Type": "Tier 2", | |
| "Store_Type": "Supermarket Type2", | |
| }, | |
| } | |
| return jsonify(feature_info) | |
| def predict_batch(): | |
| """Predict sales for multiple products at once.""" | |
| if model is None: | |
| return jsonify({"error": "Model not loaded. Please check server logs."}), 500 | |
| try: | |
| # Get JSON data from request | |
| data = request.get_json() | |
| if not data or "predictions" not in data: | |
| return jsonify( | |
| { | |
| "error": 'No data provided. Please send JSON with "predictions" array.' | |
| } | |
| ), 400 | |
| predictions_data = data["predictions"] | |
| if not isinstance(predictions_data, list): | |
| return jsonify({"error": "Predictions must be an array of objects."}), 400 | |
| results = [] | |
| errors = [] | |
| for i, item in enumerate(predictions_data): | |
| try: | |
| # Validate input using Pydantic | |
| try: | |
| validated = PredictionInput(**item) | |
| except ValidationError as ve: | |
| errors.append({"index": i, "error": ve.errors(), "input": item}) | |
| continue | |
| # Preprocess and predict | |
| input_df = preprocess_input(validated.model_dump()) | |
| prediction = model.predict(input_df) | |
| predicted_sales = float(prediction[0]) | |
| results.append( | |
| { | |
| "index": i, | |
| "predicted_sales": round(predicted_sales, 2), | |
| "input_features": validated.model_dump(), | |
| } | |
| ) | |
| except Exception as e: | |
| errors.append({"index": i, "error": str(e), "input": item}) | |
| response = { | |
| "successful_predictions": len(results), | |
| "failed_predictions": len(errors), | |
| "results": results, | |
| "errors": errors, | |
| "status": "completed", | |
| } | |
| logger.info( | |
| f"β Batch prediction completed: {len(results)} successful, {len(errors)} failed" | |
| ) | |
| return jsonify(response) | |
| except Exception as e: | |
| logger.error(f"β Batch prediction error: {str(e)}") | |
| return jsonify({"error": f"Batch prediction failed: {str(e)}"}), 500 | |
| # Load model on module import (for Gunicorn compatibility) | |
| if not load_model("./superkart_model.joblib"): | |
| logger.error("β Failed to load model. Application may not work properly.") | |
| if __name__ == "__main__": | |
| # This runs only when script is executed directly (not imported by Gunicorn) | |
| logger.info("π Starting SuperKart Sales Prediction API...") | |
| app.run(host="0.0.0.0", port=7860, debug=True) | |