Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import numpy as np | |
| from skimage import transform | |
| import io | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| PHOTO_SIZE = 224 | |
| MODEL_FILENAME = "vgg_model50.h5" | |
| CLASS_NAMES = ["Non-Autistic", "Autistic"] # Ensure order matches training | |
| # Load the model | |
| model = None | |
| try: | |
| model = tf.keras.models.load_model(MODEL_FILENAME) | |
| logger.info(f"Model '{MODEL_FILENAME}' loaded successfully.") | |
| # Optional: Warm up the model | |
| # dummy_input = np.zeros((1, PHOTO_SIZE, PHOTO_SIZE, 3), dtype=np.float32) | |
| # model.predict(dummy_input) | |
| # logger.info("Model warmed up.") | |
| except Exception as e: | |
| logger.error(f"Error loading model '{MODEL_FILENAME}': {e}", exc_info=True) | |
| # Depending on deployment, you might want to raise an exception | |
| # or handle this state so the API returns an error gracefully. | |
| # Image preprocessing function | |
| def preprocess_image(image_bytes: bytes): | |
| """Loads image from bytes, resizes, normalizes, and adds batch dim.""" | |
| try: | |
| img = Image.open(io.BytesIO(image_bytes)).convert('RGB') # Ensure 3 channels | |
| logger.info(f"Image opened successfully. Original size: {img.size}") | |
| np_image = np.array(img).astype('float32') / 255.0 # Normalize first | |
| logger.info(f"Image converted to numpy array. Shape: {np_image.shape}") | |
| # Check if resizing is needed and shape is valid before resize | |
| if np_image.shape[:2] != (PHOTO_SIZE, PHOTO_SIZE): | |
| np_image = transform.resize(np_image, (PHOTO_SIZE, PHOTO_SIZE, 3)) # Resize using skimage | |
| logger.info(f"Image resized to: ({PHOTO_SIZE}, {PHOTO_SIZE}, 3)") | |
| else: | |
| logger.info("Image already correct size, skipping resize.") | |
| # Ensure the shape is correct after potential resize | |
| if np_image.shape != (PHOTO_SIZE, PHOTO_SIZE, 3): | |
| raise ValueError(f"Unexpected image shape after processing: {np_image.shape}") | |
| np_image = np.expand_dims(np_image, axis=0) # Add batch dimension | |
| logger.info(f"Batch dimension added. Final shape: {np_image.shape}") | |
| return np_image | |
| except Exception as e: | |
| logger.error(f"Error preprocessing image: {e}", exc_info=True) | |
| raise # Re-raise the exception to be caught by the endpoint handler | |
| # Create FastAPI app | |
| app = FastAPI() | |
| # Add CORS middleware to allow requests from your Arduino/browser | |
| origins = [ | |
| "*", # Allow all origins - Be more restrictive in production! | |
| # e.g., "http://your-arduino-ip", "null" for local file testing | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allow all methods (GET, POST, etc.) | |
| allow_headers=["*"], # Allow all headers | |
| ) | |
| async def root(): | |
| return {"message": "Autism Classification API is running. POST image to /predict/"} | |
| async def predict_image(image: UploadFile = File(...)): | |
| """Receives an image file, preprocesses it, and returns prediction.""" | |
| if not model: | |
| logger.error("Model not loaded, cannot predict.") | |
| raise HTTPException(status_code=500, detail="Model is not loaded") | |
| if not image.content_type.startswith("image/"): | |
| logger.warning(f"Invalid file type received: {image.content_type}") | |
| raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.") | |
| try: | |
| image_bytes = await image.read() | |
| logger.info(f"Received image file: {image.filename}, size: {len(image_bytes)} bytes") | |
| processed_image = preprocess_image(image_bytes) | |
| # Make prediction | |
| logger.info("Making prediction...") | |
| prediction = model.predict(processed_image) | |
| logger.info(f"Raw prediction output: {prediction}") | |
| # Get the index of the highest probability | |
| predicted_index = np.argmax(prediction, axis=1)[0] | |
| # Get the corresponding class name | |
| predicted_class = CLASS_NAMES[predicted_index] | |
| logger.info(f"Predicted index: {predicted_index}, Predicted class: {predicted_class}") | |
| return {"prediction": predicted_class} | |
| except ValueError as ve: # Catch specific preprocessing errors | |
| logger.error(f"ValueError during prediction: {ve}", exc_info=True) | |
| raise HTTPException(status_code=400, detail=f"Image processing error: {ve}") | |
| except Exception as e: | |
| logger.error(f"An unexpected error occurred during prediction: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {e}") | |
| # If running directly (e.g., locally for testing), use uvicorn | |
| # On Hugging Face Spaces, this part is usually not needed as Spaces handles the server start. | |
| # if __name__ == "__main__": | |
| # import uvicorn | |
| # uvicorn.run(app, host="0.0.0.0", port=8000) |