Penthes commited on
Commit
f8c57f5
·
verified ·
1 Parent(s): 5d01458

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +306 -289
app.py CHANGED
@@ -1,290 +1,307 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.responses import JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
- import tensorflow as tf
5
- import numpy as np
6
- from PIL import Image
7
- import io
8
- import logging
9
- import uvicorn
10
- import os
11
-
12
- # Set up logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
-
16
- # Initialize FastAPI app
17
- app = FastAPI(
18
- title="Waste Classification API",
19
- description="API for classifying waste into categories: Glass, Metal, Organic, Paper, Plastic",
20
- version="1.0.0",
21
- docs_url="/", # Swagger UI at root for easy access
22
- )
23
-
24
- # Add CORS middleware for web access
25
- app.add_middleware(
26
- CORSMiddleware,
27
- allow_origins=["*"],
28
- allow_credentials=True,
29
- allow_methods=["*"],
30
- allow_headers=["*"],
31
- )
32
-
33
- # Global variables - match your training exactly
34
- model = None
35
- # IMPORTANT: Your class order from training (alphabetical from image_dataset_from_directory)
36
- class_labels = ["glass", "metal", "organic", "paper", "plastic"]
37
-
38
- def load_model():
39
- """Load the trained TensorFlow/Keras model"""
40
- try:
41
- # Try loading different formats in order of preference
42
- model_files = [
43
- 'waste_model.keras', # Keras format (recommended)
44
- 'waste_model.h5', # H5 format
45
- 'best_model.keras' # Checkpoint from training
46
- ]
47
-
48
- model = None
49
- for model_file in model_files:
50
- if os.path.exists(model_file):
51
- try:
52
- model = tf.keras.models.load_model(model_file)
53
- logger.info(f"Model loaded successfully from {model_file}")
54
- break
55
- except Exception as e:
56
- logger.warning(f"Failed to load {model_file}: {e}")
57
- continue
58
-
59
- if model is None:
60
- logger.error("No model file found. Creating dummy model for testing.")
61
- # Create dummy model with same architecture for testing
62
- model = tf.keras.Sequential([
63
- tf.keras.layers.Rescaling(1./255, input_shape=(224, 224, 3)),
64
- tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet'),
65
- tf.keras.layers.GlobalAveragePooling2D(),
66
- tf.keras.layers.Dense(128, activation='relu'),
67
- tf.keras.layers.Dropout(0.2),
68
- tf.keras.layers.Dense(5, activation='softmax')
69
- ])
70
- logger.warning("Using dummy model - predictions will be random!")
71
-
72
- return model
73
-
74
- except Exception as e:
75
- logger.error(f"Critical error loading model: {e}")
76
- raise Exception(f"Model loading failed: {e}")
77
-
78
- def preprocess_image(image_data):
79
- """
80
- Preprocess image to match your training pipeline
81
- """
82
- try:
83
- # Load image
84
- image = Image.open(io.BytesIO(image_data)).convert('RGB')
85
-
86
- # Resize to match training (224, 224)
87
- image = image.resize((224, 224), Image.BICUBIC) # Match your training interpolation
88
-
89
- # Convert to numpy array
90
- image_array = np.array(image, dtype=np.float32)
91
-
92
- # Add batch dimension
93
- image_array = np.expand_dims(image_array, axis=0)
94
-
95
- # NOTE: Your model has Rescaling(1./255) as first layer, so no need to normalize here
96
- # The model will handle normalization internally
97
-
98
- return image_array
99
-
100
- except Exception as e:
101
- logger.error(f"Image preprocessing error: {e}")
102
- raise HTTPException(status_code=400, detail=f"Image preprocessing failed: {e}")
103
-
104
- @app.on_event("startup")
105
- async def startup_event():
106
- """Load model on startup"""
107
- global model
108
- try:
109
- model = load_model()
110
- logger.info("API startup complete")
111
-
112
- # Test model with dummy input
113
- dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
114
- _ = model.predict(dummy_input, verbose=0)
115
- logger.info("Model test prediction successful")
116
-
117
- except Exception as e:
118
- logger.error(f"Startup failed: {e}")
119
- raise
120
-
121
- @app.get("/health")
122
- async def health_check():
123
- """Health check endpoint"""
124
- try:
125
- # Quick model test
126
- dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
127
- prediction = model.predict(dummy_input, verbose=0)
128
- model_working = prediction is not None
129
-
130
- return {
131
- "status": "healthy",
132
- "model_loaded": model is not None,
133
- "model_working": model_working,
134
- "classes": class_labels,
135
- "input_shape": "(224, 224, 3)",
136
- "model_type": "TensorFlow/Keras MobileNetV2"
137
- }
138
- except Exception as e:
139
- return {
140
- "status": "unhealthy",
141
- "error": str(e),
142
- "model_loaded": model is not None,
143
- "classes": class_labels
144
- }
145
-
146
- @app.post("/classify")
147
- async def classify_image(file: UploadFile = File(...)):
148
- """
149
- Main classification endpoint for ESP32
150
-
151
- Expected usage:
152
- curl -X POST -F "file=@image.jpg" https://your-space-url.hf.space/classify
153
-
154
- Returns:
155
- JSON: {"label": "plastic"} or {"error": "message"}
156
- """
157
- try:
158
- # Validate file type
159
- if not file.content_type or not file.content_type.startswith('image/'):
160
- logger.warning(f"Invalid file type: {file.content_type}")
161
- raise HTTPException(status_code=400, detail="File must be an image")
162
-
163
- # Read image data
164
- image_data = await file.read()
165
- if len(image_data) == 0:
166
- raise HTTPException(status_code=400, detail="Empty image file")
167
-
168
- logger.info(f"Processing image: {file.filename}, size: {len(image_data)} bytes")
169
-
170
- # Preprocess image
171
- processed_image = preprocess_image(image_data)
172
-
173
- # Make prediction
174
- predictions = model.predict(processed_image, verbose=0)
175
- predicted_class_index = np.argmax(predictions[0])
176
- predicted_class = class_labels[predicted_class_index]
177
- confidence = float(predictions[0][predicted_class_index])
178
-
179
- logger.info(f"Prediction: {predicted_class} (confidence: {confidence:.3f})")
180
-
181
- # Return simple response for ESP32 - match your ESP32 expectation exactly
182
- return {"label": predicted_class.capitalize()} # Capitalize to match your ESP32 labels
183
-
184
- except HTTPException:
185
- raise
186
- except Exception as e:
187
- logger.error(f"Classification error: {str(e)}")
188
- return JSONResponse(
189
- status_code=500,
190
- content={"error": f"Classification failed: {str(e)}"}
191
- )
192
-
193
- @app.post("/classify/detailed")
194
- async def classify_detailed(file: UploadFile = File(...)):
195
- """
196
- Detailed classification endpoint with confidence scores
197
- """
198
- try:
199
- # Validate file type
200
- if not file.content_type or not file.content_type.startswith('image/'):
201
- raise HTTPException(status_code=400, detail="File must be an image")
202
-
203
- # Read and process image
204
- image_data = await file.read()
205
- processed_image = preprocess_image(image_data)
206
-
207
- # Make prediction with full details
208
- predictions = model.predict(processed_image, verbose=0)
209
- predicted_class_index = np.argmax(predictions[0])
210
- predicted_class = class_labels[predicted_class_index]
211
- confidence = float(predictions[0][predicted_class_index])
212
-
213
- # Get all class probabilities
214
- all_probs = {
215
- class_labels[i].capitalize(): round(float(predictions[0][i]) * 100, 2)
216
- for i in range(len(class_labels))
217
- }
218
-
219
- return {
220
- "label": predicted_class.capitalize(),
221
- "confidence": round(confidence * 100, 2),
222
- "all_probabilities": all_probs,
223
- "model_info": {
224
- "architecture": "MobileNetV2",
225
- "input_size": "224x224",
226
- "classes": len(class_labels)
227
- },
228
- "status": "success"
229
- }
230
-
231
- except Exception as e:
232
- logger.error(f"Detailed classification error: {str(e)}")
233
- raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")
234
-
235
- @app.get("/info")
236
- async def get_info():
237
- """API information endpoint"""
238
- return {
239
- "api_name": "Waste Classification API",
240
- "version": "1.0.0",
241
- "model": {
242
- "architecture": "MobileNetV2 + Custom Head",
243
- "framework": "TensorFlow/Keras",
244
- "input_size": "224x224x3",
245
- "preprocessing": "RGB, Resize, Rescaling (internal)"
246
- },
247
- "classes": [label.capitalize() for label in class_labels],
248
- "endpoints": {
249
- "/classify": "POST - Main classification endpoint (returns simple label)",
250
- "/classify/detailed": "POST - Detailed classification with confidence",
251
- "/health": "GET - Health check",
252
- "/info": "GET - API information"
253
- },
254
- "usage": {
255
- "esp32": "POST image to /classify endpoint",
256
- "curl_example": "curl -X POST -F 'file=@image.jpg' https://your-space-url.hf.space/classify"
257
- }
258
- }
259
-
260
- @app.post("/test")
261
- async def test_with_dummy():
262
- """Test endpoint with dummy data for debugging"""
263
- try:
264
- # Create dummy image (random noise)
265
- dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
266
- dummy_input = np.expand_dims(dummy_image.astype(np.float32), axis=0)
267
-
268
- # Make prediction
269
- predictions = model.predict(dummy_input, verbose=0)
270
- predicted_class_index = np.argmax(predictions[0])
271
- predicted_class = class_labels[predicted_class_index]
272
-
273
- return {
274
- "test_status": "success",
275
- "predicted_class": predicted_class.capitalize(),
276
- "confidence": float(predictions[0][predicted_class_index]),
277
- "all_predictions": [float(p) for p in predictions[0]]
278
- }
279
- except Exception as e:
280
- return {"test_status": "failed", "error": str(e)}
281
-
282
- if __name__ == "__main__":
283
- # Run the FastAPI app
284
- port = int(os.environ.get("PORT", 7860))
285
- uvicorn.run(
286
- app,
287
- host="0.0.0.0",
288
- port=port,
289
- log_level="info"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  )
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import tensorflow as tf
5
+ import numpy as np
6
+ from PIL import Image
7
+ import io
8
+ import logging
9
+ import uvicorn
10
+ import os
11
+
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Initialize FastAPI app
17
+ app = FastAPI(
18
+ title="Waste Classification API",
19
+ description="API for classifying waste into categories: Glass, Metal, Organic, Paper, Plastic",
20
+ version="1.0.0",
21
+ docs_url="/", # Swagger UI at root for easy access
22
+ )
23
+
24
+ # Add CORS middleware for web access
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ # Global variables - match your training exactly
34
+ model = None
35
+ # IMPORTANT: Your class order from training (alphabetical from image_dataset_from_directory)
36
+ class_labels = ["glass", "metal", "organic", "paper", "plastic"]
37
+
38
+ def load_model():
39
+ """Load the trained TensorFlow/Keras model"""
40
+ try:
41
+ # Try loading different formats in order of preference
42
+ model_files = [
43
+ 'waste_model.keras', # Keras format (recommended)
44
+ 'waste_model.h5', # H5 format
45
+ 'best_model.keras' # Checkpoint from training
46
+ ]
47
+
48
+ model = None
49
+ for model_file in model_files:
50
+ if os.path.exists(model_file):
51
+ try:
52
+ model = tf.keras.models.load_model(model_file)
53
+ logger.info(f"Model loaded successfully from {model_file}")
54
+ break
55
+ except Exception as e:
56
+ logger.warning(f"Failed to load {model_file}: {e}")
57
+ continue
58
+
59
+ if model is None:
60
+ logger.error("No model file found. Creating dummy model for testing.")
61
+ # Create dummy model with same architecture for testing
62
+ model = tf.keras.Sequential([
63
+ tf.keras.layers.Rescaling(1./255, input_shape=(224, 224, 3)),
64
+ tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet'),
65
+ tf.keras.layers.GlobalAveragePooling2D(),
66
+ tf.keras.layers.Dense(128, activation='relu'),
67
+ tf.keras.layers.Dropout(0.2),
68
+ tf.keras.layers.Dense(5, activation='softmax')
69
+ ])
70
+ logger.warning("Using dummy model - predictions will be random!")
71
+
72
+ return model
73
+
74
+ except Exception as e:
75
+ logger.error(f"Critical error loading model: {e}")
76
+ raise Exception(f"Model loading failed: {e}")
77
+
78
+ def preprocess_image(image_data):
79
+ """
80
+ Preprocess image to match training pipeline:
81
+ - Crop ROI from ESP32 frame (400x296)
82
+ - Resize to 224x224
83
+ - Convert to numpy array, add batch dim
84
+ """
85
+ try:
86
+ # Load image
87
+ image = Image.open(io.BytesIO(image_data)).convert('RGB')
88
+ w, h = image.size # (width=400, height=296 expected)
89
+
90
+ # Define ROI (x1:x2, y1:y2), inclusive indices
91
+ X1, X2 = 90, 280
92
+ Y1, Y2 = 5, 205
93
+
94
+ # Clamp to actual image bounds
95
+ X1 = max(0, min(w - 1, X1))
96
+ X2 = max(0, min(w - 1, X2))
97
+ Y1 = max(0, min(h - 1, Y1))
98
+ Y2 = max(0, min(h - 1, Y2))
99
+
100
+ # Pillow crop box is (left, upper, right, lower) with right/lower EXCLUSIVE
101
+ crop_box = (X1, Y1, X2 + 1, Y2 + 1)
102
+ image = image.crop(crop_box)
103
+
104
+ # Resize to model input size
105
+ image = image.resize((224, 224), Image.BICUBIC)
106
+
107
+ # Convert to numpy array
108
+ image_array = np.array(image, dtype=np.float32)
109
+
110
+ # Add batch dimension
111
+ image_array = np.expand_dims(image_array, axis=0)
112
+
113
+ # Model has Rescaling(1./255) layer, so no manual normalization
114
+ return image_array
115
+
116
+ except Exception as e:
117
+ logger.error(f"Image preprocessing error: {e}")
118
+ raise HTTPException(status_code=400, detail=f"Image preprocessing failed: {e}")
119
+
120
+
121
+ @app.on_event("startup")
122
+ async def startup_event():
123
+ """Load model on startup"""
124
+ global model
125
+ try:
126
+ model = load_model()
127
+ logger.info("API startup complete")
128
+
129
+ # Test model with dummy input
130
+ dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
131
+ _ = model.predict(dummy_input, verbose=0)
132
+ logger.info("Model test prediction successful")
133
+
134
+ except Exception as e:
135
+ logger.error(f"Startup failed: {e}")
136
+ raise
137
+
138
+ @app.get("/health")
139
+ async def health_check():
140
+ """Health check endpoint"""
141
+ try:
142
+ # Quick model test
143
+ dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
144
+ prediction = model.predict(dummy_input, verbose=0)
145
+ model_working = prediction is not None
146
+
147
+ return {
148
+ "status": "healthy",
149
+ "model_loaded": model is not None,
150
+ "model_working": model_working,
151
+ "classes": class_labels,
152
+ "input_shape": "(224, 224, 3)",
153
+ "model_type": "TensorFlow/Keras MobileNetV2"
154
+ }
155
+ except Exception as e:
156
+ return {
157
+ "status": "unhealthy",
158
+ "error": str(e),
159
+ "model_loaded": model is not None,
160
+ "classes": class_labels
161
+ }
162
+
163
+ @app.post("/classify")
164
+ async def classify_image(file: UploadFile = File(...)):
165
+ """
166
+ Main classification endpoint for ESP32
167
+
168
+ Expected usage:
169
+ curl -X POST -F "file=@image.jpg" https://your-space-url.hf.space/classify
170
+
171
+ Returns:
172
+ JSON: {"label": "plastic"} or {"error": "message"}
173
+ """
174
+ try:
175
+ # Validate file type
176
+ if not file.content_type or not file.content_type.startswith('image/'):
177
+ logger.warning(f"Invalid file type: {file.content_type}")
178
+ raise HTTPException(status_code=400, detail="File must be an image")
179
+
180
+ # Read image data
181
+ image_data = await file.read()
182
+ if len(image_data) == 0:
183
+ raise HTTPException(status_code=400, detail="Empty image file")
184
+
185
+ logger.info(f"Processing image: {file.filename}, size: {len(image_data)} bytes")
186
+
187
+ # Preprocess image
188
+ processed_image = preprocess_image(image_data)
189
+
190
+ # Make prediction
191
+ predictions = model.predict(processed_image, verbose=0)
192
+ predicted_class_index = np.argmax(predictions[0])
193
+ predicted_class = class_labels[predicted_class_index]
194
+ confidence = float(predictions[0][predicted_class_index])
195
+
196
+ logger.info(f"Prediction: {predicted_class} (confidence: {confidence:.3f})")
197
+
198
+ # Return simple response for ESP32 - match your ESP32 expectation exactly
199
+ return {"label": predicted_class.capitalize()} # Capitalize to match your ESP32 labels
200
+
201
+ except HTTPException:
202
+ raise
203
+ except Exception as e:
204
+ logger.error(f"Classification error: {str(e)}")
205
+ return JSONResponse(
206
+ status_code=500,
207
+ content={"error": f"Classification failed: {str(e)}"}
208
+ )
209
+
210
+ @app.post("/classify/detailed")
211
+ async def classify_detailed(file: UploadFile = File(...)):
212
+ """
213
+ Detailed classification endpoint with confidence scores
214
+ """
215
+ try:
216
+ # Validate file type
217
+ if not file.content_type or not file.content_type.startswith('image/'):
218
+ raise HTTPException(status_code=400, detail="File must be an image")
219
+
220
+ # Read and process image
221
+ image_data = await file.read()
222
+ processed_image = preprocess_image(image_data)
223
+
224
+ # Make prediction with full details
225
+ predictions = model.predict(processed_image, verbose=0)
226
+ predicted_class_index = np.argmax(predictions[0])
227
+ predicted_class = class_labels[predicted_class_index]
228
+ confidence = float(predictions[0][predicted_class_index])
229
+
230
+ # Get all class probabilities
231
+ all_probs = {
232
+ class_labels[i].capitalize(): round(float(predictions[0][i]) * 100, 2)
233
+ for i in range(len(class_labels))
234
+ }
235
+
236
+ return {
237
+ "label": predicted_class.capitalize(),
238
+ "confidence": round(confidence * 100, 2),
239
+ "all_probabilities": all_probs,
240
+ "model_info": {
241
+ "architecture": "MobileNetV2",
242
+ "input_size": "224x224",
243
+ "classes": len(class_labels)
244
+ },
245
+ "status": "success"
246
+ }
247
+
248
+ except Exception as e:
249
+ logger.error(f"Detailed classification error: {str(e)}")
250
+ raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")
251
+
252
+ @app.get("/info")
253
+ async def get_info():
254
+ """API information endpoint"""
255
+ return {
256
+ "api_name": "Waste Classification API",
257
+ "version": "1.0.0",
258
+ "model": {
259
+ "architecture": "MobileNetV2 + Custom Head",
260
+ "framework": "TensorFlow/Keras",
261
+ "input_size": "224x224x3",
262
+ "preprocessing": "RGB, Resize, Rescaling (internal)"
263
+ },
264
+ "classes": [label.capitalize() for label in class_labels],
265
+ "endpoints": {
266
+ "/classify": "POST - Main classification endpoint (returns simple label)",
267
+ "/classify/detailed": "POST - Detailed classification with confidence",
268
+ "/health": "GET - Health check",
269
+ "/info": "GET - API information"
270
+ },
271
+ "usage": {
272
+ "esp32": "POST image to /classify endpoint",
273
+ "curl_example": "curl -X POST -F 'file=@image.jpg' https://your-space-url.hf.space/classify"
274
+ }
275
+ }
276
+
277
+ @app.post("/test")
278
+ async def test_with_dummy():
279
+ """Test endpoint with dummy data for debugging"""
280
+ try:
281
+ # Create dummy image (random noise)
282
+ dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
283
+ dummy_input = np.expand_dims(dummy_image.astype(np.float32), axis=0)
284
+
285
+ # Make prediction
286
+ predictions = model.predict(dummy_input, verbose=0)
287
+ predicted_class_index = np.argmax(predictions[0])
288
+ predicted_class = class_labels[predicted_class_index]
289
+
290
+ return {
291
+ "test_status": "success",
292
+ "predicted_class": predicted_class.capitalize(),
293
+ "confidence": float(predictions[0][predicted_class_index]),
294
+ "all_predictions": [float(p) for p in predictions[0]]
295
+ }
296
+ except Exception as e:
297
+ return {"test_status": "failed", "error": str(e)}
298
+
299
+ if __name__ == "__main__":
300
+ # Run the FastAPI app
301
+ port = int(os.environ.get("PORT", 7860))
302
+ uvicorn.run(
303
+ app,
304
+ host="0.0.0.0",
305
+ port=port,
306
+ log_level="info"
307
  )