N-I-M-I commited on
Commit
20618e7
·
verified ·
1 Parent(s): 1e88f11

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +39 -10
app.py CHANGED
@@ -24,16 +24,27 @@ def load_model():
24
  """Load the trained model"""
25
  global model
26
 
 
 
 
 
27
  if not os.path.exists(config.BEST_MODEL_PATH):
28
- print(f"Warning: Model checkpoint not found at {config.BEST_MODEL_PATH}")
 
29
  return False
30
 
31
- model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
32
- epoch, accuracy = load_checkpoint(model, None, config.BEST_MODEL_PATH)
33
- model.eval()
34
-
35
- print(f"Model loaded from epoch {epoch + 1} with accuracy: {accuracy:.2f}%")
36
- return True
 
 
 
 
 
 
37
 
38
 
39
  def preprocess_image(image):
@@ -157,10 +168,28 @@ def random_sample():
157
  return jsonify({'error': str(e)}), 500
158
 
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  if __name__ == '__main__':
161
- # Load model
162
- if load_model():
163
- print("Starting Flask application...")
164
  app.run(debug=True, host='0.0.0.0', port=5000)
165
  else:
166
  print("Failed to load model. Please train the model first using train.py")
 
 
24
  """Load the trained model"""
25
  global model
26
 
27
+ print(f"Looking for model at: {config.BEST_MODEL_PATH}")
28
+ print(f"Current working directory: {os.getcwd()}")
29
+ print(f"Files in checkpoints/: {os.listdir('checkpoints') if os.path.exists('checkpoints') else 'Directory not found'}")
30
+
31
  if not os.path.exists(config.BEST_MODEL_PATH):
32
+ print(f"ERROR: Model checkpoint not found at {config.BEST_MODEL_PATH}")
33
+ print(f"Please ensure the model file exists in the checkpoints directory")
34
  return False
35
 
36
+ try:
37
+ model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
38
+ epoch, accuracy = load_checkpoint(model, None, config.BEST_MODEL_PATH)
39
+ model.eval()
40
+
41
+ print(f"✅ Model loaded successfully from epoch {epoch + 1} with accuracy: {accuracy:.2f}%")
42
+ return True
43
+ except Exception as e:
44
+ print(f"ERROR loading model: {str(e)}")
45
+ import traceback
46
+ traceback.print_exc()
47
+ return False
48
 
49
 
50
  def preprocess_image(image):
 
168
  return jsonify({'error': str(e)}), 500
169
 
170
 
171
+
172
+ # Load model when module is imported (for Gunicorn)
173
+ print("=" * 60)
174
+ print("Initializing CIFAR-10 RNN Classifier")
175
+ print("=" * 60)
176
+
177
+ model_loaded = load_model()
178
+
179
+ if not model_loaded:
180
+ print("⚠️ WARNING: Model not loaded. Application will return errors.")
181
+ print("Please check that checkpoints/best_model.pth exists.")
182
+ else:
183
+ print("✅ Application ready to serve requests!")
184
+
185
+ print("=" * 60)
186
+
187
+
188
  if __name__ == '__main__':
189
+ # This runs only when executed directly (not with Gunicorn)
190
+ if model_loaded:
191
+ print("Starting Flask development server...")
192
  app.run(debug=True, host='0.0.0.0', port=5000)
193
  else:
194
  print("Failed to load model. Please train the model first using train.py")
195
+