mainak555's picture
Upload folder using huggingface_hub
05458be verified
import os
import pandas as pd
from typing import Any
from flask import Flask, request, jsonify
from pydantic import BaseModel, Field, ValidationError, model_validator
from backend.util import num_features_selector, cat_features_selector, store_age
import joblib
clf = joblib.load("SuperKart_clf_model_v1_0.joblib")
reg = joblib.load("SuperKart_reg_model_v1_0.joblib")
app = Flask("SuperKart: Revenue Forecasting")
@app.get("/")
def liveProbe():
return app.name
class ReqSchema(BaseModel):
Store_Size: str = Field(..., description="Size of the store")
Store_Type: str = Field(..., description="Type of the store")
Product_Type: str = Field(..., description="Type of the product")
Product_Weight: float = Field(..., description="Weight of the product", gt=0)
Product_MRP: float = Field(..., description="Maximum Retail Price of the product")
Store_Establishment_Year: int = Field(..., description="Store Establishment year")
Product_Allocated_Area: float = Field(..., description="Allocated area for the product")
Store_Location_City_Type: str = Field(..., description="Location city type of the store")
Store_Age: int = Field(..., description="Age of the store", exclude=True)
@model_validator(mode="before")
def set_store_age(cls, values: dict[str, Any]) -> dict[str, Any]:
establishment_year = values.get('Store_Establishment_Year')
values['Store_Age'] = store_age(establishment_year)
return values
@app.post("/v1/predict")
def predict():
auth = request.headers.get('Authorization')
if auth != os.environ['auth_key']:
return jsonify({"error": "Unauthorized"}), 401
elif clf is None or reg is None:
return jsonify({"error": "Prediction service is unavailable. Model(s) failed to load."}), 503
try:
payload = request.get_json()
reqData = ReqSchema(**payload)
except ValidationError as e:
return jsonify({"error": e.errors()}), 400
except Exception as e:
return jsonify({"error": f"Invalid request payload: {e}"}), 400
try:
data = pd.DataFrame([{
'Product_Weight': reqData.Product_Weight,
'Product_Allocated_Area': reqData.Product_Allocated_Area,
'Product_Type': reqData.Product_Type,
'Product_MRP': reqData.Product_MRP,
'Store_Size': reqData.Store_Size,
'Store_Type': reqData.Store_Type,
'Store_Location_City_Type': reqData.Store_Location_City_Type,
'Store_Age': reqData.Store_Age,
}])
pred_label = clf.predict(data).tolist()[0]
pred_value = reg.predict(data).tolist()[0]
print(pred_label, pred_value)
return jsonify({'performance': pred_label, 'revenue': round(pred_value, 2)})
except Exception as e:
return jsonify({"error": f"Prediction failed: {e}"}), 500
@app.post("/v1/predict/bulk")
def predictBulk():
auth = request.headers.get('Authorization')
if auth != os.environ['auth_key']:
return jsonify({"error": "Unauthorized"}), 401
elif clf is None or reg is None:
return jsonify({"error": "Prediction service is unavailable. Model(s) failed to load."}), 503
elif 'file' not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files['file']
try:
if file.filename.endswith(".csv"):
df = pd.read_csv(file)
elif file.filename.endswith((".xls", ".xlsx")):
df = pd.read_excel(file)
else:
return jsonify({"error": "Unsupported file format! Upload CSV or Excel"}), 400
reqCols = ['Product_Weight', 'Product_Allocated_Area', 'Product_Type', 'Product_MRP', 'Store_Size', 'Store_Location_City_Type', 'Store_Type', 'Store_Establishment_Year']
missingCols = [col for col in reqCols if col not in df.columns]
if missingCols:
return jsonify({"error": f"Missing columns: {missingCols}"}), 400
X = df[reqCols].copy()
X["Store_Age"] = X["Store_Establishment_Year"].apply(store_age)
df["Sales_Performance"] = clf.predict(X)
df["Sales_Revenue"] = reg.predict(X)
return jsonify(df.to_dict(orient="records"))
except Exception as e:
return jsonify({"error": f"Bulk prediction failed: {e}"}), 500
if __name__ == '__main__':
app.run(debug=True)