Penthes commited on
Commit
50e9542
·
verified ·
1 Parent(s): bde32a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -8
app.py CHANGED
@@ -35,21 +35,27 @@ model = None
35
  # IMPORTANT: Your class order from training (alphabetical from image_dataset_from_directory)
36
  class_labels = ["glass", "metal", "organic", "paper", "plastic"]
37
 
 
38
  def load_model():
39
- """Load the trained TensorFlow/Keras model"""
40
  try:
41
- # Try loading different formats in order of preference
42
  model_files = [
43
  'model/waste_model.keras', # Keras format (recommended)
44
  'model/waste_model.h5', # H5 format
45
- 'model/best_model.keras' # Checkpoint from training
 
 
46
  ]
47
 
48
  model = None
 
 
49
  for model_file in model_files:
50
  if os.path.exists(model_file):
51
  try:
52
  model = tf.keras.models.load_model(model_file)
 
53
  logger.info(f"Model loaded successfully from {model_file}")
54
  break
55
  except Exception as e:
@@ -57,7 +63,7 @@ def load_model():
57
  continue
58
 
59
  if model is None:
60
- logger.error("No model file found. Creating dummy model for testing.")
61
  # Create dummy model with same architecture for testing
62
  model = tf.keras.Sequential([
63
  tf.keras.layers.Rescaling(1./255, input_shape=(224, 224, 3)),
@@ -68,6 +74,8 @@ def load_model():
68
  tf.keras.layers.Dense(5, activation='softmax')
69
  ])
70
  logger.warning("Using dummy model - predictions will be random!")
 
 
71
 
72
  return model
73
 
@@ -142,15 +150,19 @@ async def health_check():
142
  # Quick model test
143
  dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
144
  prediction = model.predict(dummy_input, verbose=0)
145
- model_working = prediction is not None
 
 
 
146
 
147
  return {
148
  "status": "healthy",
 
149
  "model_loaded": model is not None,
150
- "model_working": model_working,
151
  "classes": class_labels,
152
  "input_shape": "(224, 224, 3)",
153
- "model_type": "TensorFlow/Keras MobileNetV2"
 
154
  }
155
  except Exception as e:
156
  return {
@@ -159,7 +171,33 @@ async def health_check():
159
  "model_loaded": model is not None,
160
  "classes": class_labels
161
  }
162
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  @app.post("/classify")
164
  async def classify_image(file: UploadFile = File(...)):
165
  """
 
35
  # IMPORTANT: Your class order from training (alphabetical from image_dataset_from_directory)
36
  class_labels = ["glass", "metal", "organic", "paper", "plastic"]
37
 
38
+
39
  def load_model():
40
+ """Load the trained TensorFlow/Keras model from model/ directory"""
41
  try:
42
+ # Try loading different formats from model/ directory
43
  model_files = [
44
  'model/waste_model.keras', # Keras format (recommended)
45
  'model/waste_model.h5', # H5 format
46
+ 'model/best_model.keras', # Checkpoint from training
47
+ 'waste_model.keras', # Fallback to root (original)
48
+ 'waste_model.h5' # Fallback to root
49
  ]
50
 
51
  model = None
52
+ loaded_from = None
53
+
54
  for model_file in model_files:
55
  if os.path.exists(model_file):
56
  try:
57
  model = tf.keras.models.load_model(model_file)
58
+ loaded_from = model_file
59
  logger.info(f"Model loaded successfully from {model_file}")
60
  break
61
  except Exception as e:
 
63
  continue
64
 
65
  if model is None:
66
+ logger.error("No model file found in any location. Creating dummy model for testing.")
67
  # Create dummy model with same architecture for testing
68
  model = tf.keras.Sequential([
69
  tf.keras.layers.Rescaling(1./255, input_shape=(224, 224, 3)),
 
74
  tf.keras.layers.Dense(5, activation='softmax')
75
  ])
76
  logger.warning("Using dummy model - predictions will be random!")
77
+ else:
78
+ logger.info(f"Successfully loaded model from: {loaded_from}")
79
 
80
  return model
81
 
 
150
  # Quick model test
151
  dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
152
  prediction = model.predict(dummy_input, verbose=0)
153
+
154
+ # Check if we're using the dummy model (random predictions)
155
+ is_dummy_model = np.allclose(prediction.sum(), 1.0) # Should sum to ~1
156
+ model_status = "real_model" if not is_dummy_model else "dummy_model"
157
 
158
  return {
159
  "status": "healthy",
160
+ "model_status": model_status,
161
  "model_loaded": model is not None,
 
162
  "classes": class_labels,
163
  "input_shape": "(224, 224, 3)",
164
+ "model_type": "TensorFlow/Keras MobileNetV2",
165
+ "prediction_sample": prediction[0].tolist() # Show first prediction
166
  }
167
  except Exception as e:
168
  return {
 
171
  "model_loaded": model is not None,
172
  "classes": class_labels
173
  }
174
+ @app.on_event("startup")
175
+ async def startup_event():
176
+ """Load model on startup"""
177
+ global model
178
+ try:
179
+ # Log available files for debugging
180
+ logger.info("Available files in root:")
181
+ for file in os.listdir('.'):
182
+ logger.info(f" {file}")
183
+
184
+ if os.path.exists('model'):
185
+ logger.info("Available files in model/ directory:")
186
+ for file in os.listdir('model'):
187
+ logger.info(f" model/{file}")
188
+
189
+ model = load_model()
190
+ logger.info("API startup complete")
191
+
192
+ # Test model with dummy input
193
+ dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
194
+ _ = model.predict(dummy_input, verbose=0)
195
+ logger.info("Model test prediction successful")
196
+
197
+ except Exception as e:
198
+ logger.error(f"Startup failed: {e}")
199
+ raise
200
+
201
  @app.post("/classify")
202
  async def classify_image(file: UploadFile = File(...)):
203
  """