harasar's picture
Upload folder using huggingface_hub
324825a verified
import joblib
import pandas as pd
from flask import Flask, request, jsonify
# Initialize Flask app with a name
app = Flask("Store Sales Predictor")
# Load the trained churn prediction model
model = joblib.load("superkart_prediction_model_v1_0.joblib")
# Define a route for the home page
@app.get('/')
def home():
return "Welcome to Store Sales Prediction API"
# Define an endpoint to predict churn for a single customer
@app.post('/v1/customer')
def predict_churn():
# Get JSON data from the request
customer_data = request.get_json()
# Extract relevant customer features from the input data
# sample = {
# 'Product_Weight': customer_data['Product_Weight'],
# 'Product_Allocated_Area': customer_data['Product_Allocated_Area'],
# 'Product_MRP': customer_data['Product_MRP'],
# 'Product_Type': customer_data['Product_Type'],
# 'Store_Size': customer_data['Store_Size'],
# 'Store_Type': customer_data['Store_Type'],
# 'Product_Sugar_Content': customer_data['Product_Sugar_Content'],
# 'Store_Establishment_Year': customer_data['Store_Establishment_Year'],
# 'Store_Location_City_Type': customer_data['Store_Location_City_Type']
# }
# Convert the extracted data into a DataFrame
input_data = pd.DataFrame([customer_data])
# Make a sales prediction using the trained model
prediction = model.predict(input_data)
#return jsonify({'predicted_sales': prediction})
return jsonify({'predicted_sales': float(prediction[0])})
# Define an endpoint to predict churn for a batch of customers
@app.post('/v1/customerbatch')
def predict_churn_batch():
try:
# Get the uploaded CSV file from the request
file = request.files['file']
if not file:
return jsonify({'error': 'No file uploaded'}), 400
# Read the file into a DataFrame
input_data = pd.read_csv(file)
# Make predictions using the model
predictions = model.predict(input_data.drop(columns=["Product_Id", "Store_Id","Product_Store_Sales_Total"], errors='ignore'))
# Pair predictions with Store_Id or Product_Id
output = {
str(store_id): float(pred)
for store_id, pred in zip(input_data["Store_Id"], predictions.round(2))
}
return jsonify(output)
except Exception as e:
return jsonify({'error': str(e)}), 500
# Run the Flask app in debug mode
if __name__ == '__main__':
app.run(debug=True)