JunYuanNYP's picture
Update app.py
9cc6ce3 verified
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import xgboost as xgb
import pandas as pd
# Load model
model = xgb.XGBRegressor()
model.load_model("timePrediction.json")
class InputData(BaseModel):
Quantity: int
Product_Type: str
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def encode_product(product_type: str):
product_map = {
"Lemon Scent Dishwashing Liquid": [1, 0, 0, 0, 0, 0, 0, 0],
"Antibacterial Dishwashing Gel": [0, 1, 0, 0, 0, 0, 0, 0],
"Unbleached Baking Paper": [0, 0, 1, 0, 0, 0, 0, 0],
"Silicone-coated baking sheet": [0, 0, 0, 1, 0, 0, 0, 0],
"Disposable plastic bag": [0, 0, 0, 0, 1, 0, 0, 0],
"Lavender air freshener sachet": [0, 0, 0, 0, 0, 1, 0, 0],
"Mothballs": [0, 0, 0, 0, 0, 0, 1, 0],
"Air Fryer Paper": [0, 0, 0, 0, 0, 0, 0, 1],
}
return product_map.get(product_type, [0]*8)
@app.get("/")
def start():
return "Hello World"
@app.post("/predict")
def predict(data: InputData):
features = [data.Quantity] + encode_product(data.Product_Type)
columns = ['Quantity',
'Lemon Scent Dishwashing Liquid',
'Antibacterial Dishwashing Gel',
'Unbleached Baking Paper',
'Silicone-coated baking sheet',
'Disposable plastic bag',
'Lavender air freshener sachet',
'Mothballs',
'Air Fryer Paper']
X = pd.DataFrame([features], columns=columns)
pred = model.predict(X)[0]
final_pred = pred * 0.3
return {"predicted_time": float(round(pred, 8))}