Spaces:
Sleeping
Sleeping
File size: 3,892 Bytes
018c9c2 1c8e3bc 66d1041 1c8e3bc 66d1041 1c8e3bc 3700905 1c8e3bc 3700905 018c9c2 3700905 1c8e3bc 018c9c2 1c8e3bc 018c9c2 910402e 1c8e3bc 018c9c2 1c8e3bc 3700905 1c8e3bc 3700905 1c8e3bc 6fc00e3 79f9bb7 eaf4232 3700905 018c9c2 3700905 018c9c2 3700905 018c9c2 3700905 66d1041 3700905 66d1041 3700905 66d1041 3700905 66d1041 018c9c2 66d1041 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# Import necessary libraries
import numpy as np
import joblib
import pandas as pd
from flask import Flask, request, jsonify
import traceback
import math
# Define the path where the model is saved
model_file_name = "SuperKart_v1_0.joblib"
try:
# Load the trained machine learning model
model = joblib.load(model_file_name)
except FileNotFoundError:
print(f"Error: Model file not found at {model_file_name}")
model = None
except Exception as e:
print(f"Error loading model: {e}")
traceback.print_exc()
model = None
# Initialize the Flask app
app = Flask(__name__)
@app.route('/')
def home():
return "Welcome to the Super Kart Product Sales Price Prediction API!"
# ---------------- single Prediction Endpoint ----------------
@app.route('/v1/salesprice', methods=['POST'])
def predict_sales_price():
if model is None:
return jsonify({"error": "Model not loaded. Cannot make predictions."}), 500
try:
property_data = request.get_json(force=True)
expected_keys = [
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
'Product_Type', 'Product_MRP', 'Store_Size',
'Store_Location_City_Type', 'Store_Type', 'Store_Age'
]
if not all(key in property_data for key in expected_keys):
missing_keys = [key for key in expected_keys if key not in property_data]
return jsonify({"error": f"Missing keys in input data: {missing_keys}"}), 400
sample = {key: property_data.get(key) for key in expected_keys}
input_data = pd.DataFrame([sample])
predicted_sales_price = model.predict(input_data)
predicted_price = round(float(predicted_sales_price[0]), 2)
if math.isinf(predicted_price) or math.isnan(predicted_price):
return jsonify({"error": "Prediction resulted in an invalid value."}), 400
return jsonify({'Predicted Price': predicted_price}), 200
except Exception as e:
print(f"Error during single prediction: {e}")
traceback.print_exc()
return jsonify({"error": "Internal server error", "details": str(e)}), 500
# ---------------- Batch Prediction Endpoint ----------------
@app.route('/v1/salespricebatch', methods=['POST'])
def predict_sales_price_batch():
"""
Expects a CSV file with one product per row.
Returns JSON: a list of dicts with `row_id` and predicted price.
"""
if model is None:
return jsonify({"error": "Model not loaded. Cannot make predictions."}), 500
if 'file' not in request.files:
return jsonify({"error": "No file uploaded"}), 400
try:
file = request.files['file']
input_data = pd.read_csv(file)
expected_columns = [
'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
'Product_Type', 'Product_MRP', 'Store_Size',
'Store_Location_City_Type', 'Store_Type', 'Store_Age'
]
missing_columns = [col for col in expected_columns if col not in input_data.columns]
if missing_columns:
return jsonify({"error": f"Missing required columns: {missing_columns}"}), 400
input_data.reset_index(inplace=True)
input_data.rename(columns={'index': 'row_id'}, inplace=True)
predictions = model.predict(input_data[expected_columns])
predicted_prices = [round(float(p), 2) for p in predictions]
results = [
{"row_id": row_id, "Predicted Price": price}
for row_id, price in zip(input_data['row_id'], predicted_prices)
]
return jsonify(results), 200
except Exception as e:
print(f"Error during batch prediction: {e}")
traceback.print_exc()
return jsonify({"error": "Internal server error during batch prediction.", "details": str(e)}), 500
if __name__ == '__main__':
pass
|