Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -24,16 +24,27 @@ def load_model():
|
|
| 24 |
"""Load the trained model"""
|
| 25 |
global model
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
if not os.path.exists(config.BEST_MODEL_PATH):
|
| 28 |
-
print(f"
|
|
|
|
| 29 |
return False
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def preprocess_image(image):
|
|
@@ -157,10 +168,28 @@ def random_sample():
|
|
| 157 |
return jsonify({'error': str(e)}), 500
|
| 158 |
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
if __name__ == '__main__':
|
| 161 |
-
#
|
| 162 |
-
if
|
| 163 |
-
print("Starting Flask
|
| 164 |
app.run(debug=True, host='0.0.0.0', port=5000)
|
| 165 |
else:
|
| 166 |
print("Failed to load model. Please train the model first using train.py")
|
|
|
|
|
|
| 24 |
"""Load the trained model"""
|
| 25 |
global model
|
| 26 |
|
| 27 |
+
print(f"Looking for model at: {config.BEST_MODEL_PATH}")
|
| 28 |
+
print(f"Current working directory: {os.getcwd()}")
|
| 29 |
+
print(f"Files in checkpoints/: {os.listdir('checkpoints') if os.path.exists('checkpoints') else 'Directory not found'}")
|
| 30 |
+
|
| 31 |
if not os.path.exists(config.BEST_MODEL_PATH):
|
| 32 |
+
print(f"ERROR: Model checkpoint not found at {config.BEST_MODEL_PATH}")
|
| 33 |
+
print(f"Please ensure the model file exists in the checkpoints directory")
|
| 34 |
return False
|
| 35 |
|
| 36 |
+
try:
|
| 37 |
+
model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
|
| 38 |
+
epoch, accuracy = load_checkpoint(model, None, config.BEST_MODEL_PATH)
|
| 39 |
+
model.eval()
|
| 40 |
+
|
| 41 |
+
print(f"✅ Model loaded successfully from epoch {epoch + 1} with accuracy: {accuracy:.2f}%")
|
| 42 |
+
return True
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"ERROR loading model: {str(e)}")
|
| 45 |
+
import traceback
|
| 46 |
+
traceback.print_exc()
|
| 47 |
+
return False
|
| 48 |
|
| 49 |
|
| 50 |
def preprocess_image(image):
|
|
|
|
| 168 |
return jsonify({'error': str(e)}), 500
|
| 169 |
|
| 170 |
|
| 171 |
+
|
| 172 |
+
# Load model when module is imported (for Gunicorn)
|
| 173 |
+
print("=" * 60)
|
| 174 |
+
print("Initializing CIFAR-10 RNN Classifier")
|
| 175 |
+
print("=" * 60)
|
| 176 |
+
|
| 177 |
+
model_loaded = load_model()
|
| 178 |
+
|
| 179 |
+
if not model_loaded:
|
| 180 |
+
print("⚠️ WARNING: Model not loaded. Application will return errors.")
|
| 181 |
+
print("Please check that checkpoints/best_model.pth exists.")
|
| 182 |
+
else:
|
| 183 |
+
print("✅ Application ready to serve requests!")
|
| 184 |
+
|
| 185 |
+
print("=" * 60)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
if __name__ == '__main__':
|
| 189 |
+
# This runs only when executed directly (not with Gunicorn)
|
| 190 |
+
if model_loaded:
|
| 191 |
+
print("Starting Flask development server...")
|
| 192 |
app.run(debug=True, host='0.0.0.0', port=5000)
|
| 193 |
else:
|
| 194 |
print("Failed to load model. Please train the model first using train.py")
|
| 195 |
+
|