alrichardbollans commited on
Commit
4b06494
·
1 Parent(s): f1640c3

Move config file to model repository

Browse files
Files changed (1) hide show
  1. python_utils/get_model.py +7 -5
python_utils/get_model.py CHANGED
@@ -1,5 +1,7 @@
1
  import urllib.request
2
 
 
 
3
  def get_set_up():
4
  import torch
5
  TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
@@ -18,9 +20,9 @@ def load_model():
18
 
19
  ## define relevant parameters
20
  cfg = get_cfg()
21
- config_url = "https://huggingface.co/TZProject/final_tz_segmentor/resolve/main/final_model_config.yaml"
22
- config_filename = "final_model_config.yaml"
23
- urllib.request.urlretrieve(config_url, filename=config_filename)
24
  cfg.merge_from_file(config_filename)
25
  if not torch.cuda.is_available():
26
  cfg.MODEL.DEVICE = "cpu"
@@ -38,7 +40,7 @@ def load_model():
38
 
39
  return predictor
40
 
41
- def mask_nms(masks, scores, nms_threshold=0.5):
42
  """
43
  Runs class agnostic NMS on masks/segmentations instead of the bounding boxes.
44
  :param masks: (list float) List of coordinates that make up the mask output from the model.
@@ -72,7 +74,7 @@ def mask_nms(masks, scores, nms_threshold=0.5):
72
  order.remove(j)
73
  return masks_kept
74
 
75
- def apply_nms(prediction, mask=False, cls_agnostic_nms=0.7):
76
  from torchvision.ops import nms
77
  from detectron2.structures import Instances
78
 
 
1
  import urllib.request
2
 
3
+ OPTIMAL_NMS_THRESHOLD = 0.7
4
+ _model_config_url = "https://huggingface.co/TZProject/final_tz_segmentor/resolve/main/final_model_config.yaml"
5
  def get_set_up():
6
  import torch
7
  TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
 
20
 
21
  ## define relevant parameters
22
  cfg = get_cfg()
23
+
24
+ config_filename = "./final_model_config.yaml"
25
+ urllib.request.urlretrieve(_model_config_url, filename=config_filename)
26
  cfg.merge_from_file(config_filename)
27
  if not torch.cuda.is_available():
28
  cfg.MODEL.DEVICE = "cpu"
 
40
 
41
  return predictor
42
 
43
+ def mask_nms(masks, scores, nms_threshold=OPTIMAL_NMS_THRESHOLD):
44
  """
45
  Runs class agnostic NMS on masks/segmentations instead of the bounding boxes.
46
  :param masks: (list float) List of coordinates that make up the mask output from the model.
 
74
  order.remove(j)
75
  return masks_kept
76
 
77
+ def apply_nms(prediction, mask=False, cls_agnostic_nms=OPTIMAL_NMS_THRESHOLD):
78
  from torchvision.ops import nms
79
  from detectron2.structures import Instances
80