Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,7 +16,7 @@ app = Flask(__name__)
|
|
| 16 |
CORS(app)
|
| 17 |
|
| 18 |
# CONFIG
|
| 19 |
-
DEVICE = torch.device('cpu')
|
| 20 |
MODEL_PATH = "best_model.pth"
|
| 21 |
MODEL_NAME = "nvidia/segformer-b2-finetuned-ade-512-512"
|
| 22 |
NUM_CLASSES = 6
|
|
@@ -26,8 +26,12 @@ print("Loading model...")
|
|
| 26 |
model = SegformerForSemanticSegmentation.from_pretrained(
|
| 27 |
MODEL_NAME, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True
|
| 28 |
)
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
state_dict = checkpoint['model_state_dict']
|
|
|
|
|
|
|
| 31 |
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
| 32 |
model.load_state_dict(new_state_dict)
|
| 33 |
model.to(DEVICE)
|
|
@@ -83,7 +87,5 @@ def predict():
|
|
| 83 |
except Exception as e:
|
| 84 |
return jsonify({'error': str(e)}), 500
|
| 85 |
|
| 86 |
-
# --- CRITICAL CHANGE FOR HUGGING FACE ---
|
| 87 |
if __name__ == '__main__':
|
| 88 |
-
# Hugging Face runs on port 7860
|
| 89 |
app.run(host='0.0.0.0', port=7860)
|
|
|
|
| 16 |
CORS(app)
|
| 17 |
|
| 18 |
# CONFIG
|
| 19 |
+
DEVICE = torch.device('cpu') # Hugging Face Free Tier is CPU
|
| 20 |
MODEL_PATH = "best_model.pth"
|
| 21 |
MODEL_NAME = "nvidia/segformer-b2-finetuned-ade-512-512"
|
| 22 |
NUM_CLASSES = 6
|
|
|
|
| 26 |
model = SegformerForSemanticSegmentation.from_pretrained(
|
| 27 |
MODEL_NAME, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True
|
| 28 |
)
|
| 29 |
+
|
| 30 |
+
# --- FIX IS HERE ---
|
| 31 |
+
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
|
| 32 |
state_dict = checkpoint['model_state_dict']
|
| 33 |
+
|
| 34 |
+
# Fix key names (remove 'module.' if trained on multi-GPU)
|
| 35 |
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
| 36 |
model.load_state_dict(new_state_dict)
|
| 37 |
model.to(DEVICE)
|
|
|
|
| 87 |
except Exception as e:
|
| 88 |
return jsonify({'error': str(e)}), 500
|
| 89 |
|
|
|
|
| 90 |
if __name__ == '__main__':
|
|
|
|
| 91 |
app.run(host='0.0.0.0', port=7860)
|