from fastapi import FastAPI, HTTPException from pydantic import BaseModel import joblib import pandas as pd import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load the model try: model = joblib.load("titanic_model.pkl") logger.info(f"Model loaded successfully. Feature names: {model.feature_names_in_}") except Exception as e: logger.error(f"Error loading model: {e}") raise # Create the Pydantic model for the input data class Passenger(BaseModel): pclass: int sex: str age: float sibsp: int parch: int fare: float embarked: str # { # "pclass": 1, # "sex": "male", # "age": 30, # "sibsp": 0, # "parch": 0, # "fare": 100, # "embarked": "S" # } # Create the FastAPI instance app = FastAPI() # Create the root endpoint @app.get("/") def read_root(): return {"message": "Welcome to the Titanic Survival Prediction API"} # Create the predict endpoint @app.post("/predict") def predict(passenger: Passenger): try: # Convert the input data to a DataFrame input_dict = passenger.model_dump() logger.info(f"Input data: {input_dict}") input_data = pd.DataFrame([input_dict]) logger.info(f"DataFrame created with columns: {input_data.columns.tolist()}") # One-Hot Encode the input data input_data = pd.get_dummies(input_data) logger.info(f"After one-hot encoding, columns: {input_data.columns.tolist()}") # Check if model has feature_names_in_ attribute if not hasattr(model, 'feature_names_in_'): raise HTTPException(status_code=500, detail="Model does not have feature_names_in_ attribute") logger.info(f"Model expects columns: {model.feature_names_in_}") # Align the input data columns with the model columns input_data = input_data.reindex(columns=model.feature_names_in_, fill_value=0) logger.info(f"After reindexing, columns: {input_data.columns.tolist()}") # Check if we have the right number of features if input_data.shape[1] != len(model.feature_names_in_): raise HTTPException( status_code=500, detail=f"Feature mismatch: Input has {input_data.shape[1]} features, model expects {len(model.feature_names_in_)}" ) # Predict the survival of the passenger prediction = model.predict(input_data) return { "prediction": int(prediction[0]), "prediction_probability": float(model.predict_proba(input_data)[0][1]) if hasattr(model, 'predict_proba') else None } except Exception as e: logger.error(f"Prediction error: {e}") raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")