File size: 12,684 Bytes
590b0f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9c1f61
590b0f8
a9c1f61
590b0f8
 
 
 
 
 
 
a9c1f61
 
 
 
590b0f8
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
from flask import Flask, request, jsonify
import joblib
import numpy as np
import pandas as pd
from flask_cors import CORS
import logging
from datetime import datetime
import os
import traceback

# Initialize Flask app
app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Global variables for model and preprocessor
model = None
preprocessor = None
model_artifacts = None

def load_model():
    """Load the trained model and preprocessing artifacts."""
    global model, preprocessor, model_artifacts
    
    try:
        model_path = 'superkart_sales_forecasting_model.joblib'
        
        if not os.path.exists(model_path):
            logger.error(f"Model file not found: {model_path}")
            return False
        
        # Load model artifacts
        model_artifacts = joblib.load(model_path)
        model = model_artifacts['model']
        preprocessor = model_artifacts['preprocessor']
        
        logger.info(f"Model loaded successfully: {model_artifacts['model_name']}")
        logger.info(f"Training date: {model_artifacts['training_date']}")
        
        return True
        
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        return False

def validate_input_data(data):
    """Validate input data for prediction."""
    required_fields = [
        'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
        'Product_Type', 'Product_MRP', 'Store_Size', 
        'Store_Location_City_Type', 'Store_Type', 'Store_Age'
    ]
    
    # Check if all required fields are present
    missing_fields = [field for field in required_fields if field not in data]
    if missing_fields:
        return False, f"Missing required fields: {missing_fields}"
    
    # Validate data types and ranges
    try:
        # Numerical validations
        if not isinstance(data['Product_Weight'], (int, float)) or data['Product_Weight'] <= 0:
            return False, "Product_Weight must be a positive number"
        
        if not isinstance(data['Product_Allocated_Area'], (int, float)) or not (0 <= data['Product_Allocated_Area'] <= 1):
            return False, "Product_Allocated_Area must be between 0 and 1"
        
        if not isinstance(data['Product_MRP'], (int, float)) or data['Product_MRP'] <= 0:
            return False, "Product_MRP must be a positive number"
        
        if not isinstance(data['Store_Age'], (int, float)) or data['Store_Age'] < 0:
            return False, "Store_Age must be a non-negative number"
        
        # Categorical validations
        valid_sugar_content = ['Low Sugar', 'Regular', 'No Sugar']
        if data['Product_Sugar_Content'] not in valid_sugar_content:
            return False, f"Product_Sugar_Content must be one of: {valid_sugar_content}"
        
        valid_store_sizes = ['Small', 'Medium', 'High']
        if data['Store_Size'] not in valid_store_sizes:
            return False, f"Store_Size must be one of: {valid_store_sizes}"
        
        valid_city_types = ['Tier 1', 'Tier 2', 'Tier 3']
        if data['Store_Location_City_Type'] not in valid_city_types:
            return False, f"Store_Location_City_Type must be one of: {valid_city_types}"
        
        valid_store_types = ['Departmental Store', 'Supermarket Type1', 'Supermarket Type2', 'Food Mart']
        if data['Store_Type'] not in valid_store_types:
            return False, f"Store_Type must be one of: {valid_store_types}"
        
        return True, "Validation passed"
        
    except Exception as e:
        return False, f"Validation error: {str(e)}"

def preprocess_for_prediction(data):
    """Preprocess input data for model prediction."""
    try:
        # Convert to DataFrame
        if isinstance(data, dict):
            df = pd.DataFrame([data])
        else:
            df = pd.DataFrame(data)
        
        # Feature engineering functions (must match training)
        def categorize_mrp(mrp):
            if mrp <= 69.0:
                return 'Low'
            elif mrp <= 136.0:
                return 'Medium_Low'
            elif mrp <= 202.0:
                return 'Medium_High'
            else:
                return 'High'
        
        def categorize_weight(weight):
            if weight <= 8.773:
                return 'Light'
            elif weight <= 12.89:
                return 'Medium_Light'
            elif weight <= 16.95:
                return 'Medium_Heavy'
            else:
                return 'Heavy'
        
        def categorize_store_age(age):
            if age <= 20:
                return 'New'
            elif age <= 30:
                return 'Established'
            else:
                return 'Legacy'
        
        # Add engineered features
        df['Product_MRP_Category'] = df['Product_MRP'].apply(categorize_mrp)
        df['Product_Weight_Category'] = df['Product_Weight'].apply(categorize_weight)
        df['Store_Age_Category'] = df['Store_Age'].apply(categorize_store_age)
        df['City_Store_Type'] = df['Store_Location_City_Type'] + '_' + df['Store_Type']
        df['Size_Type_Interaction'] = df['Store_Size'] + '_' + df['Store_Type']
        
        # Transform using the preprocessing pipeline
        processed_data = preprocessor.transform(df)
        
        return processed_data, None
        
    except Exception as e:
        return None, str(e)

