File size: 3,581 Bytes
72bce8e
 
 
48bfd00
 
1c2785f
48bfd00
 
1c2785f
48bfd00
1c2785f
 
 
 
 
 
 
48bfd00
 
 
 
72bce8e
 
48bfd00
7dce041
72bce8e
48bfd00
1c2785f
48bfd00
72bce8e
 
 
 
1c2785f
 
 
72bce8e
 
 
 
48bfd00
 
 
 
72bce8e
48bfd00
72bce8e
 
1c2785f
 
48bfd00
 
 
1c2785f
48bfd00
 
 
72bce8e
 
 
1c2785f
72bce8e
 
1c2785f
 
 
b3ffc44
72bce8e
 
73544cf
 
 
 
 
 
48bfd00
06b2bff
4ca11a2
48bfd00
 
72bce8e
 
48bfd00
 
 
1c2785f
48bfd00
 
 
 
72bce8e
48bfd00
8b86899
72bce8e
1ff2e82
72bce8e
 
1ff2e82
 
48bfd00
72bce8e
 
 
 
 
1c2785f
72bce8e
 
48bfd00
 
 
72bce8e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import pandas as pd
from flask import Flask, request, jsonify
import joblib
import numpy as np
import logging
import sys

# -----------------------
# Setup logger (stdout captured by HF Spaces)
# -----------------------
logger = logging.getLogger("capacity_logger")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)  # important for HF Spaces
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)

# -----------------------
# Initialize Flask app
# -----------------------
app = Flask("Store Capacity Predictor")

# Load the trained pipeline
pipeline = joblib.load("catbooster_model_v1_0.joblib")

# -----------------------
# Home route
# -----------------------
@app.get('/')
def home():
    return "Welcome to the Store Capacity Prediction API"

# -----------------------
# Single prediction
# -----------------------
@app.post('/v1/predict')
def predict_capacity():
    try:
        sales_data = request.get_json()
        input_data = pd.DataFrame([sales_data])
        input_data["Date"] = pd.to_datetime(input_data["Date"])

        logger.info("Single prediction input:\n%s", input_data)

        # Predict
        prediction = pipeline.predict(input_data).tolist()[0]

        # Sanitize prediction
        if not np.isfinite(prediction):
            logger.warning("Single prediction invalid (%s), replacing with 0", prediction)
            prediction = 0
        else:
            prediction = int(np.clip(prediction, 0, 10000))

        logger.info("Single prediction output: %s", prediction)

        return jsonify({'Predicted_Capacity': prediction})

    except Exception as e:
        logger.error("Error in single prediction: %s", e, exc_info=True)
        return jsonify({'error': str(e)}), 400

# -----------------------
# Batch prediction
# -----------------------
@app.post('/v1/predict_batch')
def predict_capacity_batch():
    try:
        # Get the uploaded CSV file from the request
        file = request.files['file']

        # Read the file into a DataFrame
        input_data = pd.read_csv(file)
        #input_data = pd.DataFrame(data_list)
        input_data["Date"] = pd.to_datetime(input_data["Date"])
        input_data["SpecialEvent"] = input_data["SpecialEvent"].fillna("").astype(str)
      
        logger.info("Batch input shape: %s", input_data.shape)
        logger.info("Batch input preview:\n%s", input_data.head())

        predictions = pipeline.predict(input_data).tolist()
        clean_predictions = []

        for idx, p in enumerate(predictions):
            if not np.isfinite(p):
                logger.warning("Row %d prediction invalid (%s), replacing with 0", idx, p)
                clean_predictions.append(0)
            else:
                clean_predictions.append(int(np.clip(p, 0, 10000)))

        logger.info("Batch predictions: %s", clean_predictions)

        output_df = pd.DataFrame({
            "BU": input_data["BU"],
            "Date": input_data["Date"],
            "Store": input_data["Store"],
            "LocationType": input_data["LocationType"],
            "Slot": input_data["Slot"],
            "Predicted_Capacity": clean_predictions
        })

        return output_df.to_html(index=False)

    except Exception as e:
        logger.error("Error in batch prediction: %s", e, exc_info=True)
        return jsonify({"error": str(e)}), 400

# -----------------------
# Run Flask app
# -----------------------
if __name__ == '__main__':
    app.run(debug=True)