MLbySush's picture
Upload folder using huggingface_hub
b8c339f verified
import os
import joblib
import pandas as pd
# must import Flask, request, jsonify before using them
from flask import Flask, request, jsonify
# ----------------------------
# Config / Model path
# ----------------------------
MODEL_PATH ="superKart_price_prediction_model_v1_0.joblib"
# ----------------------------
# Initialize app and load model
# ----------------------------
app = Flask("SuperKart Sales Predictor")
# Load model
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. ")
model = joblib.load(MODEL_PATH)
# These are the raw input feature names before preprocessing
NUMERIC_COLS = ['Product_Weight', 'Product_Allocated_Area', 'Product_MRP', 'Store_Age']
CATEGORICAL_COLS = ['Product_Sugar_Content', 'Product_Type', 'Store_Size',
'Store_Location_City_Type', 'Store_Type']
EXPECTED_COLUMNS = NUMERIC_COLS + CATEGORICAL_COLS
# ----------------------------
# Utility function
# ----------------------------
def validate_and_prepare_input(df: pd.DataFrame):
"""
Ensure the dataframe has the required columns. If Store_Establishment_Year
is provided instead of Store_Age, it will be converted to Store_Age.
Returns the prepared dataframe and a list of missing columns (empty if ok).
"""
df = df.copy()
missing = [c for c in EXPECTED_COLUMNS if c not in df.columns]
# Code for if user provided Store_Establishment_Year, convert to Store_Age
if 'Store_Establishment_Year' in df.columns and 'Store_Age' in missing:
df['Store_Age'] = 2025 - df['Store_Establishment_Year']
missing = [c for c in EXPECTED_COLUMNS if c not in df.columns]
return df, missing
# ----------------------------
# Routes
# ----------------------------
@app.get("/")
def home():
"""Health check / Landing page"""
return jsonify({
"service": "SuperKart Sales Predictor",
"status": "running"
})
@app.post("/v1/predict")
def predict_single():
"""
Predict sales for a single product-store record.
Expected JSON schema (example):
{
"Product_Weight": 12.5,
"Product_Allocated_Area": 0.056,
"Product_MRP": 149.0,
"Store_Age": 16,
"Product_Sugar_Content": "Low Sugar",
"Product_Type": "Dairy",
"Store_Size": "High",
"Store_Location_City_Type": "Tier 1",
"Store_Type": "Supermarket Type 1"
}
"""
try:
data = request.get_json(force=True)
if not isinstance(data, dict):
return jsonify({"error": "Input JSON must be an object/dict"}), 400
# Convert to DataFrame
input_df = pd.DataFrame([data])
# Validate and prepare
input_df, missing = validate_and_prepare_input(input_df)
if missing:
return jsonify({"error": "Missing required columns", "missing_columns": missing}), 400
# Keep only expected columns (ignore extra fields)
input_df = input_df[EXPECTED_COLUMNS]
# Predict using pipeline (pipeline will apply preprocessors)
pred = model.predict(input_df)
prediction_value = float(pred[0])
return jsonify({"prediction": prediction_value}), 200
except Exception as e:
return jsonify({"error": "Exception during prediction", "details": str(e)}), 500
@app.post("/v1/predict_batch")
def predict_batch():
"""
Predict sales for a batch of records supplied as a CSV file upload.
The CSV should contain the expected columns (or Store_Establishment_Year
instead of Store_Age which will be converted automatically).
"""
try:
if 'file' not in request.files:
return jsonify({"error": "No file part in the request. Upload a CSV file with key 'file'."}), 400
file = request.files['file']
if file.filename == "":
return jsonify({"error": "Empty filename. Please upload a CSV file."}), 400
# Read CSV
input_df = pd.read_csv(file)
input_df, missing = validate_and_prepare_input(input_df)
if missing:
return jsonify({"error": "Missing required columns in uploaded CSV", "missing_columns": missing}), 400
# Keep only expected columns and predict
input_df = input_df[EXPECTED_COLUMNS]
preds = model.predict(input_df)
# Return predictions aligned with original input index
output = input_df.copy()
output['predicted_Product_Store_Sales_Total'] = preds.astype(float)
# Convert to records for JSON response (limit size if necessary)
results = output.reset_index().to_dict(orient='records')
return jsonify({"predictions_count": len(results), "predictions": results}), 200
except Exception as e:
return jsonify({"error": "Exception during batch prediction", "details": str(e)}), 500
# ----------------------------
# Run app
# ----------------------------
if __name__ == "__main__":
# Listen on 0.0.0.0 for containerized environments. In dev, use port 7860 or 5000 as required.
app.run(host="0.0.0.0", port=7860, debug=False)