Spaces:
Sleeping
Sleeping
| # Import necessary libraries | |
| import numpy as np | |
| import joblib # For loading the serialized model | |
| import pandas as pd # For data manipulation | |
| from flask import Flask, request, jsonify # For creating the Flask API | |
| # Initialize the Flask application | |
| superkart_sales_api = Flask("SuperKart Sales Predictor") | |
| # Load the trained machine learning model | |
| model = joblib.load("superkart_sales_prediction_model_v1_0.joblib") | |
| # Define a route for the home page (GET request) | |
| def home(): | |
| """ | |
| This function handles GET requests to the root URL ('/') of the API. | |
| It returns a simple welcome message. | |
| """ | |
| return "Welcome to the SuperKart Sales Prediction API!" | |
| # Endpoint for Single Prediction | |
| # ------------------------------- | |
| def predict_sales(): | |
| """ | |
| Predict sales for a single product-outlet combination | |
| """ | |
| try: | |
| # Get JSON data from request | |
| data = request.get_json() | |
| # Extract relevant features | |
| sample = { | |
| 'Product_Weight': data['Product_Weight'], | |
| 'Product_Allocated_Area': data['Product_Allocated_Area'], | |
| 'Product_MRP': data['Product_MRP'], | |
| 'Store_Establishment_Year': data['Store_Establishment_Year'], | |
| 'Product_Sugar_Content': data['Product_Sugar_Content'], | |
| 'Store_Size': data['Store_Size'], | |
| 'Store_Location_City_Type': data['Store_Location_City_Type'], | |
| 'Store_Type': data['Store_Type'], | |
| 'Product_Type': data['Product_Type'] | |
| } | |
| # Convert to DataFrame | |
| input_df = pd.DataFrame([sample]) | |
| # Make prediction | |
| prediction = model.predict(input_df)[0] | |
| # Convert to float and round | |
| prediction = round(float(prediction), 2) | |
| return jsonify({"Predicted Sales": prediction}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # ------------------------------- | |
| # Endpoint for Batch Prediction | |
| # ------------------------------- | |
| def predict_sales_batch(): | |
| """ | |
| Predict sales for multiple rows from a CSV file | |
| """ | |
| try: | |
| # Get uploaded file | |
| file = request.files['file'] | |
| # Read into DataFrame | |
| input_df = pd.read_csv(file) | |
| # Make predictions | |
| predictions = model.predict(input_df).tolist() | |
| predictions = [round(float(p), 2) for p in predictions] | |
| # Return predictions in a dict format with row index as key | |
| output_dict = {str(i): predictions[i] for i in range(len(predictions))} | |
| return jsonify(output_dict) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # ------------------------------- | |
| # Run App | |
| # ------------------------------- | |
| if __name__ == '__main__': | |
| superkart_sales_api.run(debug=True, host="0.0.0.0", port=7860) | |