Himadri1102's picture
Upload folder using huggingface_hub
dd28508 verified
# Import necessary libraries
import numpy as np
import joblib
import pandas as pd
import os
from flask import Flask, request, jsonify
# Initialize the Flask application
sales_predictor_api = Flask("SuperKart Sales Predictor API")
# Load the trained machine learning model
# CRITICAL FIX: The file name is CORRECTED here to match the file you uploaded!
MODEL_PATH = os.path.join(os.getcwd(), "best_xgb_pipeline.joblib")
try:
model = joblib.load(MODEL_PATH)
print(f"Model loaded successfully from: {MODEL_PATH}")
except Exception as e:
print(f"FATAL ERROR: Could not load model: {e}")
model = None
@sales_predictor_api.get('/')
def home():
"""Returns a simple welcome message for the SuperKart API."""
return "Welcome to the SuperKart Sales Prediction API!"
@sales_predictor_api.post('/v1/sales')
def predict_single_sale():
"""Handles POST requests for a single sales forecast."""
if model is None:
return jsonify({'error': 'Internal server error: Model failed to load at startup.'}), 500
try:
property_data = request.get_json()
input_data = pd.DataFrame([property_data])
# Predicts actual sales total directly
predicted_sales = model.predict(input_data)[0]
predicted_sales = round(float(predicted_sales), 2)
return jsonify({'Predicted Total Sales': predicted_sales})
except Exception as e:
return jsonify({'error': str(e), 'message': f'Prediction failed: {str(e)}'}), 400
@sales_predictor_api.post('/v1/salesbatch')
def predict_sales_batch():
"""Handles POST requests for batch sales forecasts via CSV upload."""
if model is None:
return jsonify({'error': 'Internal server error: Model failed to load at startup.'}), 500
try:
file = request.files['file']
input_data = pd.read_csv(file)
predicted_sales = model.predict(input_data).tolist()
final_predictions = [round(float(sale), 2) for sale in predicted_sales]
property_ids = input_data['id'].tolist()
output_dict = dict(zip(property_ids, final_predictions))
return jsonify(output_dict)
except Exception as e:
return jsonify({'error': str(e), 'message': f'Batch prediction failed: {str(e)}'}), 400
if __name__ == '__main__':
sales_predictor_api.run(debug=True, host='0.0.0.0', port=7860)