Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from torchvision import transforms | |
| from task_adapter.utils.visualizer import Visualizer | |
| from typing import Tuple | |
| from PIL import Image | |
| from detectron2.data import MetadataCatalog | |
| metadata = MetadataCatalog.get('coco_2017_train_panoptic') | |
| class SemanticSAMPredictor: | |
| def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100): | |
| """ | |
| thresh: iou thresh to filter low confidence objects | |
| text_size: resize the input image short edge for the model to process | |
| hole_scale: fill in small holes as in SAM | |
| island_scale: remove small regions as in SAM | |
| """ | |
| self.model = model | |
| self.thresh = thresh | |
| self.text_size = hole_scale | |
| self.hole_scale = hole_scale | |
| self.island_scale = island_scale | |
| self.point = None | |
| def predict(self, image_ori, image, point=None): | |
| """ | |
| produce up to 6 prediction results for each click | |
| """ | |
| width = image_ori.shape[0] | |
| height = image_ori.shape[1] | |
| data = {"image": image, "height": height, "width": width} | |
| # import ipdb; ipdb.set_trace() | |
| if point is None: | |
| point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda() | |
| else: | |
| point = torch.tensor(point).cuda() | |
| point_ = point | |
| point = point_.clone() | |
| point[0, 0] = point_[0, 0] | |
| point[0, 1] = point_[0, 1] | |
| # point = point[:, [1, 0]] | |
| point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1) | |
| self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point)) | |
| data['targets'] = [dict()] | |
| data['targets'][0]['points'] = point | |
| data['targets'][0]['pb'] = point.new_tensor([0.]) | |
| batch_inputs = [data] | |
| masks, ious = self.model.model.evaluate_demo(batch_inputs) | |
| return masks, ious | |
| def process_multi_mask(self, masks, ious, image_ori): | |
| pred_masks_poses = masks | |
| reses = [] | |
| ious = ious[0, 0] | |
| ids = torch.argsort(ious, descending=True) | |
| text_res = '' | |
| mask_ls = [] | |
| ious_res = [] | |
| areas = [] | |
| for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])): | |
| iou = round(float(iou), 2) | |
| texts = f'{iou}' | |
| mask = (pred_masks_pos > 0.0).cpu().numpy() | |
| area = mask.sum() | |
| conti = False | |
| if iou < self.thresh: | |
| conti = True | |
| for m in mask_ls: | |
| if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95: | |
| conti = True | |
| break | |
| if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []: | |
| conti = False | |
| if conti: | |
| continue | |
| ious_res.append(iou) | |
| mask_ls.append(mask) | |
| areas.append(area) | |
| mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes") | |
| mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands") | |
| mask = (mask).astype(np.float) | |
| out_txt = texts | |
| visual = Visualizer(image_ori, metadata=metadata) | |
| color = [0., 0., 1.0] | |
| demo = visual.draw_binary_mask(mask, color=color, text=texts) | |
| res = demo.get_image() | |
| point_x0 = max(0, int(self.point[0, 0]) - 3) | |
| point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3) | |
| point_y0 = max(0, int(self.point[0, 1]) - 3) | |
| point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3) | |
| res[point_y0:point_y1, point_x0:point_x1, 0] = 255 | |
| res[point_y0:point_y1, point_x0:point_x1, 1] = 0 | |
| res[point_y0:point_y1, point_x0:point_x1, 2] = 0 | |
| reses.append(Image.fromarray(res)) | |
| text_res = text_res + ';' + out_txt | |
| ids = list(torch.argsort(torch.tensor(areas), descending=False)) | |
| ids = [int(i) for i in ids] | |
| torch.cuda.empty_cache() | |
| return reses, [reses[i] for i in ids] | |
| def predict_masks(self, image_ori, image, point=None): | |
| masks, ious = self.predict(image_ori, image, point) | |
| return self.process_multi_mask(masks, ious, image_ori) | |
| def remove_small_regions( | |
| mask: np.ndarray, area_thresh: float, mode: str | |
| ) -> Tuple[np.ndarray, bool]: | |
| """ | |
| Removes small disconnected regions and holes in a mask. Returns the | |
| mask and an indicator of if the mask has been modified. | |
| """ | |
| import cv2 # type: ignore | |
| assert mode in ["holes", "islands"] | |
| correct_holes = mode == "holes" | |
| working_mask = (correct_holes ^ mask).astype(np.uint8) | |
| n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) | |
| sizes = stats[:, -1][1:] # Row 0 is background label | |
| small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] | |
| if len(small_regions) == 0: | |
| return mask, False | |
| fill_labels = [0] + small_regions | |
| if not correct_holes: | |
| fill_labels = [i for i in range(n_labels) if i not in fill_labels] | |
| # If every region is below threshold, keep largest | |
| if len(fill_labels) == 0: | |
| fill_labels = [int(np.argmax(sizes)) + 1] | |
| mask = np.isin(regions, fill_labels) | |
| return mask, True | |