Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| from flask_cors import CORS | |
| import logging | |
| from datetime import datetime | |
| import os | |
| import traceback | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Global variables for model and preprocessor | |
| model = None | |
| preprocessor = None | |
| model_artifacts = None | |
| def load_model(): | |
| """Load the trained model and preprocessing artifacts.""" | |
| global model, preprocessor, model_artifacts | |
| try: | |
| model_path = 'superkart_sales_forecasting_model.joblib' | |
| if not os.path.exists(model_path): | |
| logger.error(f"Model file not found: {model_path}") | |
| return False | |
| # Load model artifacts | |
| model_artifacts = joblib.load(model_path) | |
| model = model_artifacts['model'] | |
| preprocessor = model_artifacts['preprocessor'] | |
| logger.info(f"Model loaded successfully: {model_artifacts['model_name']}") | |
| logger.info(f"Training date: {model_artifacts['training_date']}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| return False | |
| def validate_input_data(data): | |
| """Validate input data for prediction.""" | |
| required_fields = [ | |
| 'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area', | |
| 'Product_Type', 'Product_MRP', 'Store_Size', | |
| 'Store_Location_City_Type', 'Store_Type', 'Store_Age' | |
| ] | |
| # Check if all required fields are present | |
| missing_fields = [field for field in required_fields if field not in data] | |
| if missing_fields: | |
| return False, f"Missing required fields: {missing_fields}" | |
| # Validate data types and ranges | |
| try: | |
| # Numerical validations | |
| if not isinstance(data['Product_Weight'], (int, float)) or data['Product_Weight'] <= 0: | |
| return False, "Product_Weight must be a positive number" | |
| if not isinstance(data['Product_Allocated_Area'], (int, float)) or not (0 <= data['Product_Allocated_Area'] <= 1): | |
| return False, "Product_Allocated_Area must be between 0 and 1" | |
| if not isinstance(data['Product_MRP'], (int, float)) or data['Product_MRP'] <= 0: | |
| return False, "Product_MRP must be a positive number" | |
| if not isinstance(data['Store_Age'], (int, float)) or data['Store_Age'] < 0: | |
| return False, "Store_Age must be a non-negative number" | |
| # Categorical validations | |
| valid_sugar_content = ['Low Sugar', 'Regular', 'No Sugar'] | |
| if data['Product_Sugar_Content'] not in valid_sugar_content: | |
| return False, f"Product_Sugar_Content must be one of: {valid_sugar_content}" | |
| valid_store_sizes = ['Small', 'Medium', 'High'] | |
| if data['Store_Size'] not in valid_store_sizes: | |
| return False, f"Store_Size must be one of: {valid_store_sizes}" | |
| valid_city_types = ['Tier 1', 'Tier 2', 'Tier 3'] | |
| if data['Store_Location_City_Type'] not in valid_city_types: | |
| return False, f"Store_Location_City_Type must be one of: {valid_city_types}" | |
| valid_store_types = ['Departmental Store', 'Supermarket Type1', 'Supermarket Type2', 'Food Mart'] | |
| if data['Store_Type'] not in valid_store_types: | |
| return False, f"Store_Type must be one of: {valid_store_types}" | |
| return True, "Validation passed" | |
| except Exception as e: | |
| return False, f"Validation error: {str(e)}" | |
| def preprocess_for_prediction(data): | |
| """Preprocess input data for model prediction.""" | |
| try: | |
| # Convert to DataFrame | |
| if isinstance(data, dict): | |
| df = pd.DataFrame([data]) | |
| else: | |
| df = pd.DataFrame(data) | |
| # Feature engineering functions (must match training) | |
| def categorize_mrp(mrp): | |
| if mrp <= 69.0: | |
| return 'Low' | |
| elif mrp <= 136.0: | |
| return 'Medium_Low' | |
| elif mrp <= 202.0: | |
| return 'Medium_High' | |
| else: | |
| return 'High' | |
| def categorize_weight(weight): | |
| if weight <= 8.773: | |
| return 'Light' | |
| elif weight <= 12.89: | |
| return 'Medium_Light' | |
| elif weight <= 16.95: | |
| return 'Medium_Heavy' | |
| else: | |
| return 'Heavy' | |
| def categorize_store_age(age): | |
| if age <= 20: | |
| return 'New' | |
| elif age <= 30: | |
| return 'Established' | |
| else: | |
| return 'Legacy' | |
| # Add engineered features | |
| df['Product_MRP_Category'] = df['Product_MRP'].apply(categorize_mrp) | |
| df['Product_Weight_Category'] = df['Product_Weight'].apply(categorize_weight) | |
| df['Store_Age_Category'] = df['Store_Age'].apply(categorize_store_age) | |
| df['City_Store_Type'] = df['Store_Location_City_Type'] + '_' + df['Store_Type'] | |
| df['Size_Type_Interaction'] = df['Store_Size'] + '_' + df['Store_Type'] | |
| # Transform using the preprocessing pipeline | |
| processed_data = preprocessor.transform(df) | |
| return processed_data, None | |
| except Exception as e: | |
| return None, str(e) | |
| def home(): | |
| """Home endpoint with API information.""" | |
| api_info = { | |
| "message": "SuperKart Sales Forecasting API", | |
| "version": "1.0", | |
| "model_info": { | |
| "name": model_artifacts['model_name'] if model_artifacts else "Model not loaded", | |
| "training_date": model_artifacts['training_date'] if model_artifacts else "Unknown", | |
| "version": model_artifacts['model_version'] if model_artifacts else "Unknown" | |
| } if model_artifacts else {"status": "Model not loaded"}, | |
| "endpoints": { | |
| "/": "API information", | |
| "/health": "Health check", | |
| "/predict": "Single prediction (POST)", | |
| "/batch_predict": "Batch predictions (POST)", | |
| "/model_info": "Model details" | |
| }, | |
| "sample_input": { | |
| "Product_Weight": 10.5, | |
| "Product_Sugar_Content": "Low Sugar", | |
| "Product_Allocated_Area": 0.15, | |
| "Product_Type": "Fruits and Vegetables", | |
| "Product_MRP": 150.0, | |
| "Store_Size": "Medium", | |
| "Store_Location_City_Type": "Tier 2", | |
| "Store_Type": "Supermarket Type2", | |
| "Store_Age": 15 | |
| } | |
| } | |
| return jsonify(api_info) | |
| def health_check(): | |
| """Health check endpoint.""" | |
| health_status = { | |
| "status": "healthy" if model is not None else "unhealthy", | |
| "model_loaded": model is not None, | |
| "timestamp": datetime.now().isoformat(), | |
| "service": "SuperKart Sales Forecasting API" | |
| } | |
| return jsonify(health_status) | |
| def model_info(): | |
| """Get detailed model information.""" | |
| if model_artifacts is None: | |
| return jsonify({"error": "Model not loaded"}), 500 | |
| info = { | |
| "model_name": model_artifacts['model_name'], | |
| "training_date": model_artifacts['training_date'], | |
| "model_version": model_artifacts['model_version'], | |
| "performance_metrics": model_artifacts['performance_metrics'], | |
| "feature_count": len(model_artifacts['feature_names']), | |
| "model_type": type(model).__name__ | |
| } | |
| return jsonify(info) | |
| def predict(): | |
| """Single prediction endpoint.""" | |
| try: | |
| # Get JSON data from request | |
| data = request.get_json() | |
| if data is None: | |
| return jsonify({"error": "No JSON data provided"}), 400 | |
| # Validate input data | |
| is_valid, validation_message = validate_input_data(data) | |
| if not is_valid: | |
| return jsonify({"error": validation_message}), 400 | |
| # Preprocess data | |
| processed_data, error = preprocess_for_prediction(data) | |
| if error: | |
| return jsonify({"error": f"Preprocessing failed: {error}"}), 400 | |
| # Make prediction | |
| prediction = model.predict(processed_data)[0] | |
| # Prepare response | |
| response = { | |
| "prediction": float(prediction), | |
| "input_data": data, | |
| "model_info": { | |
| "model_name": model_artifacts['model_name'], | |
| "prediction_timestamp": datetime.now().isoformat() | |
| } | |
| } | |
| logger.info(f"Prediction made: {prediction:.2f}") | |
| return jsonify(response) | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| return jsonify({"error": f"Prediction failed: {str(e)}"}), 500 | |
| def batch_predict(): | |
| """Batch prediction endpoint.""" | |
| try: | |
| # Get JSON data from request | |
| data = request.get_json() | |
| if data is None: | |
| return jsonify({"error": "No JSON data provided"}), 400 | |
| # Ensure data is a list | |
| if not isinstance(data, list): | |
| return jsonify({"error": "Data must be a list of records"}), 400 | |
| if len(data) == 0: | |
| return jsonify({"error": "Empty data list provided"}), 400 | |
| predictions = [] | |
| errors = [] | |
| for i, record in enumerate(data): | |
| try: | |
| # Validate input data | |
| is_valid, validation_message = validate_input_data(record) | |
| if not is_valid: | |
| errors.append(f"Record {i}: {validation_message}") | |
| predictions.append(None) | |
| continue | |
| # Preprocess data | |
| processed_data, error = preprocess_for_prediction(record) | |
| if error: | |
| errors.append(f"Record {i}: Preprocessing failed - {error}") | |
| predictions.append(None) | |
| continue | |
| # Make prediction | |
| prediction = model.predict(processed_data)[0] | |
| predictions.append(float(prediction)) | |
| except Exception as e: | |
| errors.append(f"Record {i}: {str(e)}") | |
| predictions.append(None) | |
| # Prepare response | |
| response = { | |
| "predictions": predictions, | |
| "total_records": len(data), | |
| "successful_predictions": len([p for p in predictions if p is not None]), | |
| "errors": errors if errors else None, | |
| "model_info": { | |
| "model_name": model_artifacts['model_name'], | |
| "prediction_timestamp": datetime.now().isoformat() | |
| } | |
| } | |
| logger.info(f"Batch prediction completed: {len(predictions)} records processed") | |
| return jsonify(response) | |
| except Exception as e: | |
| logger.error(f"Batch prediction error: {str(e)}") | |
| return jsonify({"error": f"Batch prediction failed: {str(e)}"}), 500 | |
| # Initialize the model when the app starts (Flask 3.x compatible) | |
| def initialize(): | |
| """Initialize the model on app startup.""" | |
| logger.info("Initializing SuperKart Sales Forecasting API...") | |
| success = load_model() | |
| if success: | |
| logger.info("API initialization completed successfully") | |
| else: | |
| logger.error("API initialization failed - model could not be loaded") | |
| # Call initialization immediately when module loads | |
| with app.app_context(): | |
| initialize() | |
| if __name__ == '__main__': | |
| # Load model | |
| if load_model(): | |
| print("[SUCCESS] Model loaded successfully") | |
| print("[STARTING] SuperKart Sales Forecasting API...") | |
| app.run(host='0.0.0.0', port=8080, debug=False) | |
| else: | |
| print("[ERROR] Failed to load model. Please check model file.") | |