anas-gouda commited on
Commit
f775271
·
1 Parent(s): c55b78a

delete image_predictor, only use automatic mode

Browse files
Files changed (2) hide show
  1. app.py +6 -63
  2. utils/models.py +1 -3
app.py CHANGED
@@ -6,7 +6,6 @@ import numpy as np
6
  import supervision as sv
7
  import torch
8
  from PIL import Image
9
- from gradio_image_prompter import ImagePrompter
10
 
11
  from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \
12
  MASK_GENERATION_MODE, BOX_PROMPT_MODE
@@ -46,7 +45,7 @@ if torch.cuda.get_device_properties(0).major >= 8:
46
  torch.backends.cudnn.allow_tf32 = True
47
 
48
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
49
- IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE)
50
 
51
 
52
  @spaces.GPU
@@ -54,39 +53,13 @@ IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE)
54
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
55
  def process(
56
  checkpoint_dropdown,
57
- mode_dropdown,
58
  image_input,
59
- image_prompter_input
60
  ) -> Optional[Image.Image]:
61
- if mode_dropdown == BOX_PROMPT_MODE:
62
- image_input = image_prompter_input["image"]
63
- prompt = image_prompter_input["points"]
64
- if len(prompt) == 0:
65
- return image_input
66
-
67
- model = IMAGE_PREDICTORS[checkpoint_dropdown]
68
- image = np.array(image_input.convert("RGB"))
69
- box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt])
70
-
71
- model.set_image(image)
72
- masks, _, _ = model.predict(box=box, multimask_output=False)
73
-
74
- # dirty fix; remove this later
75
- if len(masks.shape) == 4:
76
- masks = np.squeeze(masks)
77
-
78
- detections = sv.Detections(
79
- xyxy=sv.mask_to_xyxy(masks=masks),
80
- mask=masks.astype(bool)
81
- )
82
- return MASK_ANNOTATOR.annotate(image_input, detections)
83
-
84
- if mode_dropdown == MASK_GENERATION_MODE:
85
- model = MASK_GENERATORS[checkpoint_dropdown]
86
- image = np.array(image_input.convert("RGB"))
87
- result = model.generate(image)
88
- detections = sv.Detections.from_sam(result)
89
- return MASK_ANNOTATOR.annotate(image_input, detections)
90
 
91
 
92
  with gr.Blocks() as demo:
@@ -98,21 +71,10 @@ with gr.Blocks() as demo:
98
  label="Checkpoint", info="Select a SAM2 checkpoint to use.",
99
  interactive=True
100
  )
101
- mode_dropdown_component = gr.Dropdown(
102
- choices=MODE_NAMES,
103
- value=MODE_NAMES[0],
104
- label="Mode",
105
- info="Select a mode to use. `box prompt` if you want to generate masks for "
106
- "selected objects, `mask generation` if you want to generate masks "
107
- "for the whole image.",
108
- interactive=True
109
- )
110
  with gr.Row():
111
  with gr.Column():
112
  image_input_component = gr.Image(
113
  type='pil', label='Upload image')
114
- image_prompter_input_component = ImagePrompter(
115
- type='pil', label='Image prompt', visible=False)
116
  submit_button_component = gr.Button(
117
  value='Submit', variant='primary')
118
  with gr.Column():
@@ -123,37 +85,18 @@ with gr.Blocks() as demo:
123
  examples=EXAMPLES,
124
  inputs=[
125
  checkpoint_dropdown_component,
126
- mode_dropdown_component,
127
  image_input_component,
128
- image_prompter_input_component,
129
  ],
130
  outputs=[image_output_component],
131
  cache_examples=False,
132
  run_on_click=True
133
  )
134
 
135
-
136
- def on_mode_dropdown_change(text):
137
- return [
138
- gr.Image(visible=text == MASK_GENERATION_MODE),
139
- ImagePrompter(visible=text == BOX_PROMPT_MODE)
140
- ]
141
-
142
- mode_dropdown_component.change(
143
- on_mode_dropdown_change,
144
- inputs=[mode_dropdown_component],
145
- outputs=[
146
- image_input_component,
147
- image_prompter_input_component
148
- ]
149
- )
150
  submit_button_component.click(
151
  fn=process,
152
  inputs=[
153
  checkpoint_dropdown_component,
154
- mode_dropdown_component,
155
  image_input_component,
156
- image_prompter_input_component,
157
  ],
158
  outputs=[image_output_component]
159
  )
 
6
  import supervision as sv
7
  import torch
8
  from PIL import Image
 
9
 
10
  from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \
11
  MASK_GENERATION_MODE, BOX_PROMPT_MODE
 
45
  torch.backends.cudnn.allow_tf32 = True
46
 
47
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
48
+ MASK_GENERATORS = load_models(device=DEVICE)
49
 
50
 
51
  @spaces.GPU
 
53
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
54
  def process(
55
  checkpoint_dropdown,
 
56
  image_input,
 
57
  ) -> Optional[Image.Image]:
58
+ model = MASK_GENERATORS[checkpoint_dropdown]
59
+ image = np.array(image_input.convert("RGB"))
60
+ result = model.generate(image)
61
+ detections = sv.Detections.from_sam(result)
62
+ return MASK_ANNOTATOR.annotate(image_input, detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  with gr.Blocks() as demo:
 
71
  label="Checkpoint", info="Select a SAM2 checkpoint to use.",
72
  interactive=True
73
  )
 
 
 
 
 
 
 
 
 
74
  with gr.Row():
75
  with gr.Column():
76
  image_input_component = gr.Image(
77
  type='pil', label='Upload image')
 
 
78
  submit_button_component = gr.Button(
79
  value='Submit', variant='primary')
80
  with gr.Column():
 
85
  examples=EXAMPLES,
86
  inputs=[
87
  checkpoint_dropdown_component,
 
88
  image_input_component,
 
89
  ],
90
  outputs=[image_output_component],
91
  cache_examples=False,
92
  run_on_click=True
93
  )
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  submit_button_component.click(
96
  fn=process,
97
  inputs=[
98
  checkpoint_dropdown_component,
 
99
  image_input_component,
 
100
  ],
101
  outputs=[image_output_component]
102
  )
utils/models.py CHANGED
@@ -21,11 +21,9 @@ CHECKPOINTS = {
21
  def load_models(
22
  device: torch.device
23
  ) -> Tuple[Dict[str, SAM2ImagePredictor], Dict[str, SAM2AutomaticMaskGenerator]]:
24
- image_predictors = {}
25
  mask_generators = {}
26
  for key, (config, checkpoint) in CHECKPOINTS.items():
27
  model = build_sam2(config, checkpoint, device=device)
28
- image_predictors[key] = SAM2ImagePredictor(sam_model=model)
29
  mask_generators[key] = SAM2AutomaticMaskGenerator(
30
  model=model,
31
  points_per_side=32,
@@ -36,4 +34,4 @@ def load_models(
36
  crop_n_layers=1,
37
  box_nms_thresh=0.7,
38
  )
39
- return image_predictors, mask_generators
 
21
  def load_models(
22
  device: torch.device
23
  ) -> Tuple[Dict[str, SAM2ImagePredictor], Dict[str, SAM2AutomaticMaskGenerator]]:
 
24
  mask_generators = {}
25
  for key, (config, checkpoint) in CHECKPOINTS.items():
26
  model = build_sam2(config, checkpoint, device=device)
 
27
  mask_generators[key] = SAM2AutomaticMaskGenerator(
28
  model=model,
29
  points_per_side=32,
 
34
  crop_n_layers=1,
35
  box_nms_thresh=0.7,
36
  )
37
+ return mask_generators