Penthes commited on
Commit
5d01458
·
verified ·
1 Parent(s): 62cba99

upload docker files

Browse files
Files changed (3) hide show
  1. Dockerfile +25 -0
  2. app.py +290 -0
  3. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ gcc \
9
+ g++ \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements first for better caching
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy application code
19
+ COPY . .
20
+
21
+ # Expose the port that the app runs on
22
+ EXPOSE 7860
23
+
24
+ # Command to run the application
25
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import tensorflow as tf
5
+ import numpy as np
6
+ from PIL import Image
7
+ import io
8
+ import logging
9
+ import uvicorn
10
+ import os
11
+
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Initialize FastAPI app
17
+ app = FastAPI(
18
+ title="Waste Classification API",
19
+ description="API for classifying waste into categories: Glass, Metal, Organic, Paper, Plastic",
20
+ version="1.0.0",
21
+ docs_url="/", # Swagger UI at root for easy access
22
+ )
23
+
24
+ # Add CORS middleware for web access
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ # Global variables - match your training exactly
34
+ model = None
35
+ # IMPORTANT: Your class order from training (alphabetical from image_dataset_from_directory)
36
+ class_labels = ["glass", "metal", "organic", "paper", "plastic"]
37
+
38
+ def load_model():
39
+ """Load the trained TensorFlow/Keras model"""
40
+ try:
41
+ # Try loading different formats in order of preference
42
+ model_files = [
43
+ 'waste_model.keras', # Keras format (recommended)
44
+ 'waste_model.h5', # H5 format
45
+ 'best_model.keras' # Checkpoint from training
46
+ ]
47
+
48
+ model = None
49
+ for model_file in model_files:
50
+ if os.path.exists(model_file):
51
+ try:
52
+ model = tf.keras.models.load_model(model_file)
53
+ logger.info(f"Model loaded successfully from {model_file}")
54
+ break
55
+ except Exception as e:
56
+ logger.warning(f"Failed to load {model_file}: {e}")
57
+ continue
58
+
59
+ if model is None:
60
+ logger.error("No model file found. Creating dummy model for testing.")
61
+ # Create dummy model with same architecture for testing
62
+ model = tf.keras.Sequential([
63
+ tf.keras.layers.Rescaling(1./255, input_shape=(224, 224, 3)),
64
+ tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet'),
65
+ tf.keras.layers.GlobalAveragePooling2D(),
66
+ tf.keras.layers.Dense(128, activation='relu'),
67
+ tf.keras.layers.Dropout(0.2),
68
+ tf.keras.layers.Dense(5, activation='softmax')
69
+ ])
70
+ logger.warning("Using dummy model - predictions will be random!")
71
+
72
+ return model
73
+
74
+ except Exception as e:
75
+ logger.error(f"Critical error loading model: {e}")
76
+ raise Exception(f"Model loading failed: {e}")
77
+
78
+ def preprocess_image(image_data):
79
+ """
80
+ Preprocess image to match your training pipeline
81
+ """
82
+ try:
83
+ # Load image
84
+ image = Image.open(io.BytesIO(image_data)).convert('RGB')
85
+
86
+ # Resize to match training (224, 224)
87
+ image = image.resize((224, 224), Image.BICUBIC) # Match your training interpolation
88
+
89
+ # Convert to numpy array
90
+ image_array = np.array(image, dtype=np.float32)
91
+
92
+ # Add batch dimension
93
+ image_array = np.expand_dims(image_array, axis=0)
94
+
95
+ # NOTE: Your model has Rescaling(1./255) as first layer, so no need to normalize here
96
+ # The model will handle normalization internally
97
+
98
+ return image_array
99
+
100
+ except Exception as e:
101
+ logger.error(f"Image preprocessing error: {e}")
102
+ raise HTTPException(status_code=400, detail=f"Image preprocessing failed: {e}")
103
+
104
+ @app.on_event("startup")
105
+ async def startup_event():
106
+ """Load model on startup"""
107
+ global model
108
+ try:
109
+ model = load_model()
110
+ logger.info("API startup complete")
111
+
112
+ # Test model with dummy input
113
+ dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
114
+ _ = model.predict(dummy_input, verbose=0)
115
+ logger.info("Model test prediction successful")
116
+
117
+ except Exception as e:
118
+ logger.error(f"Startup failed: {e}")
119
+ raise
120
+
121
+ @app.get("/health")
122
+ async def health_check():
123
+ """Health check endpoint"""
124
+ try:
125
+ # Quick model test
126
+ dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
127
+ prediction = model.predict(dummy_input, verbose=0)
128
+ model_working = prediction is not None
129
+
130
+ return {
131
+ "status": "healthy",
132
+ "model_loaded": model is not None,
133
+ "model_working": model_working,
134
+ "classes": class_labels,
135
+ "input_shape": "(224, 224, 3)",
136
+ "model_type": "TensorFlow/Keras MobileNetV2"
137
+ }
138
+ except Exception as e:
139
+ return {
140
+ "status": "unhealthy",
141
+ "error": str(e),
142
+ "model_loaded": model is not None,
143
+ "classes": class_labels
144
+ }
145
+
146
+ @app.post("/classify")
147
+ async def classify_image(file: UploadFile = File(...)):
148
+ """
149
+ Main classification endpoint for ESP32
150
+
151
+ Expected usage:
152
+ curl -X POST -F "file=@image.jpg" https://your-space-url.hf.space/classify
153
+
154
+ Returns:
155
+ JSON: {"label": "plastic"} or {"error": "message"}
156
+ """
157
+ try:
158
+ # Validate file type
159
+ if not file.content_type or not file.content_type.startswith('image/'):
160
+ logger.warning(f"Invalid file type: {file.content_type}")
161
+ raise HTTPException(status_code=400, detail="File must be an image")
162
+
163
+ # Read image data
164
+ image_data = await file.read()
165
+ if len(image_data) == 0:
166
+ raise HTTPException(status_code=400, detail="Empty image file")
167
+
168
+ logger.info(f"Processing image: {file.filename}, size: {len(image_data)} bytes")
169
+
170
+ # Preprocess image
171
+ processed_image = preprocess_image(image_data)
172
+
173
+ # Make prediction
174
+ predictions = model.predict(processed_image, verbose=0)
175
+ predicted_class_index = np.argmax(predictions[0])
176
+ predicted_class = class_labels[predicted_class_index]
177
+ confidence = float(predictions[0][predicted_class_index])
178
+
179
+ logger.info(f"Prediction: {predicted_class} (confidence: {confidence:.3f})")
180
+
181
+ # Return simple response for ESP32 - match your ESP32 expectation exactly
182
+ return {"label": predicted_class.capitalize()} # Capitalize to match your ESP32 labels
183
+
184
+ except HTTPException:
185
+ raise
186
+ except Exception as e:
187
+ logger.error(f"Classification error: {str(e)}")
188
+ return JSONResponse(
189
+ status_code=500,
190
+ content={"error": f"Classification failed: {str(e)}"}
191
+ )
192
+
193
+ @app.post("/classify/detailed")
194
+ async def classify_detailed(file: UploadFile = File(...)):
195
+ """
196
+ Detailed classification endpoint with confidence scores
197
+ """
198
+ try:
199
+ # Validate file type
200
+ if not file.content_type or not file.content_type.startswith('image/'):
201
+ raise HTTPException(status_code=400, detail="File must be an image")
202
+
203
+ # Read and process image
204
+ image_data = await file.read()
205
+ processed_image = preprocess_image(image_data)
206
+
207
+ # Make prediction with full details
208
+ predictions = model.predict(processed_image, verbose=0)
209
+ predicted_class_index = np.argmax(predictions[0])
210
+ predicted_class = class_labels[predicted_class_index]
211
+ confidence = float(predictions[0][predicted_class_index])
212
+
213
+ # Get all class probabilities
214
+ all_probs = {
215
+ class_labels[i].capitalize(): round(float(predictions[0][i]) * 100, 2)
216
+ for i in range(len(class_labels))
217
+ }
218
+
219
+ return {
220
+ "label": predicted_class.capitalize(),
221
+ "confidence": round(confidence * 100, 2),
222
+ "all_probabilities": all_probs,
223
+ "model_info": {
224
+ "architecture": "MobileNetV2",
225
+ "input_size": "224x224",
226
+ "classes": len(class_labels)
227
+ },
228
+ "status": "success"
229
+ }
230
+
231
+ except Exception as e:
232
+ logger.error(f"Detailed classification error: {str(e)}")
233
+ raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")
234
+
235
+ @app.get("/info")
236
+ async def get_info():
237
+ """API information endpoint"""
238
+ return {
239
+ "api_name": "Waste Classification API",
240
+ "version": "1.0.0",
241
+ "model": {
242
+ "architecture": "MobileNetV2 + Custom Head",
243
+ "framework": "TensorFlow/Keras",
244
+ "input_size": "224x224x3",
245
+ "preprocessing": "RGB, Resize, Rescaling (internal)"
246
+ },
247
+ "classes": [label.capitalize() for label in class_labels],
248
+ "endpoints": {
249
+ "/classify": "POST - Main classification endpoint (returns simple label)",
250
+ "/classify/detailed": "POST - Detailed classification with confidence",
251
+ "/health": "GET - Health check",
252
+ "/info": "GET - API information"
253
+ },
254
+ "usage": {
255
+ "esp32": "POST image to /classify endpoint",
256
+ "curl_example": "curl -X POST -F 'file=@image.jpg' https://your-space-url.hf.space/classify"
257
+ }
258
+ }
259
+
260
+ @app.post("/test")
261
+ async def test_with_dummy():
262
+ """Test endpoint with dummy data for debugging"""
263
+ try:
264
+ # Create dummy image (random noise)
265
+ dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
266
+ dummy_input = np.expand_dims(dummy_image.astype(np.float32), axis=0)
267
+
268
+ # Make prediction
269
+ predictions = model.predict(dummy_input, verbose=0)
270
+ predicted_class_index = np.argmax(predictions[0])
271
+ predicted_class = class_labels[predicted_class_index]
272
+
273
+ return {
274
+ "test_status": "success",
275
+ "predicted_class": predicted_class.capitalize(),
276
+ "confidence": float(predictions[0][predicted_class_index]),
277
+ "all_predictions": [float(p) for p in predictions[0]]
278
+ }
279
+ except Exception as e:
280
+ return {"test_status": "failed", "error": str(e)}
281
+
282
+ if __name__ == "__main__":
283
+ # Run the FastAPI app
284
+ port = int(os.environ.get("PORT", 7860))
285
+ uvicorn.run(
286
+ app,
287
+ host="0.0.0.0",
288
+ port=port,
289
+ log_level="info"
290
+ )
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # For Docker FastAPI version:
2
+ fastapi==0.104.1
3
+ uvicorn[standard]==0.24.0
4
+ python-multipart==0.0.6
5
+ Pillow==10.0.1
6
+ numpy==1.24.3
7
+ requests==2.31.0
8
+ tensorflow==2.15.0
9
+ numpy==1.24.3