codewithharsha commited on
Commit
621f87f
·
verified ·
1 Parent(s): 42d25a6

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +110 -0
main.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+ from contextlib import asynccontextmanager
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from PIL import Image
10
+ from skimage import transform
11
+
12
+ # --- Configuration ---
13
+ IMG_SIZE = 224
14
+ IMG_MODEL_FILENAME = "vgg_model50.h5" # Make sure this matches your uploaded file
15
+ CLASS_NAMES_IMG = ["Non-Autistic", "Autistic"] # Adjust if your VGG model output differs
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ ml_models = {} # Dictionary to hold loaded models
22
+
23
+ # --- Model Loading Logic ---
24
+ @asynccontextmanager
25
+ async def lifespan(app: FastAPI):
26
+ # Load the ML model during startup
27
+ logger.info(f"Attempting to load image model: {IMG_MODEL_FILENAME}")
28
+ try:
29
+ ml_models['image_classifier'] = tf.keras.models.load_model(IMG_MODEL_FILENAME)
30
+ logger.info("Image model loaded successfully.")
31
+ except Exception as e:
32
+ logger.error(f"Error loading image model '{IMG_MODEL_FILENAME}': {e}")
33
+ ml_models['image_classifier'] = None # Indicate loading failure
34
+ yield
35
+ # Clean up the ML models and release the resources
36
+ ml_models.clear()
37
+ logger.info("Cleaned up models.")
38
+
39
+ # --- FastAPI App ---
40
+ app = FastAPI(lifespan=lifespan)
41
+
42
+ # --- CORS Middleware ---
43
+ app.add_middleware(
44
+ CORSMiddleware,
45
+ allow_origins=["*"], # Allows all origins
46
+ allow_credentials=True,
47
+ allow_methods=["*"], # Allows all methods
48
+ allow_headers=["*"], # Allows all headers
49
+ )
50
+
51
+ # --- Image Preprocessing ---
52
+ def preprocess_image(image_bytes: bytes):
53
+ """Loads image bytes, preprocesses, and prepares for VGG16."""
54
+ try:
55
+ img = Image.open(io.BytesIO(image_bytes)).convert('RGB') # Ensure 3 channels
56
+ np_image = np.array(img).astype('float32') / 255.0 # Normalize
57
+ np_image = transform.resize(np_image, (IMG_SIZE, IMG_SIZE, 3))
58
+ np_image = np.expand_dims(np_image, axis=0) # Add batch dimension
59
+ return np_image
60
+ except Exception as e:
61
+ logger.error(f"Error preprocessing image: {e}")
62
+ raise HTTPException(status_code=400, detail=f"Error processing image file: {e}")
63
+
64
+ # --- Prediction Endpoint (Image) ---
65
+ @app.post("/predict/")
66
+ async def predict_image(image: UploadFile = File(...)):
67
+ """Receives an image file, preprocesses it, and returns the VGG16 prediction."""
68
+ if ml_models.get('image_classifier') is None:
69
+ logger.error("Image model is not loaded.")
70
+ raise HTTPException(status_code=500, detail="Image model could not be loaded")
71
+
72
+ logger.info(f"Received image file: {image.filename}")
73
+ image_bytes = await image.read()
74
+ if not image_bytes:
75
+ raise HTTPException(status_code=400, detail="No image data received")
76
+
77
+ # Preprocess the image
78
+ processed_image = preprocess_image(image_bytes)
79
+
80
+ # Make prediction
81
+ try:
82
+ predictions = ml_models['image_classifier'].predict(processed_image)
83
+ predicted_class_index = np.argmax(predictions[0]) # Get index of highest probability
84
+
85
+ # --- IMPORTANT ADJUSTMENT ---
86
+ # Your VGG notebook used sparse_categorical_crossentropy and flow_from_directory
87
+ # with classes=['non_autistic','autistic']. This means index 0 is 'non_autistic' and 1 is 'autistic'.
88
+ # However, the final Dense layer had 95 units (output = Dense(95, activation='softmax')(class1)).
89
+ # This seems like a mismatch. Assuming the binary classification was intended:
90
+ if predicted_class_index < len(CLASS_NAMES_IMG):
91
+ predicted_class_name = CLASS_NAMES_IMG[predicted_class_index]
92
+ # For binary classification, maybe return probability too?
93
+ # probability = float(predictions[0][predicted_class_index])
94
+ else:
95
+ # Handle unexpected index if the model output isn't binary as expected
96
+ predicted_class_name = "Unknown Prediction"
97
+ logger.warning(f"Predicted index {predicted_class_index} is out of bounds for CLASS_NAMES_IMG.")
98
+
99
+ logger.info(f"Prediction successful: {predicted_class_name}")
100
+ return {"prediction": predicted_class_name}
101
+ # If you want probability: return {"prediction": predicted_class_name, "probability": probability}
102
+
103
+ except Exception as e:
104
+ logger.error(f"Error during prediction: {e}")
105
+ raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
106
+
107
+ # --- Root Endpoint (Optional - for health check/info) ---
108
+ @app.get("/")
109
+ async def root():
110
+ return {"message": "Autism Image Classification API"}