itsjarvis commited on
Commit
ce22946
·
verified ·
1 Parent(s): cc29c74

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +337 -0
app.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import joblib
3
+ import numpy as np
4
+ import pandas as pd
5
+ from flask_cors import CORS
6
+ import logging
7
+ from datetime import datetime
8
+ import os
9
+ import traceback
10
+
11
+ # Initialize Flask app
12
+ app = Flask(__name__)
13
+ CORS(app) # Enable CORS for all routes
14
+
15
+ # Configure logging
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Global variables for model and preprocessor
23
+ model = None
24
+ preprocessor = None
25
+ model_artifacts = None
26
+
27
+ def load_model():
28
+ """Load the trained model and preprocessing artifacts."""
29
+ global model, preprocessor, model_artifacts
30
+
31
+ try:
32
+ model_path = 'superkart_sales_forecasting_model.joblib'
33
+
34
+ if not os.path.exists(model_path):
35
+ logger.error(f"Model file not found: {model_path}")
36
+ return False
37
+
38
+ # Load model artifacts
39
+ model_artifacts = joblib.load(model_path)
40
+ model = model_artifacts['model']
41
+ preprocessor = model_artifacts['preprocessor']
42
+
43
+ logger.info(f"Model loaded successfully: {model_artifacts['model_name']}")
44
+ logger.info(f"Training date: {model_artifacts['training_date']}")
45
+
46
+ return True
47
+
48
+ except Exception as e:
49
+ logger.error(f"Error loading model: {str(e)}")
50
+ return False
51
+
52
+ def validate_input_data(data):
53
+ """Validate input data for prediction."""
54
+ required_fields = [
55
+ 'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
56
+ 'Product_Type', 'Product_MRP', 'Store_Size',
57
+ 'Store_Location_City_Type', 'Store_Type', 'Store_Age'
58
+ ]
59
+
60
+ # Check if all required fields are present
61
+ missing_fields = [field for field in required_fields if field not in data]
62
+ if missing_fields:
63
+ return False, f"Missing required fields: {missing_fields}"
64
+
65
+ # Validate data types and ranges
66
+ try:
67
+ # Numerical validations
68
+ if not isinstance(data['Product_Weight'], (int, float)) or data['Product_Weight'] <= 0:
69
+ return False, "Product_Weight must be a positive number"
70
+
71
+ if not isinstance(data['Product_Allocated_Area'], (int, float)) or not (0 <= data['Product_Allocated_Area'] <= 1):
72
+ return False, "Product_Allocated_Area must be between 0 and 1"
73
+
74
+ if not isinstance(data['Product_MRP'], (int, float)) or data['Product_MRP'] <= 0:
75
+ return False, "Product_MRP must be a positive number"
76
+
77
+ if not isinstance(data['Store_Age'], (int, float)) or data['Store_Age'] < 0:
78
+ return False, "Store_Age must be a non-negative number"
79
+
80
+ # Categorical validations
81
+ valid_sugar_content = ['Low Sugar', 'Regular', 'No Sugar']
82
+ if data['Product_Sugar_Content'] not in valid_sugar_content:
83
+ return False, f"Product_Sugar_Content must be one of: {valid_sugar_content}"
84
+
85
+ valid_store_sizes = ['Small', 'Medium', 'High']
86
+ if data['Store_Size'] not in valid_store_sizes:
87
+ return False, f"Store_Size must be one of: {valid_store_sizes}"
88
+
89
+ valid_city_types = ['Tier 1', 'Tier 2', 'Tier 3']
90
+ if data['Store_Location_City_Type'] not in valid_city_types:
91
+ return False, f"Store_Location_City_Type must be one of: {valid_city_types}"
92
+
93
+ valid_store_types = ['Departmental Store', 'Supermarket Type1', 'Supermarket Type2', 'Food Mart']
94
+ if data['Store_Type'] not in valid_store_types:
95
+ return False, f"Store_Type must be one of: {valid_store_types}"
96
+
97
+ return True, "Validation passed"
98
+
99
+ except Exception as e:
100
+ return False, f"Validation error: {str(e)}"
101
+
102
+ def preprocess_for_prediction(data):
103
+ """Preprocess input data for model prediction."""
104
+ try:
105
+ # Convert to DataFrame
106
+ if isinstance(data, dict):
107
+ df = pd.DataFrame([data])
108
+ else:
109
+ df = pd.DataFrame(data)
110
+
111
+ # Feature engineering functions (must match training)
112
+ def categorize_mrp(mrp):
113
+ if mrp <= 69.0:
114
+ return 'Low'
115
+ elif mrp <= 136.0:
116
+ return 'Medium_Low'
117
+ elif mrp <= 202.0:
118
+ return 'Medium_High'
119
+ else:
120
+ return 'High'
121
+
122
+ def categorize_weight(weight):
123
+ if weight <= 8.773:
124
+ return 'Light'
125
+ elif weight <= 12.89:
126
+ return 'Medium_Light'
127
+ elif weight <= 16.95:
128
+ return 'Medium_Heavy'
129
+ else:
130
+ return 'Heavy'
131
+
132
+ def categorize_store_age(age):
133
+ if age <= 20:
134
+ return 'New'
135
+ elif age <= 30:
136
+ return 'Established'
137
+ else:
138
+ return 'Legacy'
139
+
140
+ # Add engineered features
141
+ df['Product_MRP_Category'] = df['Product_MRP'].apply(categorize_mrp)
142
+ df['Product_Weight_Category'] = df['Product_Weight'].apply(categorize_weight)
143
+ df['Store_Age_Category'] = df['Store_Age'].apply(categorize_store_age)
144
+ df['City_Store_Type'] = df['Store_Location_City_Type'] + '_' + df['Store_Type']
145
+ df['Size_Type_Interaction'] = df['Store_Size'] + '_' + df['Store_Type']
146
+
147
+ # Transform using the preprocessing pipeline
148
+ processed_data = preprocessor.transform(df)
149
+
150
+ return processed_data, None
151
+
152
+ except Exception as e:
153
+ return None, str(e)
154
+
155
+ @app.route('/', methods=['GET'])
156
+ def home():
157
+ """Home endpoint with API information."""
158
+ api_info = {
159
+ "message": "SuperKart Sales Forecasting API",
160
+ "version": "1.0",
161
+ "model_info": {
162
+ "name": model_artifacts['model_name'] if model_artifacts else "Model not loaded",
163
+ "training_date": model_artifacts['training_date'] if model_artifacts else "Unknown",
164
+ "version": model_artifacts['model_version'] if model_artifacts else "Unknown"
165
+ } if model_artifacts else {"status": "Model not loaded"},
166
+ "endpoints": {
167
+ "/": "API information",
168
+ "/health": "Health check",
169
+ "/predict": "Single prediction (POST)",
170
+ "/batch_predict": "Batch predictions (POST)",
171
+ "/model_info": "Model details"
172
+ },
173
+ "sample_input": {
174
+ "Product_Weight": 10.5,
175
+ "Product_Sugar_Content": "Low Sugar",
176
+ "Product_Allocated_Area": 0.15,
177
+ "Product_Type": "Fruits and Vegetables",
178
+ "Product_MRP": 150.0,
179
+ "Store_Size": "Medium",
180
+ "Store_Location_City_Type": "Tier 2",
181
+ "Store_Type": "Supermarket Type2",
182
+ "Store_Age": 15
183
+ }
184
+ }
185
+ return jsonify(api_info)
186
+
187
+ @app.route('/health', methods=['GET'])
188
+ def health_check():
189
+ """Health check endpoint."""
190
+ health_status = {
191
+ "status": "healthy" if model is not None else "unhealthy",
192
+ "model_loaded": model is not None,
193
+ "timestamp": datetime.now().isoformat(),
194
+ "service": "SuperKart Sales Forecasting API"
195
+ }
196
+ return jsonify(health_status)
197
+
198
+ @app.route('/model_info', methods=['GET'])
199
+ def model_info():
200
+ """Get detailed model information."""
201
+ if model_artifacts is None:
202
+ return jsonify({"error": "Model not loaded"}), 500
203
+
204
+ info = {
205
+ "model_name": model_artifacts['model_name'],
206
+ "training_date": model_artifacts['training_date'],
207
+ "model_version": model_artifacts['model_version'],
208
+ "performance_metrics": model_artifacts['performance_metrics'],
209
+ "feature_count": len(model_artifacts['feature_names']),
210
+ "model_type": type(model).__name__
211
+ }
212
+
213
+ return jsonify(info)
214
+
215
+ @app.route('/predict', methods=['POST'])
216
+ def predict():
217
+ """Single prediction endpoint."""
218
+ try:
219
+ # Get JSON data from request
220
+ data = request.get_json()
221
+
222
+ if data is None:
223
+ return jsonify({"error": "No JSON data provided"}), 400
224
+
225
+ # Validate input data
226
+ is_valid, validation_message = validate_input_data(data)
227
+ if not is_valid:
228
+ return jsonify({"error": validation_message}), 400
229
+
230
+ # Preprocess data
231
+ processed_data, error = preprocess_for_prediction(data)
232
+ if error:
233
+ return jsonify({"error": f"Preprocessing failed: {error}"}), 400
234
+
235
+ # Make prediction
236
+ prediction = model.predict(processed_data)[0]
237
+
238
+ # Prepare response
239
+ response = {
240
+ "prediction": float(prediction),
241
+ "input_data": data,
242
+ "model_info": {
243
+ "model_name": model_artifacts['model_name'],
244
+ "prediction_timestamp": datetime.now().isoformat()
245
+ }
246
+ }
247
+
248
+ logger.info(f"Prediction made: {prediction:.2f}")
249
+ return jsonify(response)
250
+
251
+ except Exception as e:
252
+ logger.error(f"Prediction error: {str(e)}")
253
+ logger.error(f"Traceback: {traceback.format_exc()}")
254
+ return jsonify({"error": f"Prediction failed: {str(e)}"}), 500
255
+
256
+ @app.route('/batch_predict', methods=['POST'])
257
+ def batch_predict():
258
+ """Batch prediction endpoint."""
259
+ try:
260
+ # Get JSON data from request
261
+ data = request.get_json()
262
+
263
+ if data is None:
264
+ return jsonify({"error": "No JSON data provided"}), 400
265
+
266
+ # Ensure data is a list
267
+ if not isinstance(data, list):
268
+ return jsonify({"error": "Data must be a list of records"}), 400
269
+
270
+ if len(data) == 0:
271
+ return jsonify({"error": "Empty data list provided"}), 400
272
+
273
+ predictions = []
274
+ errors = []
275
+
276
+ for i, record in enumerate(data):
277
+ try:
278
+ # Validate input data
279
+ is_valid, validation_message = validate_input_data(record)
280
+ if not is_valid:
281
+ errors.append(f"Record {i}: {validation_message}")
282
+ predictions.append(None)
283
+ continue
284
+
285
+ # Preprocess data
286
+ processed_data, error = preprocess_for_prediction(record)
287
+ if error:
288
+ errors.append(f"Record {i}: Preprocessing failed - {error}")
289
+ predictions.append(None)
290
+ continue
291
+
292
+ # Make prediction
293
+ prediction = model.predict(processed_data)[0]
294
+ predictions.append(float(prediction))
295
+
296
+ except Exception as e:
297
+ errors.append(f"Record {i}: {str(e)}")
298
+ predictions.append(None)
299
+
300
+ # Prepare response
301
+ response = {
302
+ "predictions": predictions,
303
+ "total_records": len(data),
304
+ "successful_predictions": len([p for p in predictions if p is not None]),
305
+ "errors": errors if errors else None,
306
+ "model_info": {
307
+ "model_name": model_artifacts['model_name'],
308
+ "prediction_timestamp": datetime.now().isoformat()
309
+ }
310
+ }
311
+
312
+ logger.info(f"Batch prediction completed: {len(predictions)} records processed")
313
+ return jsonify(response)
314
+
315
+ except Exception as e:
316
+ logger.error(f"Batch prediction error: {str(e)}")
317
+ return jsonify({"error": f"Batch prediction failed: {str(e)}"}), 500
318
+
319
+ # Initialize the model when the app starts
320
+ @app.before_first_request
321
+ def initialize():
322
+ """Initialize the model on first request."""
323
+ logger.info("Initializing SuperKart Sales Forecasting API...")
324
+ success = load_model()
325
+ if success:
326
+ logger.info("API initialization completed successfully")
327
+ else:
328
+ logger.error("API initialization failed - model could not be loaded")
329
+
330
+ if __name__ == '__main__':
331
+ # Load model
332
+ if load_model():
333
+ print("[SUCCESS] Model loaded successfully")
334
+ print("[STARTING] SuperKart Sales Forecasting API...")
335
+ app.run(host='0.0.0.0', port=8080, debug=False)
336
+ else:
337
+ print("[ERROR] Failed to load model. Please check model file.")