mriz commited on
Commit
73f27fb
·
verified ·
1 Parent(s): 3c3db05

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +90 -0
main.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import joblib
4
+ import pandas as pd
5
+ import logging
6
+
7
+ # Set up logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Load the model
12
+ try:
13
+ model = joblib.load("titanic_model.pkl")
14
+ logger.info(f"Model loaded successfully. Feature names: {model.feature_names_in_}")
15
+ except Exception as e:
16
+ logger.error(f"Error loading model: {e}")
17
+ raise
18
+
19
+ # Create the Pydantic model for the input data
20
+ class Passenger(BaseModel):
21
+ pclass: int
22
+ sex: str
23
+ age: float
24
+ sibsp: int
25
+ parch: int
26
+ fare: float
27
+ embarked: str
28
+
29
+
30
+ # {
31
+ # "pclass": 1,
32
+ # "sex": "male",
33
+ # "age": 30,
34
+ # "sibsp": 0,
35
+ # "parch": 0,
36
+ # "fare": 100,
37
+ # "embarked": "S"
38
+ # }
39
+
40
+ # Create the FastAPI instance
41
+ app = FastAPI()
42
+
43
+ # Create the root endpoint
44
+ @app.get("/")
45
+ def read_root():
46
+ return {"message": "Welcome to the Titanic Survival Prediction API"}
47
+
48
+ # Create the predict endpoint
49
+ @app.post("/predict")
50
+ def predict(passenger: Passenger):
51
+ try:
52
+ # Convert the input data to a DataFrame
53
+ input_dict = passenger.model_dump()
54
+ logger.info(f"Input data: {input_dict}")
55
+
56
+ input_data = pd.DataFrame([input_dict])
57
+ logger.info(f"DataFrame created with columns: {input_data.columns.tolist()}")
58
+
59
+ # One-Hot Encode the input data
60
+ input_data = pd.get_dummies(input_data)
61
+ logger.info(f"After one-hot encoding, columns: {input_data.columns.tolist()}")
62
+
63
+ # Check if model has feature_names_in_ attribute
64
+ if not hasattr(model, 'feature_names_in_'):
65
+ raise HTTPException(status_code=500, detail="Model does not have feature_names_in_ attribute")
66
+
67
+ logger.info(f"Model expects columns: {model.feature_names_in_}")
68
+
69
+ # Align the input data columns with the model columns
70
+ input_data = input_data.reindex(columns=model.feature_names_in_, fill_value=0)
71
+ logger.info(f"After reindexing, columns: {input_data.columns.tolist()}")
72
+
73
+ # Check if we have the right number of features
74
+ if input_data.shape[1] != len(model.feature_names_in_):
75
+ raise HTTPException(
76
+ status_code=500,
77
+ detail=f"Feature mismatch: Input has {input_data.shape[1]} features, model expects {len(model.feature_names_in_)}"
78
+ )
79
+
80
+ # Predict the survival of the passenger
81
+ prediction = model.predict(input_data)
82
+
83
+ return {
84
+ "prediction": int(prediction[0]),
85
+ "prediction_probability": float(model.predict_proba(input_data)[0][1]) if hasattr(model, 'predict_proba') else None
86
+ }
87
+
88
+ except Exception as e:
89
+ logger.error(f"Prediction error: {e}")
90
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")