Update main.py
Browse files
main.py
CHANGED
|
@@ -1,18 +1,16 @@
|
|
| 1 |
-
import io
|
| 2 |
import logging
|
| 3 |
from contextlib import asynccontextmanager
|
| 4 |
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
-
from fastapi import FastAPI,
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
|
| 12 |
# --- Configuration ---
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
CLASS_NAMES_IMG = ["Non-Autistic", "Autistic"] # Adjust if your VGG model output differs
|
| 16 |
|
| 17 |
# Setup logging
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -24,13 +22,13 @@ ml_models = {} # Dictionary to hold loaded models
|
|
| 24 |
@asynccontextmanager
|
| 25 |
async def lifespan(app: FastAPI):
|
| 26 |
# Load the ML model during startup
|
| 27 |
-
logger.info(f"Attempting to load
|
| 28 |
try:
|
| 29 |
-
ml_models['
|
| 30 |
-
logger.info("
|
| 31 |
except Exception as e:
|
| 32 |
-
logger.error(f"Error loading
|
| 33 |
-
ml_models['
|
| 34 |
yield
|
| 35 |
# Clean up the ML models and release the resources
|
| 36 |
ml_models.clear()
|
|
@@ -48,63 +46,86 @@ app.add_middleware(
|
|
| 48 |
allow_headers=["*"], # Allows all headers
|
| 49 |
)
|
| 50 |
|
| 51 |
-
# ---
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
try:
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
#
|
| 88 |
-
#
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
# For binary classification, maybe return probability too?
|
| 93 |
-
# probability = float(predictions[0][predicted_class_index])
|
| 94 |
-
else:
|
| 95 |
-
# Handle unexpected index if the model output isn't binary as expected
|
| 96 |
-
predicted_class_name = "Unknown Prediction"
|
| 97 |
-
logger.warning(f"Predicted index {predicted_class_index} is out of bounds for CLASS_NAMES_IMG.")
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
logger.info(f"Prediction successful: {predicted_class_name}")
|
|
|
|
| 100 |
return {"prediction": predicted_class_name}
|
| 101 |
-
# If
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
except Exception as e:
|
| 104 |
-
logger.error(f"Error during prediction: {e}")
|
| 105 |
raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
|
| 106 |
|
| 107 |
-
# --- Root Endpoint (Optional
|
| 108 |
@app.get("/")
|
| 109 |
async def root():
|
| 110 |
-
return {"message": "Autism
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
from contextlib import asynccontextmanager
|
| 3 |
|
| 4 |
+
import joblib
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from fastapi import FastAPI, HTTPException
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from pydantic import BaseModel, Field, field_validator # Updated import
|
| 9 |
+
from typing import Optional # Added for optional fields
|
| 10 |
|
| 11 |
# --- Configuration ---
|
| 12 |
+
QNR_MODEL_FILENAME = "asd_classifier_model.pkl" # Make sure this matches your uploaded file
|
| 13 |
+
CLASS_NAMES_QNR = ["Non-Autistic", "Autistic"] # 0 maps to No ASD, 1 maps to ASD
|
|
|
|
| 14 |
|
| 15 |
# Setup logging
|
| 16 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 22 |
@asynccontextmanager
|
| 23 |
async def lifespan(app: FastAPI):
|
| 24 |
# Load the ML model during startup
|
| 25 |
+
logger.info(f"Attempting to load questionnaire model: {QNR_MODEL_FILENAME}")
|
| 26 |
try:
|
| 27 |
+
ml_models['questionnaire_classifier'] = joblib.load(QNR_MODEL_FILENAME)
|
| 28 |
+
logger.info("Questionnaire model loaded successfully.")
|
| 29 |
except Exception as e:
|
| 30 |
+
logger.error(f"Error loading questionnaire model '{QNR_MODEL_FILENAME}': {e}")
|
| 31 |
+
ml_models['questionnaire_classifier'] = None # Indicate loading failure
|
| 32 |
yield
|
| 33 |
# Clean up the ML models and release the resources
|
| 34 |
ml_models.clear()
|
|
|
|
| 46 |
allow_headers=["*"], # Allows all headers
|
| 47 |
)
|
| 48 |
|
| 49 |
+
# --- Input Data Schema (Pydantic Model) ---
|
| 50 |
+
# Ensure field names match EXACTLY what the model expects
|
| 51 |
+
# Added Field defaults and type hints based on your notebook/previous code
|
| 52 |
+
class QuestionnaireData(BaseModel):
|
| 53 |
+
A1_Score: int = Field(..., ge=0, le=1)
|
| 54 |
+
A2_Score: int = Field(..., ge=0, le=1)
|
| 55 |
+
A3_Score: int = Field(..., ge=0, le=1)
|
| 56 |
+
A4_Score: int = Field(..., ge=0, le=1)
|
| 57 |
+
A5_Score: int = Field(..., ge=0, le=1)
|
| 58 |
+
A6_Score: int = Field(..., ge=0, le=1)
|
| 59 |
+
A7_Score: int = Field(..., ge=0, le=1)
|
| 60 |
+
A8_Score: int = Field(..., ge=0, le=1)
|
| 61 |
+
A9_Score: int = Field(..., ge=0, le=1)
|
| 62 |
+
A10_Score: int = Field(..., ge=0, le=1)
|
| 63 |
+
age: Optional[float] = Field(25.0, gt=0, le=120) # Made optional with default
|
| 64 |
+
gender: str = Field("m", pattern="^(m|f)$") # Allow only m or f
|
| 65 |
+
ethnicity: str = Field("White-European")
|
| 66 |
+
jaundice: str = Field("no", pattern="^(yes|no)$") # Allow only yes or no
|
| 67 |
+
# Corrected field name typo from 'contry_of_res' to 'country_of_res' if needed
|
| 68 |
+
contry_of_res: str = Field("United States") # Keep original if model expects typo
|
| 69 |
+
used_app_before: str = Field("no", pattern="^(yes|no)$") # Allow only yes or no
|
| 70 |
+
result: Optional[float] = Field(0.0) # This might be recalculated or ignored if it was the target
|
| 71 |
+
age_desc: str = Field("18 and more")
|
| 72 |
+
relation: str = Field("Self")
|
| 73 |
+
|
| 74 |
+
# Pydantic v2 validator
|
| 75 |
+
@field_validator('age')
|
| 76 |
+
def check_age(cls, v):
|
| 77 |
+
if v is None:
|
| 78 |
+
return 25.0 # Return default if None is explicitly passed
|
| 79 |
+
if not (0 < v <= 120):
|
| 80 |
+
raise ValueError('Age must be between 0 and 120')
|
| 81 |
+
return v
|
| 82 |
+
|
| 83 |
+
# You might add more validators for other fields if needed
|
| 84 |
+
|
| 85 |
+
# --- Prediction Endpoint (Questionnaire) ---
|
| 86 |
+
@app.post("/predict_questionnaire/")
|
| 87 |
+
async def predict_questionnaire(data: QuestionnaireData):
|
| 88 |
+
"""Receives questionnaire data, preprocesses using loaded pipeline, returns prediction."""
|
| 89 |
+
if ml_models.get('questionnaire_classifier') is None:
|
| 90 |
+
logger.error("Questionnaire model is not loaded.")
|
| 91 |
+
raise HTTPException(status_code=500, detail="Questionnaire model could not be loaded")
|
| 92 |
+
|
| 93 |
try:
|
| 94 |
+
# Convert Pydantic model to dictionary, then to DataFrame
|
| 95 |
+
input_data = data.model_dump() # Use model_dump() in Pydantic v2
|
| 96 |
+
logger.info(f"Received data: {input_data}")
|
| 97 |
+
input_df = pd.DataFrame([input_data])
|
| 98 |
+
|
| 99 |
+
# Recalculate 'result' based on A_Scores if needed by the model pipeline
|
| 100 |
+
# (Assuming 'result' column in training was sum of A*_Score)
|
| 101 |
+
a_scores = [f"A{i}_Score" for i in range(1, 11)]
|
| 102 |
+
input_df['result'] = input_df[a_scores].sum(axis=1)
|
| 103 |
+
logger.info(f"Recalculated result score: {input_df['result'].iloc[0]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
+
|
| 106 |
+
# Predict using the loaded pipeline (handles preprocessing)
|
| 107 |
+
prediction = ml_models['questionnaire_classifier'].predict(input_df)
|
| 108 |
+
predicted_class_index = int(prediction[0])
|
| 109 |
+
|
| 110 |
+
# Get probability (optional)
|
| 111 |
+
# probability = ml_models['questionnaire_classifier'].predict_proba(input_df)[0]
|
| 112 |
+
# prob_asd = float(probability[1]) # Probability of class 1 (ASD)
|
| 113 |
+
|
| 114 |
+
predicted_class_name = CLASS_NAMES_QNR[predicted_class_index]
|
| 115 |
logger.info(f"Prediction successful: {predicted_class_name}")
|
| 116 |
+
|
| 117 |
return {"prediction": predicted_class_name}
|
| 118 |
+
# If returning probability: return {"prediction": predicted_class_name, "probability_asd": prob_asd}
|
| 119 |
|
| 120 |
+
except ValueError as ve:
|
| 121 |
+
# Catch potential validation errors not caught by Pydantic (e.g., during predict)
|
| 122 |
+
logger.error(f"Value error during prediction: {ve}")
|
| 123 |
+
raise HTTPException(status_code=422, detail=f"Invalid input data: {ve}")
|
| 124 |
except Exception as e:
|
| 125 |
+
logger.error(f"Error during questionnaire prediction: {e}", exc_info=True)
|
| 126 |
raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
|
| 127 |
|
| 128 |
+
# --- Root Endpoint (Optional) ---
|
| 129 |
@app.get("/")
|
| 130 |
async def root():
|
| 131 |
+
return {"message": "Autism Questionnaire Classification API"}
|