titanic-api / main.py
Myloiose's picture
Create main.py
6331ec5 verified
raw
history blame contribute delete
605 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
import numpy as np
from tensorflow.keras.models import load_model
app = FastAPI()
model = load_model("titanic_model.h5")
class InputData(BaseModel):
pclass: int
sex: str
age: float
fare: float
@app.post("/predict")
async def predict(data: InputData):
sex_num = 1 if data.sex.lower() == "male" else 0
input_array = np.array([[data.pclass, sex_num, data.age, data.fare]])
prediction = model.predict(input_array)[0][0]
result = "Sobrevivi贸" if prediction > 0.5 else "No sobrevivi贸"
return {"data": [result]}