import logging from contextlib import asynccontextmanager import joblib import pandas as pd from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field, field_validator # Updated import from typing import Optional # Added for optional fields # --- Configuration --- QNR_MODEL_FILENAME = "asd_classifier_model.pkl" # Make sure this matches your uploaded file CLASS_NAMES_QNR = ["Non-Autistic", "Autistic"] # 0 maps to No ASD, 1 maps to ASD # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) ml_models = {} # Dictionary to hold loaded models # --- Model Loading Logic --- @asynccontextmanager async def lifespan(app: FastAPI): # Load the ML model during startup logger.info(f"Attempting to load questionnaire model: {QNR_MODEL_FILENAME}") try: ml_models['questionnaire_classifier'] = joblib.load(QNR_MODEL_FILENAME) logger.info("Questionnaire model loaded successfully.") except Exception as e: logger.error(f"Error loading questionnaire model '{QNR_MODEL_FILENAME}': {e}") ml_models['questionnaire_classifier'] = None # Indicate loading failure yield # Clean up the ML models and release the resources ml_models.clear() logger.info("Cleaned up models.") # --- FastAPI App --- app = FastAPI(lifespan=lifespan) # --- CORS Middleware --- app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) # --- Input Data Schema (Pydantic Model) --- # Ensure field names match EXACTLY what the model expects # Added Field defaults and type hints based on your notebook/previous code class QuestionnaireData(BaseModel): A1_Score: int = Field(..., ge=0, le=1) A2_Score: int = Field(..., ge=0, le=1) A3_Score: int = Field(..., ge=0, le=1) A4_Score: int = Field(..., ge=0, le=1) A5_Score: int = Field(..., ge=0, le=1) A6_Score: int = Field(..., ge=0, le=1) A7_Score: int = Field(..., ge=0, le=1) A8_Score: int = Field(..., ge=0, le=1) A9_Score: int = Field(..., ge=0, le=1) A10_Score: int = Field(..., ge=0, le=1) age: Optional[float] = Field(25.0, gt=0, le=120) # Made optional with default gender: str = Field("m", pattern="^(m|f)$") # Allow only m or f ethnicity: str = Field("White-European") jaundice: str = Field("no", pattern="^(yes|no)$") # Allow only yes or no # Corrected field name typo from 'contry_of_res' to 'country_of_res' if needed contry_of_res: str = Field("United States") # Keep original if model expects typo used_app_before: str = Field("no", pattern="^(yes|no)$") # Allow only yes or no result: Optional[float] = Field(0.0) # This might be recalculated or ignored if it was the target age_desc: str = Field("18 and more") relation: str = Field("Self") # Pydantic v2 validator @field_validator('age') def check_age(cls, v): if v is None: return 25.0 # Return default if None is explicitly passed if not (0 < v <= 120): raise ValueError('Age must be between 0 and 120') return v # You might add more validators for other fields if needed # --- Prediction Endpoint (Questionnaire) --- @app.post("/predict_questionnaire/") async def predict_questionnaire(data: QuestionnaireData): """Receives questionnaire data, preprocesses using loaded pipeline, returns prediction.""" if ml_models.get('questionnaire_classifier') is None: logger.error("Questionnaire model is not loaded.") raise HTTPException(status_code=500, detail="Questionnaire model could not be loaded") try: # Convert Pydantic model to dictionary, then to DataFrame input_data = data.model_dump() # Use model_dump() in Pydantic v2 logger.info(f"Received data: {input_data}") input_df = pd.DataFrame([input_data]) # Recalculate 'result' based on A_Scores if needed by the model pipeline # (Assuming 'result' column in training was sum of A*_Score) a_scores = [f"A{i}_Score" for i in range(1, 11)] input_df['result'] = input_df[a_scores].sum(axis=1) logger.info(f"Recalculated result score: {input_df['result'].iloc[0]}") # Predict using the loaded pipeline (handles preprocessing) prediction = ml_models['questionnaire_classifier'].predict(input_df) predicted_class_index = int(prediction[0]) # Get probability (optional) # probability = ml_models['questionnaire_classifier'].predict_proba(input_df)[0] # prob_asd = float(probability[1]) # Probability of class 1 (ASD) predicted_class_name = CLASS_NAMES_QNR[predicted_class_index] logger.info(f"Prediction successful: {predicted_class_name}") return {"prediction": predicted_class_name} # If returning probability: return {"prediction": predicted_class_name, "probability_asd": prob_asd} except ValueError as ve: # Catch potential validation errors not caught by Pydantic (e.g., during predict) logger.error(f"Value error during prediction: {ve}") raise HTTPException(status_code=422, detail=f"Invalid input data: {ve}") except Exception as e: logger.error(f"Error during questionnaire prediction: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Prediction error: {e}") # --- Root Endpoint (Optional) --- @app.get("/") async def root(): return {"message": "Autism Questionnaire Classification API"}