codewithharsha's picture
Update main.py
ad5978a verified
import tensorflow as tf
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import numpy as np
from skimage import transform
import io
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
PHOTO_SIZE = 224
MODEL_FILENAME = "vgg_model50.h5"
CLASS_NAMES = ["Non-Autistic", "Autistic"] # Ensure order matches training
# Load the model
model = None
try:
model = tf.keras.models.load_model(MODEL_FILENAME)
logger.info(f"Model '{MODEL_FILENAME}' loaded successfully.")
# Optional: Warm up the model
# dummy_input = np.zeros((1, PHOTO_SIZE, PHOTO_SIZE, 3), dtype=np.float32)
# model.predict(dummy_input)
# logger.info("Model warmed up.")
except Exception as e:
logger.error(f"Error loading model '{MODEL_FILENAME}': {e}", exc_info=True)
# Depending on deployment, you might want to raise an exception
# or handle this state so the API returns an error gracefully.
# Image preprocessing function
def preprocess_image(image_bytes: bytes):
"""Loads image from bytes, resizes, normalizes, and adds batch dim."""
try:
img = Image.open(io.BytesIO(image_bytes)).convert('RGB') # Ensure 3 channels
logger.info(f"Image opened successfully. Original size: {img.size}")
np_image = np.array(img).astype('float32') / 255.0 # Normalize first
logger.info(f"Image converted to numpy array. Shape: {np_image.shape}")
# Check if resizing is needed and shape is valid before resize
if np_image.shape[:2] != (PHOTO_SIZE, PHOTO_SIZE):
np_image = transform.resize(np_image, (PHOTO_SIZE, PHOTO_SIZE, 3)) # Resize using skimage
logger.info(f"Image resized to: ({PHOTO_SIZE}, {PHOTO_SIZE}, 3)")
else:
logger.info("Image already correct size, skipping resize.")
# Ensure the shape is correct after potential resize
if np_image.shape != (PHOTO_SIZE, PHOTO_SIZE, 3):
raise ValueError(f"Unexpected image shape after processing: {np_image.shape}")
np_image = np.expand_dims(np_image, axis=0) # Add batch dimension
logger.info(f"Batch dimension added. Final shape: {np_image.shape}")
return np_image
except Exception as e:
logger.error(f"Error preprocessing image: {e}", exc_info=True)
raise # Re-raise the exception to be caught by the endpoint handler
# Create FastAPI app
app = FastAPI()
# Add CORS middleware to allow requests from your Arduino/browser
origins = [
"*", # Allow all origins - Be more restrictive in production!
# e.g., "http://your-arduino-ip", "null" for local file testing
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
allow_headers=["*"], # Allow all headers
)
@app.get("/")
async def root():
return {"message": "Autism Classification API is running. POST image to /predict/"}
@app.post("/predict/")
async def predict_image(image: UploadFile = File(...)):
"""Receives an image file, preprocesses it, and returns prediction."""
if not model:
logger.error("Model not loaded, cannot predict.")
raise HTTPException(status_code=500, detail="Model is not loaded")
if not image.content_type.startswith("image/"):
logger.warning(f"Invalid file type received: {image.content_type}")
raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
try:
image_bytes = await image.read()
logger.info(f"Received image file: {image.filename}, size: {len(image_bytes)} bytes")
processed_image = preprocess_image(image_bytes)
# Make prediction
logger.info("Making prediction...")
prediction = model.predict(processed_image)
logger.info(f"Raw prediction output: {prediction}")
# Get the index of the highest probability
predicted_index = np.argmax(prediction, axis=1)[0]
# Get the corresponding class name
predicted_class = CLASS_NAMES[predicted_index]
logger.info(f"Predicted index: {predicted_index}, Predicted class: {predicted_class}")
return {"prediction": predicted_class}
except ValueError as ve: # Catch specific preprocessing errors
logger.error(f"ValueError during prediction: {ve}", exc_info=True)
raise HTTPException(status_code=400, detail=f"Image processing error: {ve}")
except Exception as e:
logger.error(f"An unexpected error occurred during prediction: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
# If running directly (e.g., locally for testing), use uvicorn
# On Hugging Face Spaces, this part is usually not needed as Spaces handles the server start.
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)