suka / app.py
grkavi0912's picture
Upload app.py with huggingface_hub
3c58142 verified
# -------------------------------------------------------
# Flask Web Framework for Product Store Sales Prediction
# -------------------------------------------------------
# Import necessary libraries
import os
import numpy as np
import pandas as pd
import joblib
from flask import Flask, request, jsonify
# Initialize the Flask application
product_sales_api = Flask("SuperKart Product Sales Predictor")
# Define the path to the model file - it will be at the root of the Space
model_path_in_space = "random_forest_pipeline.joblib"
# Load the trained RandomForest model pipeline
try:
model = joblib.load(model_path_in_space)
print(f"Model loaded successfully from {model_path_in_space}")
except Exception as e:
print(f"Error loading model: {e}")
model = None # Set model to None to indicate loading failure
# -------------------------------------------------------
# Define a route for the home page (GET request)
# -------------------------------------------------------
@product_sales_api.route('/')
def home():
"""
This function handles GET requests to the root URL ('/') of the API.
It returns a simple welcome message.
"""
if model is None:
return "Error: Model could not be loaded. Please check the logs.", 500
return "Welcome to the SuperKart Product Store Sales Prediction API!"
# -------------------------------------------------------
# Define an endpoint for single product prediction (POST request)
# -------------------------------------------------------
@product_sales_api.route('/v1/sales', methods=['POST'])
def predict_sales():
"""
This function handles POST requests to the '/v1/sales' endpoint.
It expects a JSON payload containing product features and returns
the predicted Product_Store_Sales_Total as a JSON response.
"""
if model is None:
return jsonify({'error': 'Model not loaded'}), 500
try:
# Get the JSON data from the request body
product_data = request.get_json()
# Convert the JSON data into a Pandas DataFrame
# Ensure the column names match the features used during training
# and are in the correct order if your model/pipeline is sensitive to it.
# Based on your preprocessing and model, the expected input features
# after one-hot encoding are needed. You might need to map the input
# JSON keys to the expected columns in your preprocessor/model.
# A more robust approach here would be to reconstruct the expected
# DataFrame structure based on the features your model was trained on.
# For simplicity and demonstration, let's assume the input JSON
# has keys corresponding to the original features BEFORE preprocessing
# and the preprocessor handles the transformation.
# These were: 'Product_Weight', 'Product_Allocated_Area', 'Product_MRP',
# 'Store_Establishment_Year', 'Product_Sugar_Content', 'Product_Type',
# 'Store_Id', 'Store_Size', 'Store_Location_City_Type', 'Store_Type'
# It's crucial that the keys in the incoming JSON match these original column names.
input_sample = {}
# Populate input_sample from product_data, handle missing keys if necessary
# For demonstration, assuming all keys are present:
original_feature_cols = [
'Product_Weight', 'Product_Allocated_Area', 'Product_MRP',
'Store_Establishment_Year', 'Product_Sugar_Content', 'Product_Type',
'Store_Id', 'Store_Size', 'Store_Location_City_Type', 'Store_Type'
]
for col in original_feature_cols:
# Use .get() to safely access keys, provide a default or handle missing later
input_sample[col] = product_data.get(col)
input_df = pd.DataFrame([input_sample])
# Ensure categorical columns have the correct dtype
categorical_cols = ['Product_Sugar_Content', 'Product_Type', 'Store_Id', 'Store_Size', 'Store_Location_City_Type', 'Store_Type']
for col in categorical_cols:
if col in input_df.columns:
input_df[col] = input_df[col].astype('category')
# Make prediction using the trained model pipeline
# The pipeline handles preprocessing
prediction = model.predict(input_df)[0]
# Return the predicted sales total as JSON
return jsonify({'Predicted_Product_Store_Sales_Total': float(prediction)})
except Exception as e:
# Log the error for debugging
print(f"Error during single prediction: {e}")
return jsonify({'error': str(e)}), 500
# -------------------------------------------------------
# Define an endpoint for batch predictions (CSV upload)
# -------------------------------------------------------
@product_sales_api.route('/v1/salesbatch', methods=['POST'])
def predict_sales_batch():
"""
This function handles POST requests to the '/v1/salesbatch' endpoint.
It expects a CSV file upload and returns predictions for multiple records.
"""
if model is None:
return jsonify({'error': 'Model not loaded'}), 500
try:
# Get the uploaded CSV file
if 'file' not in request.files:
return jsonify({'error': 'No file part in the request'}), 400
file = request.files['file']
# If the user does not select a file, the browser submits an
# empty file without a filename.
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file:
# Read the CSV file into a DataFrame
# Assume the CSV columns match the original training features
data = pd.read_csv(file)
# Ensure categorical columns have the correct dtype after reading from CSV
categorical_cols = ['Product_Sugar_Content', 'Product_Type', 'Store_Id', 'Store_Size', 'Store_Location_City_Type', 'Store_Type']
for col in categorical_cols:
if col in data.columns:
data[col] = data[col].astype('category')
# Make batch predictions using the trained model pipeline
predictions = model.predict(data)
data['Predicted_Product_Store_Sales_Total'] = predictions
# Return the results as JSON
return data.to_json(orient='records')
except Exception as e:
# Log the error for debugging
print(f"Error during batch prediction: {e}")
return jsonify({'error': str(e)}), 500
# -------------------------------------------------------
# Run the Flask API (typically not run in deployment, Gunicorn handles this)
# -------------------------------------------------------
# This part is mainly for local testing. In a Docker deployment with Gunicorn,
# Gunicorn will call the 'product_sales_api' application directly.
# if __name__ == '__main__':
# product_sales_api.run(host='0.0.0.0', port=5000, debug=True)