@app.route('/', methods=['GET'])
def home():
    """Home endpoint with API information."""
    api_info = {
        "message": "SuperKart Sales Forecasting API",
        "version": "1.0",
        "model_info": {
            "name": model_artifacts['model_name'] if model_artifacts else "Model not loaded",
            "training_date": model_artifacts['training_date'] if model_artifacts else "Unknown",
            "version": model_artifacts['model_version'] if model_artifacts else "Unknown"
        } if model_artifacts else {"status": "Model not loaded"},
        "endpoints": {
            "/": "API information",
            "/health": "Health check",
            "/predict": "Single prediction (POST)",
            "/batch_predict": "Batch predictions (POST)",
            "/model_info": "Model details"
        },
        "sample_input": {
            "Product_Weight": 10.5,
            "Product_Sugar_Content": "Low Sugar",
            "Product_Allocated_Area": 0.15,
            "Product_Type": "Fruits and Vegetables",
            "Product_MRP": 150.0,
            "Store_Size": "Medium",
            "Store_Location_City_Type": "Tier 2",
            "Store_Type": "Supermarket Type2",
            "Store_Age": 15
        }
    }
    return jsonify(api_info)

@app.route('/health', methods=['GET'])
def health_check():
    """Health check endpoint."""
    health_status = {
        "status": "healthy" if model is not None else "unhealthy",
        "model_loaded": model is not None,
        "timestamp": datetime.now().isoformat(),
        "service": "SuperKart Sales Forecasting API"
    }
    return jsonify(health_status)

@app.route('/model_info', methods=['GET'])
def model_info():
    """Get detailed model information."""
    if model_artifacts is None:
        return jsonify({"error": "Model not loaded"}), 500
    
    info = {
        "model_name": model_artifacts['model_name'],
        "training_date": model_artifacts['training_date'],
        "model_version": model_artifacts['model_version'],
        "performance_metrics": model_artifacts['performance_metrics'],
        "feature_count": len(model_artifacts['feature_names']),
        "model_type": type(model).__name__
    }
    
    return jsonify(info)

@app.route('/predict', methods=['POST'])
def predict():
    """Single prediction endpoint."""
    try:
        # Get JSON data from request
        data = request.get_json()
        
        if data is None:
            return jsonify({"error": "No JSON data provided"}), 400
        
        # Validate input data
        is_valid, validation_message = validate_input_data(data)
        if not is_valid:
            return jsonify({"error": validation_message}), 400
        
        # Preprocess data
        processed_data, error = preprocess_for_prediction(data)
        if error:
            return jsonify({"error": f"Preprocessing failed: {error}"}), 400
        
        # Make prediction
        prediction = model.predict(processed_data)[0]
        
        # Prepare response
        response = {
            "prediction": float(prediction),
            "input_data": data,
            "model_info": {
                "model_name": model_artifacts['model_name'],
                "prediction_timestamp": datetime.now().isoformat()
            }
        }
        
        logger.info(f"Prediction made: {prediction:.2f}")
        return jsonify(response)
        
    except Exception as e:
        logger.error(f"Prediction error: {str(e)}")
        logger.error(f"Traceback: {traceback.format_exc()}")
        return jsonify({"error": f"Prediction failed: {str(e)}"}), 500

@app.route('/batch_predict', methods=['POST'])
def batch_predict():
    """Batch prediction endpoint."""
    try:
        # Get JSON data from request
        data = request.get_json()
        
        if data is None:
            return jsonify({"error": "No JSON data provided"}), 400
        
        # Ensure data is a list
        if not isinstance(data, list):
            return jsonify({"error": "Data must be a list of records"}), 400
        
        if len(data) == 0:
            return jsonify({"error": "Empty data list provided"}), 400
        
        predictions = []
        errors = []
        
        for i, record in enumerate(data):
            try:
                # Validate input data
                is_valid, validation_message = validate_input_data(record)
                if not is_valid:
                    errors.append(f"Record {i}: {validation_message}")
                    predictions.append(None)
                    continue
                
                # Preprocess data
                processed_data, error = preprocess_for_prediction(record)
                if error:
                    errors.append(f"Record {i}: Preprocessing failed - {error}")
                    predictions.append(None)
                    continue
                
                # Make prediction
                prediction = model.predict(processed_data)[0]
                predictions.append(float(prediction))
                
            except Exception as e:
                errors.append(f"Record {i}: {str(e)}")
                predictions.append(None)
        
        # Prepare response
        response = {
            "predictions": predictions,
            "total_records": len(data),
            "successful_predictions": len([p for p in predictions if p is not None]),
            "errors": errors if errors else None,
            "model_info": {
                "model_name": model_artifacts['model_name'],
                "prediction_timestamp": datetime.now().isoformat()
            }
        }
        
        logger.info(f"Batch prediction completed: {len(predictions)} records processed")
        return jsonify(response)
        
    except Exception as e:
        logger.error(f"Batch prediction error: {str(e)}")
        return jsonify({"error": f"Batch prediction failed: {str(e)}"}), 500

# Initialize the model when the app starts (Flask 3.x compatible)
def initialize():
    """Initialize the model on app startup."""
    logger.info("Initializing SuperKart Sales Forecasting API...")
    success = load_model()
    if success:
        logger.info("API initialization completed successfully")
    else:
        logger.error("API initialization failed - model could not be loaded")

# Call initialization immediately when module loads
with app.app_context():
    initialize()

if __name__ == '__main__':
    # Load model
    if load_model():
        print("[SUCCESS] Model loaded successfully")
        print("[STARTING] SuperKart Sales Forecasting API...")
        app.run(host='0.0.0.0', port=8080, debug=False)
    else:
        print("[ERROR] Failed to load model. Please check model file.")