File size: 3,901 Bytes
5fc6306
f969e9d
5fc6306
 
 
f969e9d
 
70d620a
f969e9d
5fc6306
 
 
 
f969e9d
 
 
5fc6306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cbb02b
5fc6306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cbb02b
5fc6306
6cbb02b
 
 
5fc6306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb8e0f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Backend_files/app.y
import joblib
import pandas as pd
from flask import Flask, request, jsonify
from flask_cors import CORS
from sklearn.base import BaseEstimator, TransformerMixin

from custom_transformer import StoreAgeAdder, OutlierCapper

# Initialize flas app with name
Superkart_Sales_Predictor_API = Flask("Superkart Sales Predictor")
CORS(Superkart_Sales_Predictor_API)  # Enable CORS for frontend integration (optional)

# Load the trained model
Random_Forest_Loaded_Model = joblib.load('Random_Forest_Model.pkl')

# Define a route for the home page (GET request)
@Superkart_Sales_Predictor_API.get('/')
def home():
    """
    This function handles GET requests to the root URL ('/') of the API.
    It returns a simple welcome message.
    """
    return "Welcome to SuperKart Sales Predictor API!"

# Define an endpoint to predict the Superkart sales
@Superkart_Sales_Predictor_API.post('/predict')
def predict():
    try:
        # Parse input JSON
        Product_And_Store_data = request.get_json()

        # Validate input
        required_fields = [
            'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
            'Product_Type', 'Product_MRP', 'Store_Id', 'Store_Establishment_Year',
            'Store_Size', 'Store_Location_City_Type', 'Store_Type'
        ]
        missing = [f for f in required_fields if f not in Product_And_Store_data]
        if missing:
            return jsonify({"error": f"Missing fields: {missing}"}), 400

        # Convert to DataFrame
        sample_df = pd.DataFrame([{
            'Product_Weight': Product_And_Store_data['Product_Weight'],
            'Product_Sugar_Content': Product_And_Store_data['Product_Sugar_Content'],
            'Product_Allocated_Area': Product_And_Store_data['Product_Allocated_Area'],
            'Product_Type': Product_And_Store_data['Product_Type'],
            'Product_MRP': Product_And_Store_data['Product_MRP'],
            'Store_Id': Product_And_Store_data['Store_Id'],
            'Store_Establishment_Year': Product_And_Store_data['Store_Establishment_Year'],
            'Store_Size': Product_And_Store_data['Store_Size'],
            'Store_Location_City_Type': Product_And_Store_data['Store_Location_City_Type'],
            'Store_Type': Product_And_Store_data['Store_Type']
        }])

        # Predict
        prediction = Random_Forest_Loaded_Model.predict(sample_df)

        # Return response
        return jsonify({'prediction': prediction.tolist()})

    except Exception as e:
        return jsonify({"error": str(e)}), 500

@Superkart_Sales_Predictor_API.post('/predict_batch')
def predict_batch():
    try:
        # Parse JSON input - should be a list of dicts
        batch_data = request.get_json()

        # Validate input type
        if not isinstance(batch_data, list):
            return jsonify({"error": "Input must be a list of records"}), 400

        required_fields = [
            'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
            'Product_Type', 'Product_MRP', 'Store_Id', 'Store_Establishment_Year',
            'Store_Size', 'Store_Location_City_Type', 'Store_Type'
        ]

        # Check each record
        for i, record in enumerate(batch_data):
            missing = [f for f in required_fields if f not in record]
            if missing:
                return jsonify({"error": f"Missing fields in record {i}: {missing}"}), 400

        # Convert list of dicts to DataFrame
        df = pd.DataFrame(batch_data)

        # Predict
        predictions = Random_Forest_Loaded_Model.predict(df)

        # Return list of predictions
        return jsonify({'predictions': predictions.tolist()})

    except Exception as e:
        return jsonify({"error": str(e)}), 500

# Run flask in debug mode
if __name__ == '__main__':
    Superkart_Sales_Predictor_API.run(debug=False, host='0.0.0.0', port=7860)