codewithharsha's picture
Update main.py
b2eb69a verified
import logging
from contextlib import asynccontextmanager
import joblib
import pandas as pd
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, field_validator # Updated import
from typing import Optional # Added for optional fields
# --- Configuration ---
QNR_MODEL_FILENAME = "asd_classifier_model.pkl" # Make sure this matches your uploaded file
CLASS_NAMES_QNR = ["Non-Autistic", "Autistic"] # 0 maps to No ASD, 1 maps to ASD
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
ml_models = {} # Dictionary to hold loaded models
# --- Model Loading Logic ---
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the ML model during startup
logger.info(f"Attempting to load questionnaire model: {QNR_MODEL_FILENAME}")
try:
ml_models['questionnaire_classifier'] = joblib.load(QNR_MODEL_FILENAME)
logger.info("Questionnaire model loaded successfully.")
except Exception as e:
logger.error(f"Error loading questionnaire model '{QNR_MODEL_FILENAME}': {e}")
ml_models['questionnaire_classifier'] = None # Indicate loading failure
yield
# Clean up the ML models and release the resources
ml_models.clear()
logger.info("Cleaned up models.")
# --- FastAPI App ---
app = FastAPI(lifespan=lifespan)
# --- CORS Middleware ---
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# --- Input Data Schema (Pydantic Model) ---
# Ensure field names match EXACTLY what the model expects
# Added Field defaults and type hints based on your notebook/previous code
class QuestionnaireData(BaseModel):
A1_Score: int = Field(..., ge=0, le=1)
A2_Score: int = Field(..., ge=0, le=1)
A3_Score: int = Field(..., ge=0, le=1)
A4_Score: int = Field(..., ge=0, le=1)
A5_Score: int = Field(..., ge=0, le=1)
A6_Score: int = Field(..., ge=0, le=1)
A7_Score: int = Field(..., ge=0, le=1)
A8_Score: int = Field(..., ge=0, le=1)
A9_Score: int = Field(..., ge=0, le=1)
A10_Score: int = Field(..., ge=0, le=1)
age: Optional[float] = Field(25.0, gt=0, le=120) # Made optional with default
gender: str = Field("m", pattern="^(m|f)$") # Allow only m or f
ethnicity: str = Field("White-European")
jaundice: str = Field("no", pattern="^(yes|no)$") # Allow only yes or no
# Corrected field name typo from 'contry_of_res' to 'country_of_res' if needed
contry_of_res: str = Field("United States") # Keep original if model expects typo
used_app_before: str = Field("no", pattern="^(yes|no)$") # Allow only yes or no
result: Optional[float] = Field(0.0) # This might be recalculated or ignored if it was the target
age_desc: str = Field("18 and more")
relation: str = Field("Self")
# Pydantic v2 validator
@field_validator('age')
def check_age(cls, v):
if v is None:
return 25.0 # Return default if None is explicitly passed
if not (0 < v <= 120):
raise ValueError('Age must be between 0 and 120')
return v
# You might add more validators for other fields if needed
# --- Prediction Endpoint (Questionnaire) ---
@app.post("/predict_questionnaire/")
async def predict_questionnaire(data: QuestionnaireData):
"""Receives questionnaire data, preprocesses using loaded pipeline, returns prediction."""
if ml_models.get('questionnaire_classifier') is None:
logger.error("Questionnaire model is not loaded.")
raise HTTPException(status_code=500, detail="Questionnaire model could not be loaded")
try:
# Convert Pydantic model to dictionary, then to DataFrame
input_data = data.model_dump() # Use model_dump() in Pydantic v2
logger.info(f"Received data: {input_data}")
input_df = pd.DataFrame([input_data])
# Recalculate 'result' based on A_Scores if needed by the model pipeline
# (Assuming 'result' column in training was sum of A*_Score)
a_scores = [f"A{i}_Score" for i in range(1, 11)]
input_df['result'] = input_df[a_scores].sum(axis=1)
logger.info(f"Recalculated result score: {input_df['result'].iloc[0]}")
# Predict using the loaded pipeline (handles preprocessing)
prediction = ml_models['questionnaire_classifier'].predict(input_df)
predicted_class_index = int(prediction[0])
# Get probability (optional)
# probability = ml_models['questionnaire_classifier'].predict_proba(input_df)[0]
# prob_asd = float(probability[1]) # Probability of class 1 (ASD)
predicted_class_name = CLASS_NAMES_QNR[predicted_class_index]
logger.info(f"Prediction successful: {predicted_class_name}")
return {"prediction": predicted_class_name}
# If returning probability: return {"prediction": predicted_class_name, "probability_asd": prob_asd}
except ValueError as ve:
# Catch potential validation errors not caught by Pydantic (e.g., during predict)
logger.error(f"Value error during prediction: {ve}")
raise HTTPException(status_code=422, detail=f"Invalid input data: {ve}")
except Exception as e:
logger.error(f"Error during questionnaire prediction: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
# --- Root Endpoint (Optional) ---
@app.get("/")
async def root():
return {"message": "Autism Questionnaire Classification API"}