mnist-digit-api / app.py
apexherbert200's picture
Second commit
2479cdc
raw
history blame contribute delete
538 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import numpy as np
app = FastAPI()
model = joblib.load("./model.pkl")
# print(type(model.feature_names_in_))
class inputData(BaseModel):
image_x: list[float]
@app.get("/")
def read_root():
return {"message": "Welcome to Digit classifier page ;)"}
@app.post("/predict")
def predict(data: inputData):
datum = np.array([data.image_x])
pred = model.predict(datum)
predicted_class = int(pred[0])
return {"prediction": predicted_class}