Spaces:
Sleeping
Sleeping
| import io | |
| import json | |
| import joblib | |
| import pandas as pd | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| MODEL_PATH = "SuperKart_prediction_model_v1_0.joblib" | |
| EXPECTED_COLS = [ | |
| "Product_Weight", | |
| "Product_Allocated_Area", | |
| "Product_MRP", | |
| "Store_Age", | |
| "Product_Sugar_Content", | |
| "Product_Type", | |
| "Store_Type", | |
| "Store_Size", | |
| "Store_Location_City_Type", | |
| ] | |
| app = FastAPI(title="SuperKart Backend", version="1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| model = None | |
| def load_model(): | |
| global model | |
| model = joblib.load(MODEL_PATH) | |
| class Payload(BaseModel): | |
| Product_Weight: float | |
| Product_Allocated_Area: float | |
| Product_MRP: float | |
| Store_Age: int | |
| Product_Sugar_Content: str | |
| Product_Type: str | |
| Store_Type: str | |
| Store_Size: int | |
| Store_Location_City_Type: int | |
| def validate_and_order(df: pd.DataFrame) -> pd.DataFrame: | |
| missing = [c for c in EXPECTED_COLS if c not in df.columns] | |
| if missing: | |
| raise HTTPException(status_code=422, detail=f"Missing columns: {missing}") | |
| return df[EXPECTED_COLS].copy() | |
| def health(): | |
| return {"status": "ok"} | |
| def predict_single(payload: Payload): | |
| try: | |
| df = pd.DataFrame([payload.dict()]) | |
| X = validate_and_order(df) | |
| y = model.predict(X) | |
| return {"Predicted Price": float(y[0])} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def predict_batch(file: UploadFile = File(...)): | |
| try: | |
| content = file.file.read() | |
| df = pd.read_csv(io.BytesIO(content)) | |
| X = validate_and_order(df) | |
| y = model.predict(X) | |
| df["Predicted Price"] = y | |
| return json.loads(df.to_json(orient="records")) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |