import os from pycocotools import mask as mask_util import json import numpy as np import cv2 from distinctipy import distinctipy import matplotlib.pyplot as plt from PIL import Image from types import MethodType import json import random import sys import torch import torchvision from detectron2.data import MetadataCatalog from detectron2.structures import BitMasks, PolygonMasks from detectron2.utils.visualizer import ColorMode, Visualizer from detectron2.data.detection_utils import read_image from fvcore.common.timer import Timer from third_parts.APE.build_ape import build_ape_predictor from third_parts.recognize_anything.build_ram_plus import build_ram_predictor from third_parts.segment_anything import build_sam_vit_h, SamPredictor, SamAutomaticMaskGenerator def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def sample_points(box, mask, min_points=3, max_points=16): x0, y0, w, h = box aspect_ratio = w / h # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_points, max_points + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_points and i * j >= min_points) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, w, h, 50) width_bin = w / target_aspect_ratio[0] height_bin = h / target_aspect_ratio[1] ret_points = [] for wi in range(target_aspect_ratio[0]): xi = x0 + (wi+0.5) * width_bin for hi in range(target_aspect_ratio[1]): yi = y0 + (hi+0.5) * height_bin if mask[int(yi), int(xi)] > 0: ret_points.append((xi, yi)) # if len(ret_points) < min_points: temp_points = [] for wi in range(int(x0), int(x0+w)): for hi in range(int(y0), int(y0+h)): if mask[int(hi), int(wi)] > 0: temp_points.append((wi, hi)) if len(temp_points)//max_points < 1: uniform_indices = list(range(0, len(temp_points))) else: uniform_indices = list(range(0, len(temp_points), len(temp_points)//max_points)) additional_points = [temp_points[uniform_idx] for uniform_idx in uniform_indices[1:-1]] # ret_points = [temp_points[uniform_indices[1]], temp_points[uniform_indices[2]], temp_points[uniform_indices[3]]] ret_points = ret_points + additional_points return ret_points def mask_iou(masks, chunk_size=50, chunk_mode=False): masks1 = masks.unsqueeze(1).char() # n, 1, h, w masks2 = masks.unsqueeze(0).char() # 1, n, h, w if not chunk_mode: intersection = (masks1 * masks2) union = (masks1 + masks2 - intersection).sum(-1).sum(-1) intersection = intersection.sum(-1).sum(-1) return intersection, union def chunk_mask_iou(_chunk_size=50): num_chunks = masks1.shape[0] // _chunk_size if masks1.shape[0] % _chunk_size > 0: num_chunks += 1 row_chunks_intersection, row_chunks_union = [], [] for row_idx in range(num_chunks): col_chunks_intersection, col_chunks_union = [], [] masks1_chunk = masks1[row_idx*_chunk_size:(row_idx+1)*_chunk_size] for col_idx in range(num_chunks): masks2_chunk = masks2[:, col_idx*_chunk_size:(col_idx+1)*_chunk_size] try: intersection = masks1_chunk * masks2_chunk temp_sum = masks1_chunk + masks2_chunk union = (temp_sum - intersection).sum(-1).sum(-1) intersection = intersection.sum(-1).sum(-1) except torch.cuda.OutOfMemoryError: return False, None, None col_chunks_intersection.append(intersection) col_chunks_union.append(union) row_chunks_intersection.append(torch.cat(col_chunks_intersection, dim=1)) row_chunks_union.append(torch.cat(col_chunks_union, dim=1)) intersection = torch.cat(row_chunks_intersection, dim=0) union = torch.cat(row_chunks_union, dim=0) return True, intersection, union for c_size in [chunk_size, chunk_size//2, chunk_size//4]: is_ok, intersection, union = chunk_mask_iou(c_size) if not is_ok: continue return intersection, union def mask_iou_v2(masks1, masks2, chunk_size=50, chunk_mode=False): masks1 = masks1.unsqueeze(1).char() # n, 1, h, w masks2 = masks2.unsqueeze(0).char() # 1, m, h, w if not chunk_mode: intersection = (masks1 * masks2) union = (masks1 + masks2 - intersection).sum(-1).sum(-1) intersection = intersection.sum(-1).sum(-1) return intersection, union def chunk_mask_iou(_chunk_size=50): num_chunks1 = masks1.shape[0] // _chunk_size if masks1.shape[0] % _chunk_size > 0: num_chunks1 += 1 num_chunks2 = masks2.shape[1] // _chunk_size if masks2.shape[0] % _chunk_size > 0: num_chunks2 += 1 row_chunks_intersection, row_chunks_union = [], [] for row_idx in range(num_chunks1): col_chunks_intersection, col_chunks_union = [], [] masks1_chunk = masks1[row_idx*_chunk_size:(row_idx+1)*_chunk_size] for col_idx in range(num_chunks2): masks2_chunk = masks2[:, col_idx*_chunk_size:(col_idx+1)*_chunk_size] try: intersection = masks1_chunk * masks2_chunk temp_sum = masks1_chunk + masks2_chunk union = (temp_sum - intersection).sum(-1).sum(-1) intersection = intersection.sum(-1).sum(-1) except torch.cuda.OutOfMemoryError: return False, None, None col_chunks_intersection.append(intersection) col_chunks_union.append(union) row_chunks_intersection.append(torch.cat(col_chunks_intersection, dim=1)) row_chunks_union.append(torch.cat(col_chunks_union, dim=1)) intersection = torch.cat(row_chunks_intersection, dim=0) union = torch.cat(row_chunks_union, dim=0) return True, intersection, union for c_size in [chunk_size, chunk_size//2, chunk_size//4]: is_ok, intersection, union = chunk_mask_iou(c_size) if not is_ok: continue return intersection, union return intersection, union def mask_area(masks, chunk_size=50, chunk_mode=False): if not chunk_mode: return masks.sum(-1).sum(-1) num_chunks = masks.shape[0] // chunk_size if masks.shape[0] % chunk_size > 0: num_chunks += 1 areas = [] for i in range(num_chunks): masks_i = masks[i*chunk_size:(i+1)*chunk_size] areas.append(masks_i.sum(-1).sum(-1)) return torch.cat(areas, dim=0) from detectron2.utils.visualizer import GenericMask import matplotlib.colors as mplc def draw_instance_predictions_cache(self, labels, np_masks, jittering: bool = True): """ Draw instance-level prediction results on an image. Args: predictions (Instances): the output of an instance detection/segmentation model. Following fields will be used to draw: "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class to distinguish instances from the same class Returns: output (VisImage): image object with visualizations. """ boxes = None scores = None classes = None keypoints = None masks = [GenericMask(x, self.output.height, self.output.width) for x in np_masks] if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): colors = ( [self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes] if jittering else [ tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]])) for c in classes ] ) alpha = 0.8 else: colors = None alpha = 0.5 self.overlay_instances( masks=masks, boxes=boxes, labels=labels, keypoints=keypoints, assigned_colors=colors, alpha=alpha, ) return self.output def merge_sa1b_image(image_file, anno_file, save_path, generated_annos, visualize=False): file_name = os.path.basename(image_file).split('.')[0] if anno_file is not None: with open(anno_file, 'r') as f: json_results = json.load(f) generated_annos = json_results["annotations"] assert generated_annos is not None, "Provide valid annotation file or generated_annos from sam automatic generator." _all_sam_masks, predicted_iou_scores = [], [] for object_anno in generated_annos: object_mask = object_anno["segmentation"] if isinstance(object_mask["counts"], list): object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) mask = mask_util.decode(object_mask) mask = mask.astype(np.uint8).squeeze() _all_sam_masks.append(torch.from_numpy(mask)) predicted_iou_scores.append(object_anno['predicted_iou']) #TODO sorted the masks list according to the iou score from high to low sorted_idx = sorted(range(len(predicted_iou_scores)), key=lambda k: predicted_iou_scores[k], reverse=True) all_sam_masks = [] for idx in sorted_idx: all_sam_masks.append(_all_sam_masks[idx]) all_sam_masks = torch.stack(all_sam_masks) ori_height, ori_width = all_sam_masks.shape[-2:] downsampled_sam_masks = torch.nn.functional.interpolate(all_sam_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") downsampled_sam_masks = (downsampled_sam_masks[0] > 0.5).to(all_sam_masks.dtype).to("cuda") intersection, union = mask_iou(downsampled_sam_masks, chunk_size=100, chunk_mode=True) mask_iou_matrix = intersection / union # nms num_instances = len(mask_iou_matrix) keep = [True] * num_instances for ins_i in range(num_instances): if not keep[ins_i]: continue for ins_j in range(ins_i, num_instances): if ins_j == ins_i: continue if mask_iou_matrix[ins_i, ins_j] > 0.8: keep[ins_j] = False # merge # area = downsampled_sam_masks.sum(-1).sum(-1) area = mask_area(downsampled_sam_masks, chunk_mode=True, chunk_size=100) roc = intersection / area[:, None] for ins_i in range(num_instances): if not keep[ins_i]: continue for ins_j in range(num_instances): if ins_i == ins_j: continue if not keep[ins_j]: continue if roc[ins_i, ins_j] > 0.8: keep[ins_i] = False break left_masks = [all_sam_masks[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] left_tags = ['object' for _ in range(len(left_masks))] unique_tags = list(set(left_tags)) text_prompt = ','.join(unique_tags) metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) metadata.thing_classes = unique_tags metadata.stuff_classes = unique_tags if not visualize: return torch.stack(left_masks) def run_on_image_v2(image_file, anno_file, save_path, ram_predictor, ape_predictor, sam_predictor, sam_auto_mask_generator, visualize=False): if not os.path.exists(image_file): return None file_name = os.path.basename(image_file).split('.')[0] if (anno_file is None) or (not os.path.exists(anno_file)): image = cv2.imread(image_file) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) generated_annos = sam_auto_mask_generator.generate(image) sam_masks = merge_sa1b_image(image_file, None, save_path, generated_annos, visualize=False) else: sam_masks = merge_sa1b_image(image_file, anno_file, save_path, None, visualize=False) ape_masks, ape_tags = run_on_image(image_file, save_path, ram_predictor, ape_predictor, sam_predictor, visualize=False) if ape_masks is None: return None sam_image = cv2.imread(image_file) ori_height, ori_width = sam_image.shape[:2] sam_image = cv2.cvtColor(sam_image, cv2.COLOR_BGR2RGB) sam_predictor.set_image(sam_image) # has been set in the `run_on_image` function ori_height, ori_width = sam_masks.shape[-2:] downsampled_sam_masks = torch.nn.functional.interpolate(sam_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") downsampled_sam_masks = (downsampled_sam_masks[0] > 0.5).to(sam_masks.dtype).to("cuda") downsampled_ape_masks = torch.nn.functional.interpolate(ape_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") downsampled_ape_masks = (downsampled_ape_masks[0] > 0.5).to(ape_masks.dtype).to("cuda") sam_ape_masks_intersection, sam_ape_masks_union = mask_iou_v2(downsampled_sam_masks, downsampled_ape_masks, chunk_size=100, chunk_mode=True) sam_ape_masks_iou = sam_ape_masks_intersection / sam_ape_masks_union # sam_area = downsampled_sam_masks.sum(-1).sum(-1) sam_area = mask_area(downsampled_sam_masks, chunk_mode=True, chunk_size=100) sam_masks_roc = sam_ape_masks_intersection / sam_area[:, None] sam_boxes = torchvision.ops.masks_to_boxes(sam_masks) ape_boxes = torchvision.ops.masks_to_boxes(ape_masks) first_round_masks = [] iou_target_indices = torch.argmax(sam_ape_masks_iou, dim=1) roc_target_indices = torch.argmax(sam_masks_roc, dim=1) for sam_idx in range(downsampled_sam_masks.shape[0]): iou_tgt_idx = iou_target_indices[sam_idx] roc_tgt_idx = roc_target_indices[sam_idx] if sam_ape_masks_iou[sam_idx, iou_tgt_idx] > 0.8: first_round_masks.append(sam_masks[sam_idx]) elif sam_masks_roc[sam_idx, roc_tgt_idx] > 0.8: # sam mask inside ape mask box_x1, box_y1, box_x2, box_y2 = sam_boxes[sam_idx] box_w = box_x2 - box_x1 box_h = box_y2 - box_y1 ret_points = sample_points([box_x1, box_y1, box_w, box_h], sam_masks[sam_idx], min_points=1, max_points=3) if len(ret_points) == 0 : first_round_masks.append(sam_masks[sam_idx]) else: point_labels = [1 for _ in range(len(ret_points))] temp_masks, scores, _ = sam_predictor.predict( point_coords=np.array(ret_points), point_labels=np.array(point_labels), multimask_output=True, ) temp_masks = torch.from_numpy(temp_masks) downsampled_temp_masks = torch.nn.functional.interpolate(temp_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") downsampled_temp_masks = (downsampled_temp_masks[0] > 0.5).to(temp_masks.dtype).to("cuda") downsampled_ape_mask = downsampled_ape_masks[roc_tgt_idx][None] ape_temp_masks_intersection, ape_temp_masks_union = mask_iou_v2(downsampled_ape_mask, downsampled_temp_masks) ape_temp_masks_iou = ape_temp_masks_intersection / ape_temp_masks_union iou_temp_indices = torch.argmax(ape_temp_masks_iou, dim=1) iou_temp_idx = iou_temp_indices[0] if ape_temp_masks_iou[0, iou_temp_idx] > 0.8 and scores[iou_temp_idx] > 0.9: first_round_masks.append(temp_masks[iou_temp_idx]) else: first_round_masks.append(sam_masks[sam_idx]) else: # first_round_masks.append(sam_masks[sam_idx]) box_x1, box_y1, box_x2, box_y2 = sam_boxes[sam_idx] box_w = box_x2 - box_x1 box_h = box_y2 - box_y1 ret_points = sample_points([box_x1, box_y1, box_w, box_h], sam_masks[sam_idx], min_points=1, max_points=3) if len(ret_points) == 0: first_round_masks.append(sam_masks[sam_idx]) else: point_labels = [1 for _ in range(len(ret_points))] temp_masks, scores, _ = sam_predictor.predict( point_coords=np.array(ret_points), point_labels=np.array(point_labels), multimask_output=True, ) temp_masks = torch.from_numpy(temp_masks) temp_masks_area = temp_masks.sum(-1).sum(-1) tgt_idx = torch.argmax(temp_masks_area) if scores[tgt_idx] > 0.9: first_round_masks.append(temp_masks[tgt_idx]) else: first_round_masks.append(sam_masks[sam_idx]) ape_sam_masks_intersection, ape_sam_masks_union = sam_ape_masks_intersection.transpose(0, 1), sam_ape_masks_union.transpose(0, 1) # ape_area = ape_masks.sum(-1).sum(-1) ape_area = mask_area(downsampled_ape_masks, chunk_mode=True, chunk_size=100) ape_masks_roc = ape_sam_masks_intersection / ape_area[:, None] roc_target_indices = torch.argmax(ape_masks_roc, dim=1) for ape_idx in range(ape_masks.shape[0]): roc_tgt_idx = roc_target_indices[ape_idx] if ape_masks_roc[ape_idx, roc_tgt_idx] < 0.2: if sam_masks_roc[roc_tgt_idx, ape_idx] < 0.2: box_x1, box_y1, box_x2, box_y2 = ape_boxes[ape_idx] box_w = box_x2 - box_x1 box_h = box_y2 - box_y1 ret_points = sample_points([box_x1, box_y1, box_w, box_h], ape_masks[ape_idx], min_points=3, max_points=16) if len(ret_points) == 0: first_round_masks.append(ape_masks[ape_idx]) else: point_labels = [1 for _ in range(len(ret_points))] temp_masks, scores, _ = sam_predictor.predict( point_coords=np.array(ret_points), point_labels=np.array(point_labels), multimask_output=False, ) temp_masks = torch.from_numpy(temp_masks) if scores[0] > 0.9: first_round_masks.append(temp_masks[0]) else: first_round_masks.append(ape_masks[ape_idx]) else: # some sam masks inside ape masks, but they are not in object-level box_x1, box_y1, box_x2, box_y2 = ape_boxes[ape_idx] box_w = box_x2 - box_x1 box_h = box_y2 - box_y1 ret_points = sample_points([box_x1, box_y1, box_w, box_h], ape_masks[ape_idx], min_points=3, max_points=8) for point in ret_points: temp_masks, scores, _ = sam_predictor.predict( point_coords=np.array([point]), point_labels=np.array([1]), multimask_output=True, ) temp_masks = torch.from_numpy(temp_masks) downsampled_temp_masks = torch.nn.functional.interpolate(temp_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") downsampled_temp_masks = (downsampled_temp_masks[0] > 0.5).to(temp_masks.dtype).to("cuda") downsampled_ape_mask = downsampled_ape_masks[ape_idx][None] ape_temp_masks_intersection, ape_temp_masks_union = mask_iou_v2(downsampled_ape_mask, downsampled_temp_masks) ape_temp_masks_iou = ape_temp_masks_intersection / ape_temp_masks_union iou_temp_indices = torch.argmax(ape_temp_masks_iou, dim=1) iou_temp_idx = iou_temp_indices[0] if ape_temp_masks_iou[0, iou_temp_idx] > 0.8: first_round_masks.append(ape_masks[ape_idx]) # first_round_scores = [mask.sum(-1).sum(-1) for mask in first_round_masks] first_round_scores = mask_area(torch.stack(first_round_masks), chunk_mode=True, chunk_size=100) sorted_idx = sorted(range(len(first_round_masks)), key=lambda k: first_round_scores[k], reverse=True) sorted_first_round_masks = [] for idx in sorted_idx: sorted_first_round_masks.append(first_round_masks[idx]) sorted_first_round_masks = torch.stack(sorted_first_round_masks) downsampled_first_round_masks = torch.nn.functional.interpolate(sorted_first_round_masks[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") downsampled_first_round_masks = (downsampled_first_round_masks[0] > 0.5).to(sorted_first_round_masks.dtype) intersection, union = mask_iou(downsampled_first_round_masks, chunk_mode=True, chunk_size=100) mask_iou_matrix = intersection / union # nms num_instances = len(mask_iou_matrix) keep = [True] * num_instances for ins_i in range(num_instances): if not keep[ins_i]: continue for ins_j in range(ins_i, num_instances): if ins_j == ins_i: continue if mask_iou_matrix[ins_i, ins_j] > 0.8: keep[ins_j] = False # merge # area = downsampled_first_round_masks.sum(-1).sum(-1) area = mask_area(downsampled_first_round_masks, chunk_mode=True, chunk_size=100) roc = intersection / area[:, None] for ins_i in range(num_instances): if not keep[ins_i]: continue for ins_j in range(num_instances): if ins_i == ins_j: continue if not keep[ins_j]: continue if roc[ins_i, ins_j] > 0.5: keep[ins_i] = False break left_masks = [sorted_first_round_masks[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] if visualize: left_tags = ['object' for _ in range(len(left_masks))] unique_tags = list(set(left_tags)) text_prompt = ','.join(unique_tags) metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) metadata.thing_classes = unique_tags metadata.stuff_classes = unique_tags result_masks = torch.stack(left_masks).cpu().numpy() input_image = read_image(image_file, format="BGR") visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE) visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer) vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks) output_image = vis_output.get_image() output_image = Image.fromarray(output_image) final_out_path = "./work_dirs/visualize_object_level" if not os.path.exists(final_out_path): os.makedirs(final_out_path) output_image.save(os.path.join(final_out_path, file_name+'.jpg')) else: result_masks = torch.stack(left_masks).cpu().numpy() save_json_results = [] for ins_i, mask in enumerate(result_masks): rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0] rle["counts"] = rle["counts"].decode("utf-8") save_json_results.append({ "ins_id": ins_i, "segmentation": rle, }) with open(os.path.join(save_path, file_name+'.json'), 'w') as f: json.dump(save_json_results, f) def run_on_image(image_file, save_path, ram_predictor, ape_predictor, sam_predictor, visualize=False): res = ram_predictor.run_on_image(image_file_path=image_file, dynamic_resolution=True) tag_list = [] for tag_string in res[0]: tags = tag_string.split(' | ') tag_list += tags tags = list(set(tag_list)) text_prompt = ','.join(tags) output_image, json_results = ape_predictor.run_on_image( image_file, input_text=text_prompt, visualize=True, score_threhold=0.1, output_type=["instance segmentation"], ) if visualize: file_name = os.path.basename(image_file).split('.')[0] raw_ape_out_path = os.path.join(save_path, 'raw_ape_out_0116') if not os.path.exists(raw_ape_out_path): os.makedirs(raw_ape_out_path) output_image.save(os.path.join(raw_ape_out_path, file_name+'.jpg')) # sam segment # colors = distinctipy.get_colors(len(json_results)+1) sam_image = cv2.imread(image_file) ori_height, ori_width = sam_image.shape[:2] sam_image = cv2.cvtColor(sam_image, cv2.COLOR_BGR2RGB) sam_predictor.set_image(sam_image) new_masks_from_sam = [] correspondding_tags = [] correspondding_scores = [] # the scores has been sorted inside the APE for idx, item in enumerate(json_results): object_mask = item["segmentation"] if isinstance(object_mask["counts"], list): object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) mask = mask_util.decode(object_mask) mask = mask.astype(np.uint8).squeeze() box = item["bbox"] ret_points = sample_points(box, mask) if len(ret_points) == 0: continue mask_h, mask_w = object_mask["size"] input_point, input_label = [], [] for point in ret_points: _x = point[0] / mask_w * ori_width _y = point[1] / mask_h * ori_height input_point.append([int(_x), int(_y)]) input_label.append(1) masks, scores, logits = sam_predictor.predict( point_coords=np.array(input_point), point_labels=np.array(input_label), multimask_output=False ) new_masks_from_sam.append(torch.from_numpy(masks)) correspondding_tags.append(item["category_name"]) correspondding_scores.append(item["score"]) if len(new_masks_from_sam) == 0: return None, None new_masks_from_sam = torch.cat(new_masks_from_sam) downsampled_new_masks_from_sam = torch.nn.functional.interpolate(new_masks_from_sam[None].to(torch.float32), size=(ori_height//4, ori_width//4), mode="bilinear") downsampled_new_masks_from_sam = (downsampled_new_masks_from_sam[0] > 0.5).to(new_masks_from_sam.dtype).to("cuda") intersection, union = mask_iou(downsampled_new_masks_from_sam, chunk_mode=True, chunk_size=100) mask_iou_matrix = intersection / union # nms num_instances = len(mask_iou_matrix) keep = [True] * num_instances for ins_i in range(num_instances): if not keep[ins_i]: continue for ins_j in range(ins_i, num_instances): if ins_j == ins_i: continue if mask_iou_matrix[ins_i, ins_j] > 0.8: keep[ins_j] = False # merge # area = downsampled_new_masks_from_sam.sum(-1).sum(-1) area = mask_area(downsampled_new_masks_from_sam, chunk_mode=True, chunk_size=100) roc = intersection / area[:, None] for ins_i in range(num_instances): if not keep[ins_i]: continue for ins_j in range(num_instances): if ins_i == ins_j: continue if not keep[ins_j]: continue if roc[ins_i, ins_j] > 0.8: keep[ins_i] = False break left_masks = [new_masks_from_sam[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] left_masks = torch.stack(left_masks) left_boxes = torchvision.ops.masks_to_boxes(left_masks) left_tags = [correspondding_tags[ins_i] for ins_i in range(len(keep)) if keep[ins_i]] # zoom in result_mask_list = [] result_tag_list = [] ori_image = Image.open(image_file) for ins_i, ins_box in enumerate(left_boxes): ins_box = ins_box.numpy().tolist() box_w = ins_box[2] - ins_box[0] box_h = ins_box[3] - ins_box[1] loose_box_x0 = int(ins_box[0] - box_w // 4) loose_box_y0 = int(ins_box[1] - box_h // 4) loose_box_x1 = int(ins_box[2] + box_w // 4) loose_box_y1 = int(ins_box[3] + box_h // 4) loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height loose_box_w = loose_box_x1 - loose_box_x0 loose_box_h = loose_box_y1 - loose_box_y0 assert loose_box_w >= box_w and loose_box_h >= box_h if loose_box_w < 256: padded_length_w = 256 - loose_box_w left_padded = padded_length_w // 2 right_padded = padded_length_w - left_padded if loose_box_x0 - left_padded < 0: right_padded = right_padded + left_padded - loose_box_x0 left_padded = loose_box_x0 if loose_box_x1 + right_padded > ori_width: left_padded = left_padded + loose_box_x1 + right_padded - ori_width right_padded = ori_width - loose_box_x1 loose_box_x0 = int(loose_box_x0 - left_padded) loose_box_x1 = int(loose_box_x1 + right_padded) loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width if loose_box_h < 256: padded_length_h = 256 - loose_box_h top_padded = padded_length_h // 2 bottom_padded = padded_length_h - top_padded if loose_box_y0 - top_padded < 0: bottom_padded = bottom_padded + top_padded - loose_box_y0 top_padded = loose_box_y0 if loose_box_y1 + bottom_padded > ori_height: top_padded = top_padded + loose_box_y1 + bottom_padded - ori_height bottom_padded = ori_height - loose_box_y1 loose_box_y0 = int(loose_box_y0 - top_padded) loose_box_y1 = int(loose_box_y1 + bottom_padded) loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height loose_box_w = loose_box_x1 - loose_box_x0 loose_box_h = loose_box_y1 - loose_box_y0 if loose_box_w > loose_box_h: padded_length_h = loose_box_w - loose_box_h top_padded = padded_length_h // 2 bottom_padded = padded_length_h - top_padded if loose_box_y0 - top_padded < 0: bottom_padded = bottom_padded + top_padded - loose_box_y0 top_padded = loose_box_y0 if loose_box_y1 + bottom_padded > ori_height: top_padded = top_padded + loose_box_y1 + bottom_padded - ori_height bottom_padded = ori_height - loose_box_y1 loose_box_y0 = int(loose_box_y0 - top_padded) loose_box_y1 = int(loose_box_y1 + bottom_padded) loose_box_y0 = loose_box_y0 if loose_box_y0 > 0 else 0 loose_box_y1 = loose_box_y1 if loose_box_y1 < ori_height else ori_height elif loose_box_h > loose_box_w: padded_length_w = loose_box_h - loose_box_w left_padded = padded_length_w // 2 right_padded = padded_length_w - left_padded if loose_box_x0 - left_padded < 0: right_padded = right_padded + left_padded - loose_box_x0 left_padded = loose_box_x0 if loose_box_x1 + right_padded > ori_width: left_padded = left_padded + loose_box_x1 + right_padded - ori_width right_padded = ori_width - loose_box_x1 loose_box_x0 = int(loose_box_x0 - left_padded) loose_box_x1 = int(loose_box_x1 + right_padded) loose_box_x0 = loose_box_x0 if loose_box_x0 > 0 else 0 loose_box_x1 = loose_box_x1 if loose_box_x1 < ori_width else ori_width image_patch = ori_image.crop((loose_box_x0, loose_box_y0, loose_box_x1, loose_box_y1)) image_patch_w, image_patch_h = image_patch.size res = ram_predictor.run_on_image(image_file_path=image_patch, dynamic_resolution=False) tag_list = [] for tag_string in res[0]: tags = tag_string.split(' | ') tag_list += tags tags = list(set(tag_list)) text_prompt = ','.join(tags) if image_patch_w > image_patch_h: rescaled_image_patch_w = 1024 rescaled_image_patch_h = int(image_patch_h / image_patch_w * 1024) else: rescaled_image_patch_h = 1024 rescaled_image_patch_w = int(image_patch_w / image_patch_h * 1024) image_patch = image_patch.resize((rescaled_image_patch_w, rescaled_image_patch_h)) output_image, json_results = ape_predictor.run_on_image( image_patch, input_text=text_prompt, visualize=True, score_threhold=0.1, output_type=["instance segmentation"], ) all_masks, all_tags = [], [] for idx, item in enumerate(json_results): object_mask = item["segmentation"] if isinstance(object_mask["counts"], list): object_mask = mask_util.frPyObjects(object_mask, object_mask["size"][0], object_mask["size"][1]) mask = mask_util.decode(object_mask) mask = torch.as_tensor(mask.astype(np.uint8)) all_masks.append(mask) all_tags.append(item['category_name']) # if len(all_masks) == 0: # continue if len(all_masks) == 0: result_mask_list.append(left_masks[ins_i]) result_tag_list.append(left_tags[ins_i]) continue all_masks = torch.stack(all_masks) all_masks_ori_size = torch.nn.functional.interpolate(all_masks.unsqueeze(0), size=(image_patch_h, image_patch_w), mode='bilinear') all_masks_ori_size = all_masks_ori_size > 0.4 ori_mask_crop = left_masks[ins_i, loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] # mask iou masks1 = ori_mask_crop[None, None, :, :].char().to('cuda') masks2 = all_masks_ori_size.char().to('cuda') intersection = (masks1 * masks2) union = (masks1 + masks2 - intersection).sum(-1).sum(-1) intersection = intersection.sum(-1).sum(-1) area = masks2.sum(-1).sum(-1) # area = mask_area(masks2, chunk_mode=True) masks_iou = intersection / union target_idx = torch.argmax(masks_iou, dim=1) if masks_iou[0, target_idx] < 0.8: temp_result_mask_list = [] temp_result_tag_list = [] for ins_j, mask_j_iou in enumerate(masks_iou[0]): if mask_j_iou < 0.1: continue roc_j = intersection[0, ins_j] / area[0, ins_j] if roc_j < 0.8: continue result_mask = torch.zeros((ori_height, ori_width)).to(all_masks.dtype) result_mask[loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] = all_masks_ori_size[0, ins_j] temp_result_mask_list.append(result_mask) temp_result_tag_list.append(all_tags[ins_j]) if len(temp_result_mask_list) > 1: result_mask_list.extend(temp_result_mask_list) result_tag_list.extend(temp_result_tag_list) else: result_mask_list.append(left_masks[ins_i]) result_tag_list.append(left_tags[ins_i]) else: result_mask = torch.zeros((ori_height, ori_width)).to(all_masks.dtype) result_mask[loose_box_y0:loose_box_y1, loose_box_x0:loose_box_x1] = all_masks_ori_size[0, target_idx.item()] result_mask_list.append(result_mask) result_tag_list.append(all_tags[target_idx]) unique_tags = list(set(result_tag_list)) text_prompt = ','.join(unique_tags) metadata = MetadataCatalog.get("__unused_ape_" + text_prompt) metadata.thing_classes = unique_tags metadata.stuff_classes = unique_tags if not visualize: return torch.stack(result_mask_list), result_tag_list def main(node_id=0, local_rank=0, work_dir="./work_dirs/object_level"): global_rank_id = int(node_id * 8 + local_rank) task_file = f"./work_dirs/object_level_task/rank{global_rank_id}.json" if not os.path.exists(task_file): print(f"No task file:{task_file}") return None with open(task_file, 'r') as f: sam_images = json.load(f) ram_predictor = build_ram_predictor(override_ckpt_file="third_parts/recognize_anything/xinyu1205/recognize-anything-plus-model/ram_plus_swin_large_14m.pth") ape_predictor = build_ape_predictor(which_categories='COCO', override_ckpt_file="third_parts/APE/shenyunhang/APE/configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth") sam = build_sam_vit_h("third_parts/zhouyik/zt_any_visual_prompt/sam_vit_h_4b8939.pth") sam.to(device="cuda") sam_predictor = SamPredictor(sam) sam_auto_mask_generator = SamAutomaticMaskGenerator(sam) timer = Timer() past_time = 0 total_images = len(sam_images) for idx, sam_image_file in enumerate(sam_images): image_name = os.path.basename(sam_image_file).split('.')[0] dir_name = os.path.dirname(sam_image_file) sam_anno_file = os.path.join(dir_name, image_name+".json") save_dir = os.path.join(work_dir, os.path.basename(dir_name)) if os.path.exists(os.path.join(save_dir, image_name+'.json')): continue if not os.path.exists(save_dir): os.makedirs(save_dir) if random.random() < 0.3: visualize=True else: visualize=False run_on_image_v2(sam_image_file, sam_anno_file, save_dir, ram_predictor, ape_predictor, sam_predictor, sam_auto_mask_generator, visualize=visualize) consume_time = "%.2f" % (timer.seconds() - past_time) past_time = timer.seconds() print(f"RANK#{local_rank}: {idx+1}/{total_images}, comsume {consume_time} seconds.") if __name__ == "__main__": work_dir, local_rank, node_id = sys.argv[1:] main(node_id=node_id, local_rank=local_rank, work_dir=work_dir)