Spaces:
Sleeping
Sleeping
Update python_utils/get_model.py
Browse files
python_utils/get_model.py
CHANGED
|
@@ -24,11 +24,20 @@ def load_model():
|
|
| 24 |
|
| 25 |
## define relevant parameters
|
| 26 |
cfg = get_cfg()
|
|
|
|
| 27 |
cfg.merge_from_file("./configs/test_model_config.yaml")
|
| 28 |
if not torch.cuda.is_available():
|
| 29 |
cfg.MODEL.DEVICE = "cpu"
|
| 30 |
else:
|
| 31 |
cfg.MODEL.DEVICE = 'cuda'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
predictor = DefaultPredictor(cfg)
|
| 33 |
|
| 34 |
return predictor
|
|
|
|
| 24 |
|
| 25 |
## define relevant parameters
|
| 26 |
cfg = get_cfg()
|
| 27 |
+
|
| 28 |
cfg.merge_from_file("./configs/test_model_config.yaml")
|
| 29 |
if not torch.cuda.is_available():
|
| 30 |
cfg.MODEL.DEVICE = "cpu"
|
| 31 |
else:
|
| 32 |
cfg.MODEL.DEVICE = 'cuda'
|
| 33 |
+
|
| 34 |
+
## when rerouting to use the final model (final_tz_segmentor) USE_FED_LOSS has to be set to false
|
| 35 |
+
## this setting requires the training data to calculate class imbalance that the app will not have access to
|
| 36 |
+
## some messages will appear when using the model that certain weights are not being used
|
| 37 |
+
## but these are used during training and not inference and shouldn't affect the model performance
|
| 38 |
+
## code below
|
| 39 |
+
## cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = false
|
| 40 |
+
|
| 41 |
predictor = DefaultPredictor(cfg)
|
| 42 |
|
| 43 |
return predictor
|