codewithharsha commited on
Commit
b2eb69a
·
verified ·
1 Parent(s): a9cce6e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +85 -64
main.py CHANGED
@@ -1,18 +1,16 @@
1
- import io
2
  import logging
3
  from contextlib import asynccontextmanager
4
 
5
- import numpy as np
6
- import tensorflow as tf
7
- from fastapi import FastAPI, File, UploadFile, HTTPException
8
  from fastapi.middleware.cors import CORSMiddleware
9
- from PIL import Image
10
- from skimage import transform
11
 
12
  # --- Configuration ---
13
- IMG_SIZE = 224
14
- IMG_MODEL_FILENAME = "vgg_model50.h5" # Make sure this matches your uploaded file
15
- CLASS_NAMES_IMG = ["Non-Autistic", "Autistic"] # Adjust if your VGG model output differs
16
 
17
  # Setup logging
18
  logging.basicConfig(level=logging.INFO)
@@ -24,13 +22,13 @@ ml_models = {} # Dictionary to hold loaded models
24
  @asynccontextmanager
25
  async def lifespan(app: FastAPI):
26
  # Load the ML model during startup
27
- logger.info(f"Attempting to load image model: {IMG_MODEL_FILENAME}")
28
  try:
29
- ml_models['image_classifier'] = tf.keras.models.load_model(IMG_MODEL_FILENAME)
30
- logger.info("Image model loaded successfully.")
31
  except Exception as e:
32
- logger.error(f"Error loading image model '{IMG_MODEL_FILENAME}': {e}")
33
- ml_models['image_classifier'] = None # Indicate loading failure
34
  yield
35
  # Clean up the ML models and release the resources
36
  ml_models.clear()
@@ -48,63 +46,86 @@ app.add_middleware(
48
  allow_headers=["*"], # Allows all headers
49
  )
50
 
51
- # --- Image Preprocessing ---
52
- def preprocess_image(image_bytes: bytes):
53
- """Loads image bytes, preprocesses, and prepares for VGG16."""
54
- try:
55
- img = Image.open(io.BytesIO(image_bytes)).convert('RGB') # Ensure 3 channels
56
- np_image = np.array(img).astype('float32') / 255.0 # Normalize
57
- np_image = transform.resize(np_image, (IMG_SIZE, IMG_SIZE, 3))
58
- np_image = np.expand_dims(np_image, axis=0) # Add batch dimension
59
- return np_image
60
- except Exception as e:
61
- logger.error(f"Error preprocessing image: {e}")
62
- raise HTTPException(status_code=400, detail=f"Error processing image file: {e}")
63
-
64
- # --- Prediction Endpoint (Image) ---
65
- @app.post("/predict/")
66
- async def predict_image(image: UploadFile = File(...)):
67
- """Receives an image file, preprocesses it, and returns the VGG16 prediction."""
68
- if ml_models.get('image_classifier') is None:
69
- logger.error("Image model is not loaded.")
70
- raise HTTPException(status_code=500, detail="Image model could not be loaded")
71
-
72
- logger.info(f"Received image file: {image.filename}")
73
- image_bytes = await image.read()
74
- if not image_bytes:
75
- raise HTTPException(status_code=400, detail="No image data received")
76
-
77
- # Preprocess the image
78
- processed_image = preprocess_image(image_bytes)
79
-
80
- # Make prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  try:
82
- predictions = ml_models['image_classifier'].predict(processed_image)
83
- predicted_class_index = np.argmax(predictions[0]) # Get index of highest probability
84
-
85
- # --- IMPORTANT ADJUSTMENT ---
86
- # Your VGG notebook used sparse_categorical_crossentropy and flow_from_directory
87
- # with classes=['non_autistic','autistic']. This means index 0 is 'non_autistic' and 1 is 'autistic'.
88
- # However, the final Dense layer had 95 units (output = Dense(95, activation='softmax')(class1)).
89
- # This seems like a mismatch. Assuming the binary classification was intended:
90
- if predicted_class_index < len(CLASS_NAMES_IMG):
91
- predicted_class_name = CLASS_NAMES_IMG[predicted_class_index]
92
- # For binary classification, maybe return probability too?
93
- # probability = float(predictions[0][predicted_class_index])
94
- else:
95
- # Handle unexpected index if the model output isn't binary as expected
96
- predicted_class_name = "Unknown Prediction"
97
- logger.warning(f"Predicted index {predicted_class_index} is out of bounds for CLASS_NAMES_IMG.")
98
 
 
 
 
 
 
 
 
 
 
 
99
  logger.info(f"Prediction successful: {predicted_class_name}")
 
100
  return {"prediction": predicted_class_name}
101
- # If you want probability: return {"prediction": predicted_class_name, "probability": probability}
102
 
 
 
 
 
103
  except Exception as e:
104
- logger.error(f"Error during prediction: {e}")
105
  raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
106
 
107
- # --- Root Endpoint (Optional - for health check/info) ---
108
  @app.get("/")
109
  async def root():
110
- return {"message": "Autism Image Classification API"}
 
 
1
  import logging
2
  from contextlib import asynccontextmanager
3
 
4
+ import joblib
5
+ import pandas as pd
6
+ from fastapi import FastAPI, HTTPException
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel, Field, field_validator # Updated import
9
+ from typing import Optional # Added for optional fields
10
 
11
  # --- Configuration ---
12
+ QNR_MODEL_FILENAME = "asd_classifier_model.pkl" # Make sure this matches your uploaded file
13
+ CLASS_NAMES_QNR = ["Non-Autistic", "Autistic"] # 0 maps to No ASD, 1 maps to ASD
 
