Spaces:
Running
on
Zero
Running
on
Zero
Update sam2_mask.py
Browse files- sam2_mask.py +3 -3
sam2_mask.py
CHANGED
|
@@ -111,6 +111,7 @@ def process_mask(mask, expand_contract_px, expand, feathering_enabled, feather_s
|
|
| 111 |
mask = feather_mask(mask, feather_size)
|
| 112 |
return mask
|
| 113 |
|
|
|
|
| 114 |
def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
|
| 115 |
image = Image.open(input_image)
|
| 116 |
image = np.array(image.convert("RGB"))
|
|
@@ -123,7 +124,7 @@ def sam_process(input_image, checkpoint, tracking_points, trackings_input_label,
|
|
| 123 |
# sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
|
| 124 |
# Use CPU for both model and computations
|
| 125 |
# sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
|
| 126 |
-
predictor = SAM2ImagePredictor.from_pretrained(sam21_hfmap[checkpoint], device="
|
| 127 |
|
| 128 |
# predictor = SAM2ImagePredictor(sam2_model)
|
| 129 |
predictor.set_image(image)
|
|
@@ -152,8 +153,7 @@ with gr.Blocks() as demo:
|
|
| 152 |
tracking_points = gr.State([])
|
| 153 |
trackings_input_label = gr.State([])
|
| 154 |
with gr.Column():
|
| 155 |
-
gr.Markdown("# SAM2 Image Predictor
|
| 156 |
-
gr.Markdown("This version runs entirely on CPU")
|
| 157 |
with gr.Row():
|
| 158 |
with gr.Column():
|
| 159 |
input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
|
|
|
|
| 111 |
mask = feather_mask(mask, feather_size)
|
| 112 |
return mask
|
| 113 |
|
| 114 |
+
@spaces.GPU()
|
| 115 |
def sam_process(input_image, checkpoint, tracking_points, trackings_input_label, expand_contract_px, expand, feathering_enabled, feather_size):
|
| 116 |
image = Image.open(input_image)
|
| 117 |
image = np.array(image.convert("RGB"))
|
|
|
|
| 124 |
# sam2_checkpoint, model_cfg = checkpoint_map[checkpoint]
|
| 125 |
# Use CPU for both model and computations
|
| 126 |
# sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
|
| 127 |
+
predictor = SAM2ImagePredictor.from_pretrained(sam21_hfmap[checkpoint], device="cuda")
|
| 128 |
|
| 129 |
# predictor = SAM2ImagePredictor(sam2_model)
|
| 130 |
predictor.set_image(image)
|
|
|
|
| 153 |
tracking_points = gr.State([])
|
| 154 |
trackings_input_label = gr.State([])
|
| 155 |
with gr.Column():
|
| 156 |
+
gr.Markdown("# SAM2 Image Predictor / Masking Assistant")
|
|
|
|
| 157 |
with gr.Row():
|
| 158 |
with gr.Column():
|
| 159 |
input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
|