# Import necessary libraries import numpy as np import joblib # For loading the serialized model import pandas as pd # For data manipulation from flask import Flask, request, jsonify # For creating the Flask API # Initialize the Flask application app = Flask("SuperKart Sales Predictor") # Load the trained model model = joblib.load("superkart_prediction_model_v1_0.joblib") # --- Home route --- @app.get('/') def home(): """ Health check endpoint for the API. """ return "Welcome to the SuperKart Sales Prediction API!" # --- Single prediction endpoint --- @app.post('/v1/predict') def predict_sales_revenue(): """ Predicts total sales revenue for a single product/store combination. Expects a JSON payload with product and store details. """ try: # Parse JSON input data = request.get_json() # Extract features sample = { 'Product_Weight': data['Product_Weight'], 'Product_Sugar_Content': data['Product_Sugar_Content'], 'Product_Allocated_Area': data['Product_Allocated_Area'], 'Product_Type': data['Product_Type'], 'Product_MRP': data['Product_MRP'], 'Store_Size': data['Store_Size'], 'Store_Location_City_Type': data['Store_Location_City_Type'], 'Store_Type': data['Store_Type'], 'store_age': data['store_age'] } # Convert to DataFrame input_df = pd.DataFrame([sample]) # Make prediction prediction = model.predict(input_df)[0] # Return result return jsonify({'Predicted_Sales_Revenue': round(float(prediction), 2)}) except Exception as e: return jsonify({'error': str(e)}), 400 # --- Batch prediction endpoint --- @app.post('/v1/predictbatch') def predict_sales_batch(): """ Predicts total sales revenue for multiple entries from a CSV file. Expects a file upload under the key 'file'. """ try: file = request.files['file'] input_data = pd.read_csv(file) # Generate predictions predictions = model.predict(input_data).tolist() predictions = [round(float(p), 2) for p in predictions] # Include IDs if present if 'id' in input_data.columns: ids = input_data['id'].tolist() result = dict(zip(ids, predictions)) else: result = {'Prediction_' + str(i + 1): p for i, p in enumerate(predictions)} return jsonify(result) except Exception as e: return jsonify({'error': str(e)}), 400