| import copy |
| import os |
|
|
| import torch |
| import torch.utils.data |
| import torchvision |
| from PIL import Image |
| from pycocotools import mask as coco_mask |
| from transforms import Compose |
|
|
|
|
| class FilterAndRemapCocoCategories: |
| def __init__(self, categories, remap=True): |
| self.categories = categories |
| self.remap = remap |
|
|
| def __call__(self, image, anno): |
| anno = [obj for obj in anno if obj["category_id"] in self.categories] |
| if not self.remap: |
| return image, anno |
| anno = copy.deepcopy(anno) |
| for obj in anno: |
| obj["category_id"] = self.categories.index(obj["category_id"]) |
| return image, anno |
|
|
|
|
| def convert_coco_poly_to_mask(segmentations, height, width): |
| masks = [] |
| for polygons in segmentations: |
| rles = coco_mask.frPyObjects(polygons, height, width) |
| mask = coco_mask.decode(rles) |
| if len(mask.shape) < 3: |
| mask = mask[..., None] |
| mask = torch.as_tensor(mask, dtype=torch.uint8) |
| mask = mask.any(dim=2) |
| masks.append(mask) |
| if masks: |
| masks = torch.stack(masks, dim=0) |
| else: |
| masks = torch.zeros((0, height, width), dtype=torch.uint8) |
| return masks |
|
|
|
|
| class ConvertCocoPolysToMask: |
| def __call__(self, image, anno): |
| w, h = image.size |
| segmentations = [obj["segmentation"] for obj in anno] |
| cats = [obj["category_id"] for obj in anno] |
| if segmentations: |
| masks = convert_coco_poly_to_mask(segmentations, h, w) |
| cats = torch.as_tensor(cats, dtype=masks.dtype) |
| |
| |
| target, _ = (masks * cats[:, None, None]).max(dim=0) |
| |
| target[masks.sum(0) > 1] = 255 |
| else: |
| target = torch.zeros((h, w), dtype=torch.uint8) |
| target = Image.fromarray(target.numpy()) |
| return image, target |
|
|
|
|
| def _coco_remove_images_without_annotations(dataset, cat_list=None): |
| def _has_valid_annotation(anno): |
| |
| if len(anno) == 0: |
| return False |
| |
| return sum(obj["area"] for obj in anno) > 1000 |
|
|
| if not isinstance(dataset, torchvision.datasets.CocoDetection): |
| raise TypeError( |
| f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" |
| ) |
|
|
| ids = [] |
| for ds_idx, img_id in enumerate(dataset.ids): |
| ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) |
| anno = dataset.coco.loadAnns(ann_ids) |
| if cat_list: |
| anno = [obj for obj in anno if obj["category_id"] in cat_list] |
| if _has_valid_annotation(anno): |
| ids.append(ds_idx) |
|
|
| dataset = torch.utils.data.Subset(dataset, ids) |
| return dataset |
|
|
|
|
| def get_coco(root, image_set, transforms): |
| PATHS = { |
| "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), |
| "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), |
| |
| } |
| CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] |
|
|
| transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) |
|
|
| img_folder, ann_file = PATHS[image_set] |
| img_folder = os.path.join(root, img_folder) |
| ann_file = os.path.join(root, ann_file) |
|
|
| dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) |
|
|
| if image_set == "train": |
| dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) |
|
|
| return dataset |
|
|