Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import joblib | |
| import pandas as pd | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load the model | |
| try: | |
| model = joblib.load("titanic_model.pkl") | |
| logger.info(f"Model loaded successfully. Feature names: {model.feature_names_in_}") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| # Create the Pydantic model for the input data | |
| class Passenger(BaseModel): | |
| pclass: int | |
| sex: str | |
| age: float | |
| sibsp: int | |
| parch: int | |
| fare: float | |
| embarked: str | |
| # { | |
| # "pclass": 1, | |
| # "sex": "male", | |
| # "age": 30, | |
| # "sibsp": 0, | |
| # "parch": 0, | |
| # "fare": 100, | |
| # "embarked": "S" | |
| # } | |
| # Create the FastAPI instance | |
| app = FastAPI() | |
| # Create the root endpoint | |
| def read_root(): | |
| return {"message": "Welcome to the Titanic Survival Prediction API"} | |
| # Create the predict endpoint | |
| def predict(passenger: Passenger): | |
| try: | |
| # Convert the input data to a DataFrame | |
| input_dict = passenger.model_dump() | |
| logger.info(f"Input data: {input_dict}") | |
| input_data = pd.DataFrame([input_dict]) | |
| logger.info(f"DataFrame created with columns: {input_data.columns.tolist()}") | |
| # One-Hot Encode the input data | |
| input_data = pd.get_dummies(input_data) | |
| logger.info(f"After one-hot encoding, columns: {input_data.columns.tolist()}") | |
| # Check if model has feature_names_in_ attribute | |
| if not hasattr(model, 'feature_names_in_'): | |
| raise HTTPException(status_code=500, detail="Model does not have feature_names_in_ attribute") | |
| logger.info(f"Model expects columns: {model.feature_names_in_}") | |
| # Align the input data columns with the model columns | |
| input_data = input_data.reindex(columns=model.feature_names_in_, fill_value=0) | |
| logger.info(f"After reindexing, columns: {input_data.columns.tolist()}") | |
| # Check if we have the right number of features | |
| if input_data.shape[1] != len(model.feature_names_in_): | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Feature mismatch: Input has {input_data.shape[1]} features, model expects {len(model.feature_names_in_)}" | |
| ) | |
| # Predict the survival of the passenger | |
| prediction = model.predict(input_data) | |
| return { | |
| "prediction": int(prediction[0]), | |
| "prediction_probability": float(model.predict_proba(input_data)[0][1]) if hasattr(model, 'predict_proba') else None | |
| } | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |