| |
| import numpy as np |
| import joblib |
| import pandas as pd |
| import os |
| from flask import Flask, request, jsonify |
|
|
| |
| sales_predictor_api = Flask("SuperKart Sales Predictor API") |
|
|
| |
| |
| MODEL_PATH = os.path.join(os.getcwd(), "best_xgb_pipeline.joblib") |
|
|
| try: |
| model = joblib.load(MODEL_PATH) |
| print(f"Model loaded successfully from: {MODEL_PATH}") |
| except Exception as e: |
| print(f"FATAL ERROR: Could not load model: {e}") |
| model = None |
|
|
| @sales_predictor_api.get('/') |
| def home(): |
| """Returns a simple welcome message for the SuperKart API.""" |
| return "Welcome to the SuperKart Sales Prediction API!" |
|
|
| @sales_predictor_api.post('/v1/sales') |
| def predict_single_sale(): |
| """Handles POST requests for a single sales forecast.""" |
| if model is None: |
| return jsonify({'error': 'Internal server error: Model failed to load at startup.'}), 500 |
| |
| try: |
| property_data = request.get_json() |
| input_data = pd.DataFrame([property_data]) |
| |
| |
| predicted_sales = model.predict(input_data)[0] |
| predicted_sales = round(float(predicted_sales), 2) |
|
|
| return jsonify({'Predicted Total Sales': predicted_sales}) |
| except Exception as e: |
| return jsonify({'error': str(e), 'message': f'Prediction failed: {str(e)}'}), 400 |
|
|
| @sales_predictor_api.post('/v1/salesbatch') |
| def predict_sales_batch(): |
| """Handles POST requests for batch sales forecasts via CSV upload.""" |
| if model is None: |
| return jsonify({'error': 'Internal server error: Model failed to load at startup.'}), 500 |
| |
| try: |
| file = request.files['file'] |
| input_data = pd.read_csv(file) |
| predicted_sales = model.predict(input_data).tolist() |
| final_predictions = [round(float(sale), 2) for sale in predicted_sales] |
|
|
| property_ids = input_data['id'].tolist() |
| output_dict = dict(zip(property_ids, final_predictions)) |
| |
| return jsonify(output_dict) |
| except Exception as e: |
| return jsonify({'error': str(e), 'message': f'Batch prediction failed: {str(e)}'}), 400 |
|
|
|
|
| if __name__ == '__main__': |
| sales_predictor_api.run(debug=True, host='0.0.0.0', port=7860) |
|
|