| import logging |
| from contextlib import asynccontextmanager |
|
|
| import joblib |
| import pandas as pd |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field, field_validator |
| from typing import Optional |
|
|
| |
| QNR_MODEL_FILENAME = "asd_classifier_model.pkl" |
| CLASS_NAMES_QNR = ["Non-Autistic", "Autistic"] |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| ml_models = {} |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| logger.info(f"Attempting to load questionnaire model: {QNR_MODEL_FILENAME}") |
| try: |
| ml_models['questionnaire_classifier'] = joblib.load(QNR_MODEL_FILENAME) |
| logger.info("Questionnaire model loaded successfully.") |
| except Exception as e: |
| logger.error(f"Error loading questionnaire model '{QNR_MODEL_FILENAME}': {e}") |
| ml_models['questionnaire_classifier'] = None |
| yield |
| |
| ml_models.clear() |
| logger.info("Cleaned up models.") |
|
|
| |
| app = FastAPI(lifespan=lifespan) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| |
| |
| class QuestionnaireData(BaseModel): |
| A1_Score: int = Field(..., ge=0, le=1) |
| A2_Score: int = Field(..., ge=0, le=1) |
| A3_Score: int = Field(..., ge=0, le=1) |
| A4_Score: int = Field(..., ge=0, le=1) |
| A5_Score: int = Field(..., ge=0, le=1) |
| A6_Score: int = Field(..., ge=0, le=1) |
| A7_Score: int = Field(..., ge=0, le=1) |
| A8_Score: int = Field(..., ge=0, le=1) |
| A9_Score: int = Field(..., ge=0, le=1) |
| A10_Score: int = Field(..., ge=0, le=1) |
| age: Optional[float] = Field(25.0, gt=0, le=120) |
| gender: str = Field("m", pattern="^(m|f)$") |
| ethnicity: str = Field("White-European") |
| jaundice: str = Field("no", pattern="^(yes|no)$") |
| |
| contry_of_res: str = Field("United States") |
| used_app_before: str = Field("no", pattern="^(yes|no)$") |
| result: Optional[float] = Field(0.0) |
| age_desc: str = Field("18 and more") |
| relation: str = Field("Self") |
|
|
| |
| @field_validator('age') |
| def check_age(cls, v): |
| if v is None: |
| return 25.0 |
| if not (0 < v <= 120): |
| raise ValueError('Age must be between 0 and 120') |
| return v |
|
|
| |
|
|
| |
| @app.post("/predict_questionnaire/") |
| async def predict_questionnaire(data: QuestionnaireData): |
| """Receives questionnaire data, preprocesses using loaded pipeline, returns prediction.""" |
| if ml_models.get('questionnaire_classifier') is None: |
| logger.error("Questionnaire model is not loaded.") |
| raise HTTPException(status_code=500, detail="Questionnaire model could not be loaded") |
|
|
| try: |
| |
| input_data = data.model_dump() |
| logger.info(f"Received data: {input_data}") |
| input_df = pd.DataFrame([input_data]) |
|
|
| |
| |
| a_scores = [f"A{i}_Score" for i in range(1, 11)] |
| input_df['result'] = input_df[a_scores].sum(axis=1) |
| logger.info(f"Recalculated result score: {input_df['result'].iloc[0]}") |
|
|
|
|
| |
| prediction = ml_models['questionnaire_classifier'].predict(input_df) |
| predicted_class_index = int(prediction[0]) |
|
|
| |
| |
| |
|
|
| predicted_class_name = CLASS_NAMES_QNR[predicted_class_index] |
| logger.info(f"Prediction successful: {predicted_class_name}") |
|
|
| return {"prediction": predicted_class_name} |
| |
|
|
| except ValueError as ve: |
| |
| logger.error(f"Value error during prediction: {ve}") |
| raise HTTPException(status_code=422, detail=f"Invalid input data: {ve}") |
| except Exception as e: |
| logger.error(f"Error during questionnaire prediction: {e}", exc_info=True) |
| raise HTTPException(status_code=500, detail=f"Prediction error: {e}") |
|
|
| |
| @app.get("/") |
| async def root(): |
| return {"message": "Autism Questionnaire Classification API"} |
|
|