Update app.py
Browse files
app.py
CHANGED
|
@@ -12,14 +12,13 @@ import gradio as gr
|
|
| 12 |
|
| 13 |
# image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
|
| 14 |
def get_masks(model_type, image):
|
| 15 |
-
if model_type == 'vit_h':
|
| 16 |
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
if model_type == 'vit_b':
|
| 20 |
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
| 21 |
|
| 22 |
-
if model_type == 'vit_l':
|
| 23 |
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
|
| 24 |
|
| 25 |
mask_generator = SamAutomaticMaskGenerator(sam)
|
|
|
|
| 12 |
|
| 13 |
# image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
|
| 14 |
def get_masks(model_type, image):
|
| 15 |
+
if model_type.all() == 'vit_h':
|
| 16 |
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
| 17 |
+
|
| 18 |
+
if model_type,all() == 'vit_b':
|
|
|
|
| 19 |
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
| 20 |
|
| 21 |
+
if model_type.all() == 'vit_l':
|
| 22 |
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
|
| 23 |
|
| 24 |
mask_generator = SamAutomaticMaskGenerator(sam)
|