Spaces:
Runtime error
Runtime error
Update mask_adapter/sam_maskadapter.py
Browse files- mask_adapter/sam_maskadapter.py +19 -12
mask_adapter/sam_maskadapter.py
CHANGED
|
@@ -227,18 +227,18 @@ class SAMPointVisualizationDemo(object):
|
|
| 227 |
self.mask_adapter = mask_adapter
|
| 228 |
|
| 229 |
|
| 230 |
-
from .data.datasets import openseg_classes
|
| 231 |
|
| 232 |
-
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
|
| 233 |
#COCO_CATEGORIES_seg = openseg_classes.get_coco_stuff_categories_with_prompt_eng()
|
| 234 |
|
| 235 |
-
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
|
| 236 |
-
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
|
| 237 |
#print(coco_metadata)
|
| 238 |
-
lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines()
|
| 239 |
-
lvis_classes = [x[x.find(':')+1:] for x in lvis_classes]
|
| 240 |
|
| 241 |
-
self.class_names = thing_classes + stuff_classes + lvis_classes
|
| 242 |
#self.text_embedding = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy"))
|
| 243 |
|
| 244 |
self.class_names = self._load_class_names()
|
|
@@ -248,9 +248,11 @@ class SAMPointVisualizationDemo(object):
|
|
| 248 |
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
|
| 249 |
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
|
| 250 |
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
|
|
|
|
|
|
| 254 |
|
| 255 |
|
| 256 |
def extract_features_convnext(self, x):
|
|
@@ -280,7 +282,9 @@ class SAMPointVisualizationDemo(object):
|
|
| 280 |
|
| 281 |
return clip_vis_dense
|
| 282 |
|
| 283 |
-
def run_on_image_with_points(self, ori_image, points,text_features):
|
|
|
|
|
|
|
| 284 |
height, width, _ = ori_image.shape
|
| 285 |
|
| 286 |
image = ori_image
|
|
@@ -357,7 +361,10 @@ class SAMPointVisualizationDemo(object):
|
|
| 357 |
|
| 358 |
return None, Image.fromarray(overlay)
|
| 359 |
|
| 360 |
-
def run_on_image_with_boxes(self, ori_image, bbox,text_features):
|
|
|
|
|
|
|
|
|
|
| 361 |
height, width, _ = ori_image.shape
|
| 362 |
|
| 363 |
image = ori_image
|
|
|
|
| 227 |
self.mask_adapter = mask_adapter
|
| 228 |
|
| 229 |
|
| 230 |
+
#from .data.datasets import openseg_classes
|
| 231 |
|
| 232 |
+
#COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
|
| 233 |
#COCO_CATEGORIES_seg = openseg_classes.get_coco_stuff_categories_with_prompt_eng()
|
| 234 |
|
| 235 |
+
#thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
|
| 236 |
+
#stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
|
| 237 |
#print(coco_metadata)
|
| 238 |
+
#lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines()
|
| 239 |
+
#lvis_classes = [x[x.find(':')+1:] for x in lvis_classes]
|
| 240 |
|
| 241 |
+
#self.class_names = thing_classes + stuff_classes + lvis_classes
|
| 242 |
#self.text_embedding = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy"))
|
| 243 |
|
| 244 |
self.class_names = self._load_class_names()
|
|
|
|
| 248 |
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
|
| 249 |
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
|
| 250 |
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
|
| 251 |
+
ADE20K_150_CATEGORIES_ = openseg_classes.get_ade20k_categories_with_prompt_eng()
|
| 252 |
+
ade20k_thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES_ if k["isthing"] == 1]
|
| 253 |
+
ade20k_stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES_]
|
| 254 |
+
class_names = thing_classes + stuff_classes + ade20k_thing_classes+ ade20k_stuff_classes
|
| 255 |
+
return [ class_name for class_name in class_names ]
|
| 256 |
|
| 257 |
|
| 258 |
def extract_features_convnext(self, x):
|
|
|
|
| 282 |
|
| 283 |
return clip_vis_dense
|
| 284 |
|
| 285 |
+
def run_on_image_with_points(self, ori_image, points,text_features,class_names=None):
|
| 286 |
+
if class_names != None:
|
| 287 |
+
self.class_names = class_names
|
| 288 |
height, width, _ = ori_image.shape
|
| 289 |
|
| 290 |
image = ori_image
|
|
|
|
| 361 |
|
| 362 |
return None, Image.fromarray(overlay)
|
| 363 |
|
| 364 |
+
def run_on_image_with_boxes(self, ori_image, bbox,text_features,class_names=None):
|
| 365 |
+
if class_names != None:
|
| 366 |
+
self.class_names = class_names
|
| 367 |
+
|
| 368 |
height, width, _ = ori_image.shape
|
| 369 |
|
| 370 |
image = ori_image
|