File size: 3,512 Bytes
0148d6b
644c847
 
 
 
 
 
 
 
 
 
cdefe97
455e0c5
644c847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146438a
cd620ca
cdefe97
 
d87bf7f
6d4e93b
d87bf7f
 
 
 
 
 
cd620ca
6d4e93b
d87bf7f
 
6d4e93b
d87bf7f
cd620ca
6d4e93b
d87bf7f
 
644c847
 
cdefe97
644c847
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

import numpy as np
import joblib
import pandas as pd
from flask import Flask, request, jsonify

# Initialize Flask app
superkart_api = Flask("superkart_sales_api")

# Load the trained model (must be in same folder as app.py)
try:
    # This assumes 'superkart_prediction.joblib' is in the same directory as app.py
    model = joblib.load("superkart_prediction.joblib")
    print("✅ Model loaded successfully.")
except Exception as e:
    print("❌ Model load failed:", e)
    raise e

# Health check to show Backend is running
@superkart_api.get('/')
def home():
    return "✅ You are on Sales Prediction API for SuperKart"

# Prediction endpoint
@superkart_api.post('/v1/predict')
def predict_sales():
    try:
        data = request.get_json()
        if data is None:
            return jsonify({'error': "No JSON payload received"}), 400

        print("Raw incoming data:", data)

        required_fields = [
            'Product_Id_char',
            'Product_Weight',
            'Product_Sugar_Content',
            'Product_Allocated_Area',
            'Product_MRP',
            'Store_Size',
            'Store_Location_City_Type',
            'Store_Type',
            'Store_Age_Years',
            'Product_Type_Category'
        ]

        missing_fields = [f for f in required_fields if f not in data]
        if missing_fields:
            return jsonify({'error': f"Missing fields: {missing_fields}"}), 400

        sample = {
            'Product_Id_char': data['Product_Id_char'],
            'Product_Weight': float(data['Product_Weight']),
            'Product_Sugar_Content': data['Product_Sugar_Content'],
            'Product_Allocated_Area': np.log1p(float(data['Product_Allocated_Area'])),
            'Product_MRP': float(data['Product_MRP']),
            'Store_Size': data['Store_Size'],
            'Store_Location_City_Type': data['Store_Location_City_Type'],
            'Store_Type': data['Store_Type'],
            'Store_Age_Years': int(data['Store_Age_Years']),
            'Product_Type_Category': data['Product_Type_Category']
        }

        input_df = pd.DataFrame([sample])
        print("Transformed input for model:\n", input_df)

        prediction = model.predict(input_df).tolist()[0]
        return jsonify({'Predicted_Sales': prediction})

    except Exception as e:
        print("❌ Error during prediction:", str(e))
        return jsonify({'error': f"Prediction failed: {str(e)}"}), 500

# BATCH SALES PREDICTION
# Corrected decorator to use the `superkart_api` instance
@superkart_api.route("/v1/sales_batch", methods=["POST"])
def predict_sales_batch():
    try:
        file = request.files.get("file")
        if file is None:
            return jsonify({"error": "No CSV file uploaded under key 'file'"}), 400
        df = pd.read_csv(file)
        log_preds = model.predict(df).tolist()
        predictions = [round(float(np.exp(p)), 2) for p in log_preds]
        id_col = next((c for c in ("id", "ID", "Product_Id") if c in df.columns), None)
        if id_col:
            ids = df[id_col].astype(str).tolist()
            result = dict(zip(ids, predictions))
        else:
            result = {str(i): predictions[i] for i in range(len(predictions))}
        return jsonify({"predictions": result}), 200
    except Exception as e:
        return jsonify({"error": str(e)}), 400

# Local testing
if __name__ == '__main__':
    # This will still use superkart_api for local runs
    superkart_api.run(debug=True, host='0.0.0.0', port=7860)