| | import copy |
| | import os |
| |
|
| | import torch |
| | import torch.utils.data |
| | import torchvision |
| | from PIL import Image |
| | from pycocotools import mask as coco_mask |
| | from transforms import Compose |
| |
|
| |
|
| | class FilterAndRemapCocoCategories: |
| | def __init__(self, categories, remap=True): |
| | self.categories = categories |
| | self.remap = remap |
| |
|
| | def __call__(self, image, anno): |
| | anno = [obj for obj in anno if obj["category_id"] in self.categories] |
| | if not self.remap: |
| | return image, anno |
| | anno = copy.deepcopy(anno) |
| | for obj in anno: |
| | obj["category_id"] = self.categories.index(obj["category_id"]) |
| | return image, anno |
| |
|
| |
|
| | def convert_coco_poly_to_mask(segmentations, height, width): |
| | masks = [] |
| | for polygons in segmentations: |
| | rles = coco_mask.frPyObjects(polygons, height, width) |
| | mask = coco_mask.decode(rles) |
| | if len(mask.shape) < 3: |
| | mask = mask[..., None] |
| | mask = torch.as_tensor(mask, dtype=torch.uint8) |
| | mask = mask.any(dim=2) |
| | masks.append(mask) |
| | if masks: |
| | masks = torch.stack(masks, dim=0) |
| | else: |
| | masks = torch.zeros((0, height, width), dtype=torch.uint8) |
| | return masks |
| |
|
| |
|
| | class ConvertCocoPolysToMask: |
| | def __call__(self, image, anno): |
| | w, h = image.size |
| | segmentations = [obj["segmentation"] for obj in anno] |
| | cats = [obj["category_id"] for obj in anno] |
| | if segmentations: |
| | masks = convert_coco_poly_to_mask(segmentations, h, w) |
| | cats = torch.as_tensor(cats, dtype=masks.dtype) |
| | |
| | |
| | target, _ = (masks * cats[:, None, None]).max(dim=0) |
| | |
| | target[masks.sum(0) > 1] = 255 |
| | else: |
| | target = torch.zeros((h, w), dtype=torch.uint8) |
| | target = Image.fromarray(target.numpy()) |
| | return image, target |
| |
|
| |
|
| | def _coco_remove_images_without_annotations(dataset, cat_list=None): |
| | def _has_valid_annotation(anno): |
| | |
| | if len(anno) == 0: |
| | return False |
| | |
| | return sum(obj["area"] for obj in anno) > 1000 |
| |
|
| | if not isinstance(dataset, torchvision.datasets.CocoDetection): |
| | raise TypeError( |
| | f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" |
| | ) |
| |
|
| | ids = [] |
| | for ds_idx, img_id in enumerate(dataset.ids): |
| | ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) |
| | anno = dataset.coco.loadAnns(ann_ids) |
| | if cat_list: |
| | anno = [obj for obj in anno if obj["category_id"] in cat_list] |
| | if _has_valid_annotation(anno): |
| | ids.append(ds_idx) |
| |
|
| | dataset = torch.utils.data.Subset(dataset, ids) |
| | return dataset |
| |
|
| |
|
| | def get_coco(root, image_set, transforms): |
| | PATHS = { |
| | "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), |
| | "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), |
| | |
| | } |
| | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] |
| |
|
| | transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) |
| |
|
| | img_folder, ann_file = PATHS[image_set] |
| | img_folder = os.path.join(root, img_folder) |
| | ann_file = os.path.join(root, ann_file) |
| |
|
| | dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) |
| |
|
| | if image_set == "train": |
| | dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) |
| |
|
| | return dataset |
| |
|