File size: 2,583 Bytes
c9b211d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Import necessary libraries
import numpy as np
import joblib  # For loading the serialized model
import pandas as pd  # For data manipulation
from flask import Flask, request, jsonify  # For creating the Flask API

# Initialize the Flask application
app = Flask("SuperKart Sales Predictor")

# Load the trained model
model = joblib.load("superkart_prediction_model_v1_0.joblib")

# --- Home route ---
@app.get('/')
def home():
    """
    Health check endpoint for the API.
    """
    return "Welcome to the SuperKart Sales Prediction API!"


# --- Single prediction endpoint ---
@app.post('/v1/predict')
def predict_sales_revenue():
    """
    Predicts total sales revenue for a single product/store combination.
    Expects a JSON payload with product and store details.
    """
    try:
        # Parse JSON input
        data = request.get_json()

        # Extract features
        sample = {
            'Product_Weight': data['Product_Weight'],
            'Product_Sugar_Content': data['Product_Sugar_Content'],
            'Product_Allocated_Area': data['Product_Allocated_Area'],
            'Product_Type': data['Product_Type'],
            'Product_MRP': data['Product_MRP'],
            'Store_Size': data['Store_Size'],
            'Store_Location_City_Type': data['Store_Location_City_Type'],
            'Store_Type': data['Store_Type'],
            'store_age': data['store_age']
        }

        # Convert to DataFrame
        input_df = pd.DataFrame([sample])

        # Make prediction
        prediction = model.predict(input_df)[0]

        # Return result
        return jsonify({'Predicted_Sales_Revenue': round(float(prediction), 2)})

    except Exception as e:
        return jsonify({'error': str(e)}), 400


# --- Batch prediction endpoint ---
@app.post('/v1/predictbatch')
def predict_sales_batch():
    """
    Predicts total sales revenue for multiple entries from a CSV file.
    Expects a file upload under the key 'file'.
    """
    try:
        file = request.files['file']
        input_data = pd.read_csv(file)

        # Generate predictions
        predictions = model.predict(input_data).tolist()
        predictions = [round(float(p), 2) for p in predictions]

        # Include IDs if present
        if 'id' in input_data.columns:
            ids = input_data['id'].tolist()
            result = dict(zip(ids, predictions))
        else:
            result = {'Prediction_' + str(i + 1): p for i, p in enumerate(predictions)}

        return jsonify(result)

    except Exception as e:
        return jsonify({'error': str(e)}), 400