codewithharsha commited on
Commit
de2a195
·
verified ·
1 Parent(s): 9811f65

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +126 -0
main.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from fastapi import FastAPI, File, UploadFile, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from PIL import Image
5
+ import numpy as np
6
+ from skimage import transform
7
+ import io
8
+ import logging
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Configuration
15
+ PHOTO_SIZE = 224
16
+ MODEL_FILENAME = "vgg_model50.h5"
17
+ CLASS_NAMES = ["Non-Autistic", "Autistic"] # Ensure order matches training
18
+
19
+ # Load the model
20
+ model = None
21
+ try:
22
+ model = tf.keras.models.load_model(MODEL_FILENAME)
23
+ logger.info(f"Model '{MODEL_FILENAME}' loaded successfully.")
24
+ # Optional: Warm up the model
25
+ # dummy_input = np.zeros((1, PHOTO_SIZE, PHOTO_SIZE, 3), dtype=np.float32)
26
+ # model.predict(dummy_input)
27
+ # logger.info("Model warmed up.")
28
+ except Exception as e:
29
+ logger.error(f"Error loading model '{MODEL_FILENAME}': {e}", exc_info=True)
30
+ # Depending on deployment, you might want to raise an exception
31
+ # or handle this state so the API returns an error gracefully.
32
+
33
+ # Image preprocessing function
34
+ def preprocess_image(image_bytes: bytes):
35
+ """Loads image from bytes, resizes, normalizes, and adds batch dim."""
36
+ try:
37
+ img = Image.open(io.BytesIO(image_bytes)).convert('RGB') # Ensure 3 channels
38
+ logger.info(f"Image opened successfully. Original size: {img.size}")
39
+
40
+ np_image = np.array(img).astype('float32') / 255.0 # Normalize first
41
+ logger.info(f"Image converted to numpy array. Shape: {np_image.shape}")
42
+
43
+ # Check if resizing is needed and shape is valid before resize
44
+ if np_image.shape[:2] != (PHOTO_SIZE, PHOTO_SIZE):
45
+ np_image = transform.resize(np_image, (PHOTO_SIZE, PHOTO_SIZE, 3)) # Resize using skimage
46
+ logger.info(f"Image resized to: ({PHOTO_SIZE}, {PHOTO_SIZE}, 3)")
47
+ else:
48
+ logger.info("Image already correct size, skipping resize.")
49
+
50
+
51
+ # Ensure the shape is correct after potential resize
52
+ if np_image.shape != (PHOTO_SIZE, PHOTO_SIZE, 3):
53
+ raise ValueError(f"Unexpected image shape after processing: {np_image.shape}")
54
+
55
+ np_image = np.expand_dims(np_image, axis=0) # Add batch dimension
56
+ logger.info(f"Batch dimension added. Final shape: {np_image.shape}")
57
+ return np_image
58
+ except Exception as e:
59
+ logger.error(f"Error preprocessing image: {e}", exc_info=True)
60
+ raise # Re-raise the exception to be caught by the endpoint handler
61
+
62
+ # Create FastAPI app
63
+ app = FastAPI()
64
+
65
+ # Add CORS middleware to allow requests from your Arduino/browser
66
+ origins = [
67
+ "*", # Allow all origins - Be more restrictive in production!
68
+ # e.g., "http://your-arduino-ip", "null" for local file testing
69
+ ]
70
+
71
+ app.add_middleware(
72
+ CORSMiddleware,
73
+ allow_origins=origins,
74
+ allow_credentials=True,
75
+ allow_methods=["*"], # Allow all methods (GET, POST, etc.)
76
+ allow_headers=["*"], # Allow all headers
77
+ )
78
+
79
+ @app.get("/")
80
+ async def root():
81
+ return {"message": "Autism Classification API is running. POST image to /predict/"}
82
+
83
+ @app.post("/predict/")
84
+ async def predict_image(image: UploadFile = File(...)):
85
+ """Receives an image file, preprocesses it, and returns prediction."""
86
+ if not model:
87
+ logger.error("Model not loaded, cannot predict.")
88
+ raise HTTPException(status_code=500, detail="Model is not loaded")
89
+
90
+ if not image.content_type.startswith("image/"):
91
+ logger.warning(f"Invalid file type received: {image.content_type}")
92
+ raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
93
+
94
+ try:
95
+ image_bytes = await image.read()
96
+ logger.info(f"Received image file: {image.filename}, size: {len(image_bytes)} bytes")
97
+
98
+ processed_image = preprocess_image(image_bytes)
99
+
100
+ # Make prediction
101
+ logger.info("Making prediction...")
102
+ prediction = model.predict(processed_image)
103
+ logger.info(f"Raw prediction output: {prediction}")
104
+
105
+ # Get the index of the highest probability
106
+ predicted_index = np.argmax(prediction, axis=1)[0]
107
+
108
+ # Get the corresponding class name
109
+ predicted_class = CLASS_NAMES[predicted_index]
110
+ logger.info(f"Predicted index: {predicted_index}, Predicted class: {predicted_class}")
111
+
112
+ return {"prediction": predicted_class}
113
+
114
+ except ValueError as ve: # Catch specific preprocessing errors
115
+ logger.error(f"ValueError during prediction: {ve}", exc_info=True)
116
+ raise HTTPException(status_code=400, detail=f"Image processing error: {ve}")
117
+ except Exception as e:
118
+ logger.error(f"An unexpected error occurred during prediction: {e}", exc_info=True)
119
+ raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
120
+
121
+ # If running directly (e.g., locally for testing), use uvicorn
122
+ # On Hugging Face Spaces, this part is usually not needed as Spaces handles the server start.
123
+ # if __name__ == "__main__":
124
+ # import uvicorn
125
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
126
+