Myloiose commited on
Commit
dccd1a0
verified
1 Parent(s): cdb9e42

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -0
main.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import numpy as np
4
+ from tensorflow.keras.models import load_model
5
+
6
+ app = FastAPI()
7
+ model = load_model("titanic_model.h5")
8
+
9
+ class InputData(BaseModel):
10
+ pclass: int
11
+ sex: str
12
+ age: float
13
+ fare: float
14
+
15
+ @app.post("/predict")
16
+ async def predict(data: InputData):
17
+ sex_num = 1 if data.sex.lower() == "male" else 0
18
+ input_array = np.array([[data.pclass, sex_num, data.age, data.fare]])
19
+ prediction = model.predict(input_array)[0][0]
20
+ result = "Sobrevivi贸" if prediction > 0.5 else "No sobrevivi贸"
21
+ return {"data": [result]}