anas-gouda commited on
Commit
5797fff
·
1 Parent(s): ae90ebf

remove unused params, tune SAM2 params

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. utils/models.py +3 -7
app.py CHANGED
@@ -9,8 +9,7 @@ from PIL import Image
9
  import dounseen.utils as dounseen_utils
10
  import cv2
11
 
12
- from utils.models import load_sam2_models, CHECKPOINT_NAMES, MODE_NAMES, \
13
- MASK_GENERATION_MODE, BOX_PROMPT_MODE, load_dounseen_model
14
 
15
  # TODO add presentation on YouTube and add link here
16
  MARKDOWN = """
@@ -89,9 +88,10 @@ with gr.Blocks() as demo:
89
  with gr.Row():
90
  object_image_1 = gr.Image(type='pil', label=f'Object Image 1')
91
  object_image_2 = gr.Image(type='pil', label=f'Object Image 2')
92
- object_image_3 = gr.Image(type='pil', label=f'Object Image 3')
93
  with gr.Row():
 
94
  object_image_4 = gr.Image(type='pil', label=f'Object Image 4')
 
95
  object_image_5 = gr.Image(type='pil', label=f'Object Image 5')
96
  object_image_6 = gr.Image(type='pil', label=f'Object Image 6')
97
  object_images = [object_image_1, object_image_2, object_image_3, object_image_4, object_image_5, object_image_6]
 
9
  import dounseen.utils as dounseen_utils
10
  import cv2
11
 
12
+ from utils.models import load_sam2_models, CHECKPOINT_NAMES, load_dounseen_model
 
13
 
14
  # TODO add presentation on YouTube and add link here
15
  MARKDOWN = """
 
88
  with gr.Row():
89
  object_image_1 = gr.Image(type='pil', label=f'Object Image 1')
90
  object_image_2 = gr.Image(type='pil', label=f'Object Image 2')
 
91
  with gr.Row():
92
+ object_image_3 = gr.Image(type='pil', label=f'Object Image 3')
93
  object_image_4 = gr.Image(type='pil', label=f'Object Image 4')
94
+ with gr.Row():
95
  object_image_5 = gr.Image(type='pil', label=f'Object Image 5')
96
  object_image_6 = gr.Image(type='pil', label=f'Object Image 6')
97
  object_images = [object_image_1, object_image_2, object_image_3, object_image_4, object_image_5, object_image_6]
utils/models.py CHANGED
@@ -6,10 +6,6 @@ from sam2.build_sam import build_sam2
6
  from sam2.sam2_image_predictor import SAM2ImagePredictor
7
  from dounseen.core import UnseenClassifier
8
 
9
- BOX_PROMPT_MODE = "box prompt"
10
- MASK_GENERATION_MODE = "mask generation"
11
- MODE_NAMES = [BOX_PROMPT_MODE, MASK_GENERATION_MODE]
12
-
13
  CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
14
  CHECKPOINTS = {
15
  "tiny": ["sam2_hiera_t.yaml", "models/sam2/sam2_hiera_tiny.pt"],
@@ -27,12 +23,12 @@ def load_sam2_models(
27
  model = build_sam2(config, checkpoint, device=device)
28
  mask_generators[key] = SAM2AutomaticMaskGenerator(
29
  model=model,
30
- points_per_side=16,
31
- points_per_batch=64,
32
  pred_iou_thresh=0.7,
33
  stability_score_thresh=0.92,
34
  stability_score_offset=0.7,
35
- crop_n_layers=0,
36
  box_nms_thresh=0.7,
37
  )
38
  return mask_generators
 
6
  from sam2.sam2_image_predictor import SAM2ImagePredictor
7
  from dounseen.core import UnseenClassifier
8
 
 
 
 
 
9
  CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
10
  CHECKPOINTS = {
11
  "tiny": ["sam2_hiera_t.yaml", "models/sam2/sam2_hiera_tiny.pt"],
 
23
  model = build_sam2(config, checkpoint, device=device)
24
  mask_generators[key] = SAM2AutomaticMaskGenerator(
25
  model=model,
26
+ points_per_side=64,
27
+ points_per_batch=128,
28
  pred_iou_thresh=0.7,
29
  stability_score_thresh=0.92,
30
  stability_score_offset=0.7,
31
+ crop_n_layers=1,
32
  box_nms_thresh=0.7,
33
  )
34
  return mask_generators