| """ |
| Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
| Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py |
| |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
| import torch |
| import torch.utils.data |
|
|
| import torchvision |
|
|
| from PIL import Image |
| import faster_coco_eval |
| import faster_coco_eval.core.mask as coco_mask |
| from ._dataset import DetDataset |
| from .._misc import convert_to_tv_tensor |
| from ...core import register |
|
|
| torchvision.disable_beta_transforms_warning() |
| faster_coco_eval.init_as_pycocotools() |
| Image.MAX_IMAGE_PIXELS = None |
|
|
| __all__ = ['CocoDetection'] |
|
|
|
|
| @register() |
| class CocoDetection(torchvision.datasets.CocoDetection, DetDataset): |
| __inject__ = ['transforms', ] |
| __share__ = ['remap_mscoco_category'] |
|
|
| def __init__(self, img_folder, ann_file, transforms, return_masks=False, remap_mscoco_category=False): |
| super(CocoDetection, self).__init__(img_folder, ann_file) |
| self._transforms = transforms |
| self.prepare = ConvertCocoPolysToMask(return_masks) |
| self.img_folder = img_folder |
| self.ann_file = ann_file |
| self.return_masks = return_masks |
| self.remap_mscoco_category = remap_mscoco_category |
|
|
| def __getitem__(self, idx): |
| img, target = self.load_item(idx) |
| if self._transforms is not None: |
| img, target, _ = self._transforms(img, target, self) |
| return img, target |
|
|
| def load_item(self, idx): |
| image, target = super(CocoDetection, self).__getitem__(idx) |
| image_id = self.ids[idx] |
| target = {'image_id': image_id, 'annotations': target} |
|
|
| if self.remap_mscoco_category: |
| image, target = self.prepare(image, target, category2label=mscoco_category2label) |
| else: |
| image, target = self.prepare(image, target) |
|
|
| target['idx'] = torch.tensor([idx]) |
|
|
| if 'boxes' in target: |
| target['boxes'] = convert_to_tv_tensor(target['boxes'], key='boxes', spatial_size=image.size[::-1]) |
|
|
| if 'masks' in target: |
| target['masks'] = convert_to_tv_tensor(target['masks'], key='masks') |
|
|
| return image, target |
|
|
| def extra_repr(self) -> str: |
| s = f' img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n' |
| s += f' return_masks: {self.return_masks}\n' |
| if hasattr(self, '_transforms') and self._transforms is not None: |
| s += f' transforms:\n {repr(self._transforms)}' |
| if hasattr(self, '_preset') and self._preset is not None: |
| s += f' preset:\n {repr(self._preset)}' |
| return s |
|
|
| @property |
| def categories(self, ): |
| return self.coco.dataset['categories'] |
|
|
| @property |
| def category2name(self, ): |
| return {cat['id']: cat['name'] for cat in self.categories} |
|
|
| @property |
| def category2label(self, ): |
| return {cat['id']: i for i, cat in enumerate(self.categories)} |
|
|
| @property |
| def label2category(self, ): |
| return {i: cat['id'] for i, cat in enumerate(self.categories)} |
|
|
|
|
| def convert_coco_poly_to_mask(segmentations, height, width): |
| masks = [] |
| for polygons in segmentations: |
| rles = coco_mask.frPyObjects(polygons, height, width) |
| mask = coco_mask.decode(rles) |
| if len(mask.shape) < 3: |
| mask = mask[..., None] |
| mask = torch.as_tensor(mask, dtype=torch.uint8) |
| mask = mask.any(dim=2) |
| masks.append(mask) |
| if masks: |
| masks = torch.stack(masks, dim=0) |
| else: |
| masks = torch.zeros((0, height, width), dtype=torch.uint8) |
| return masks |
|
|
|
|
| class ConvertCocoPolysToMask(object): |
| def __init__(self, return_masks=False): |
| self.return_masks = return_masks |
|
|
| def __call__(self, image: Image.Image, target, **kwargs): |
| w, h = image.size |
|
|
| image_id = target["image_id"] |
| image_id = torch.tensor([image_id]) |
|
|
| anno = target["annotations"] |
|
|
| anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] |
|
|
| boxes = [obj["bbox"] for obj in anno] |
| |
| boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) |
| boxes[:, 2:] += boxes[:, :2] |
| boxes[:, 0::2].clamp_(min=0, max=w) |
| boxes[:, 1::2].clamp_(min=0, max=h) |
|
|
| category2label = kwargs.get('category2label', None) |
| if category2label is not None: |
| labels = [category2label[obj["category_id"]] for obj in anno] |
| else: |
| labels = [obj["category_id"] for obj in anno] |
|
|
| labels = torch.tensor(labels, dtype=torch.int64) |
|
|
| if self.return_masks: |
| segmentations = [obj["segmentation"] for obj in anno] |
| masks = convert_coco_poly_to_mask(segmentations, h, w) |
|
|
| keypoints = None |
| if anno and "keypoints" in anno[0]: |
| keypoints = [obj["keypoints"] for obj in anno] |
| keypoints = torch.as_tensor(keypoints, dtype=torch.float32) |
| num_keypoints = keypoints.shape[0] |
| if num_keypoints: |
| keypoints = keypoints.view(num_keypoints, -1, 3) |
|
|
| keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) |
| boxes = boxes[keep] |
| labels = labels[keep] |
| if self.return_masks: |
| masks = masks[keep] |
| if keypoints is not None: |
| keypoints = keypoints[keep] |
|
|
| target = {} |
| target["boxes"] = boxes |
| target["labels"] = labels |
| if self.return_masks: |
| target["masks"] = masks |
| target["image_id"] = image_id |
| if keypoints is not None: |
| target["keypoints"] = keypoints |
|
|
| |
| area = torch.tensor([obj["area"] for obj in anno]) |
| iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) |
| target["area"] = area[keep] |
| target["iscrowd"] = iscrowd[keep] |
|
|
| target["orig_size"] = torch.as_tensor([int(w), int(h)]) |
| |
|
|
| return image, target |
|
|
|
|
| mscoco_category2name = { |
| 1: 'person', |
| 2: 'bicycle', |
| 3: 'car', |
| 4: 'motorcycle', |
| 5: 'airplane', |
| 6: 'bus', |
| 7: 'train', |
| 8: 'truck', |
| 9: 'boat', |
| 10: 'traffic light', |
| 11: 'fire hydrant', |
| 13: 'stop sign', |
| 14: 'parking meter', |
| 15: 'bench', |
| 16: 'bird', |
| 17: 'cat', |
| 18: 'dog', |
| 19: 'horse', |
| 20: 'sheep', |
| 21: 'cow', |
| 22: 'elephant', |
| 23: 'bear', |
| 24: 'zebra', |
| 25: 'giraffe', |
| 27: 'backpack', |
| 28: 'umbrella', |
| 31: 'handbag', |
| 32: 'tie', |
| 33: 'suitcase', |
| 34: 'frisbee', |
| 35: 'skis', |
| 36: 'snowboard', |
| 37: 'sports ball', |
| 38: 'kite', |
| 39: 'baseball bat', |
| 40: 'baseball glove', |
| 41: 'skateboard', |
| 42: 'surfboard', |
| 43: 'tennis racket', |
| 44: 'bottle', |
| 46: 'wine glass', |
| 47: 'cup', |
| 48: 'fork', |
| 49: 'knife', |
| 50: 'spoon', |
| 51: 'bowl', |
| 52: 'banana', |
| 53: 'apple', |
| 54: 'sandwich', |
| 55: 'orange', |
| 56: 'broccoli', |
| 57: 'carrot', |
| 58: 'hot dog', |
| 59: 'pizza', |
| 60: 'donut', |
| 61: 'cake', |
| 62: 'chair', |
| 63: 'couch', |
| 64: 'potted plant', |
| 65: 'bed', |
| 67: 'dining table', |
| 70: 'toilet', |
| 72: 'tv', |
| 73: 'laptop', |
| 74: 'mouse', |
| 75: 'remote', |
| 76: 'keyboard', |
| 77: 'cell phone', |
| 78: 'microwave', |
| 79: 'oven', |
| 80: 'toaster', |
| 81: 'sink', |
| 82: 'refrigerator', |
| 84: 'book', |
| 85: 'clock', |
| 86: 'vase', |
| 87: 'scissors', |
| 88: 'teddy bear', |
| 89: 'hair drier', |
| 90: 'toothbrush' |
| } |
|
|
| mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())} |
| mscoco_label2category = {v: k for k, v in mscoco_category2label.items()} |
|
|