14
 
15
  # Setup logging
16
  logging.basicConfig(level=logging.INFO)
 
22
  @asynccontextmanager
23
  async def lifespan(app: FastAPI):
24
  # Load the ML model during startup
25
+ logger.info(f"Attempting to load questionnaire model: {QNR_MODEL_FILENAME}")
26
  try:
27
+ ml_models['questionnaire_classifier'] = joblib.load(QNR_MODEL_FILENAME)
28
+ logger.info("Questionnaire model loaded successfully.")
29
  except Exception as e:
30
+ logger.error(f"Error loading questionnaire model '{QNR_MODEL_FILENAME}': {e}")
31
+ ml_models['questionnaire_classifier'] = None # Indicate loading failure
32
  yield
33
  # Clean up the ML models and release the resources
34
  ml_models.clear()
 
46
  allow_headers=["*"], # Allows all headers
47
  )
48
 
49
+ # --- Input Data Schema (Pydantic Model) ---
50
+ # Ensure field names match EXACTLY what the model expects
51
+ # Added Field defaults and type hints based on your notebook/previous code
52
+ class QuestionnaireData(BaseModel):
53
+ A1_Score: int = Field(..., ge=0, le=1)
54
+ A2_Score: int = Field(..., ge=0, le=1)
55
+ A3_Score: int = Field(..., ge=0, le=1)
56
+ A4_Score: int = Field(..., ge=0, le=1)
57
+ A5_Score: int = Field(..., ge=0, le=1)
58
+ A6_Score: int = Field(..., ge=0, le=1)
59
+ A7_Score: int = Field(..., ge=0, le=1)
60
+ A8_Score: int = Field(..., ge=0, le=1)
61
+ A9_Score: int = Field(..., ge=0, le=1)
62
+ A10_Score: int = Field(..., ge=0, le=1)
63
+ age: Optional[float] = Field(25.0, gt=0, le=120) # Made optional with default
64
+ gender: str = Field("m", pattern="^(m|f)$") # Allow only m or f
65
+ ethnicity: str = Field("White-European")
66
+ jaundice: str = Field("no", pattern="^(yes|no)$") # Allow only yes or no
67
+ # Corrected field name typo from 'contry_of_res' to 'country_of_res' if needed
68
+ contry_of_res: str = Field("United States") # Keep original if model expects typo
69
+ used_app_before: str = Field("no", pattern="^(yes|no)$") # Allow only yes or no
70
+ result: Optional[float] = Field(0.0) # This might be recalculated or ignored if it was the target
71
+ age_desc: str = Field("18 and more")
72
+ relation: str = Field("Self")
73
+
74
+ # Pydantic v2 validator
75
+ @field_validator('age')
76
+ def check_age(cls, v):
77
+ if v is None:
78
+ return 25.0 # Return default if None is explicitly passed
79
+ if not (0 < v <= 120):
80
+ raise ValueError('Age must be between 0 and 120')
81
+ return v
82
+
83
+ # You might add more validators for other fields if needed
84
+
85
+ # --- Prediction Endpoint (Questionnaire) ---
86
+ @app.post("/predict_questionnaire/")
87
+ async def predict_questionnaire(data: QuestionnaireData):
88
+ """Receives questionnaire data, preprocesses using loaded pipeline, returns prediction."""
89
+ if ml_models.get('questionnaire_classifier') is None:
90
+ logger.error("Questionnaire model is not loaded.")
91
+ raise HTTPException(status_code=500, detail="Questionnaire model could not be loaded")
92
+
93
  try:
94
+ # Convert Pydantic model to dictionary, then to DataFrame
95
+ input_data = data.model_dump() # Use model_dump() in Pydantic v2
96
+ logger.info(f"Received data: {input_data}")
97
+ input_df = pd.DataFrame([input_data])
98
+
99
+ # Recalculate 'result' based on A_Scores if needed by the model pipeline
100
+ # (Assuming 'result' column in training was sum of A*_Score)
101
+ a_scores = [f"A{i}_Score" for i in range(1, 11)]
102
+ input_df['result'] = input_df[a_scores].sum(axis=1)
103
+ logger.info(f"Recalculated result score: {input_df['result'].iloc[0]}")
 
 
 
 
 
 
104
 
105
+
106
+ # Predict using the loaded pipeline (handles preprocessing)
107
+ prediction = ml_models['questionnaire_classifier'].predict(input_df)
108
+ predicted_class_index = int(prediction[0])
109
+
110
+ # Get probability (optional)
111
+ # probability = ml_models['questionnaire_classifier'].predict_proba(input_df)[0]
112
+ # prob_asd = float(probability[1]) # Probability of class 1 (ASD)
113
+
114
+ predicted_class_name = CLASS_NAMES_QNR[predicted_class_index]
115
  logger.info(f"Prediction successful: {predicted_class_name}")
116
+
117
  return {"prediction": predicted_class_name}
118
+ # If returning probability: return {"prediction": predicted_class_name, "probability_asd": prob_asd}
119
 
120
+ except ValueError as ve:
121
+ # Catch potential validation errors not caught by Pydantic (e.g., during predict)
122
+ logger.error(f"Value error during prediction: {ve}")
123
+ raise HTTPException(status_code=422, detail=f"Invalid input data: {ve}")
124
  except Exception as e:
125
+ logger.error(f"Error during questionnaire prediction: {e}", exc_info=True)
126
  raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
127
 
128
+ # --- Root Endpoint (Optional) ---
129
  @app.get("/")
130
  async def root():
131
+ return {"message": "Autism Questionnaire Classification API"}