jatamura commited on
Commit
9858610
·
verified ·
1 Parent(s): db32c07

Update python_utils/get_model.py

Browse files
Files changed (1) hide show
  1. python_utils/get_model.py +9 -0
python_utils/get_model.py CHANGED
@@ -24,11 +24,20 @@ def load_model():
24
 
25
  ## define relevant parameters
26
  cfg = get_cfg()
 
27
  cfg.merge_from_file("./configs/test_model_config.yaml")
28
  if not torch.cuda.is_available():
29
  cfg.MODEL.DEVICE = "cpu"
30
  else:
31
  cfg.MODEL.DEVICE = 'cuda'
 
 
 
 
 
 
 
 
32
  predictor = DefaultPredictor(cfg)
33
 
34
  return predictor
 
24
 
25
  ## define relevant parameters
26
  cfg = get_cfg()
27
+
28
  cfg.merge_from_file("./configs/test_model_config.yaml")
29
  if not torch.cuda.is_available():
30
  cfg.MODEL.DEVICE = "cpu"
31
  else:
32
  cfg.MODEL.DEVICE = 'cuda'
33
+
34
+ ## when rerouting to use the final model (final_tz_segmentor) USE_FED_LOSS has to be set to false
35
+ ## this setting requires the training data to calculate class imbalance that the app will not have access to
36
+ ## some messages will appear when using the model that certain weights are not being used
37
+ ## but these are used during training and not inference and shouldn't affect the model performance
38
+ ## code below
39
+ ## cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = false
40
+
41
  predictor = DefaultPredictor(cfg)
42
 
43
  return predictor