Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import faster_coco_eval | |
| import faster_coco_eval.core.mask as coco_mask | |
| import torch | |
| import torch.utils.data | |
| import torchvision | |
| import os | |
| from PIL import Image | |
| from ...core import register | |
| from .._misc import convert_to_tv_tensor | |
| from ._dataset import DetDataset | |
| torchvision.disable_beta_transforms_warning() | |
| faster_coco_eval.init_as_pycocotools() | |
| Image.MAX_IMAGE_PIXELS = None | |
| __all__ = ["CocoDetection"] | |
| class CocoDetection(torchvision.datasets.CocoDetection, DetDataset): | |
| __inject__ = [ | |
| "transforms", | |
| ] | |
| __share__ = ["remap_mscoco_category"] | |
| def __init__( | |
| self, img_folder, ann_file, transforms, return_masks=False, remap_mscoco_category=False | |
| ): | |
| super(CocoDetection, self).__init__(img_folder, ann_file) | |
| self._transforms = transforms | |
| self.prepare = ConvertCocoPolysToMask(return_masks) | |
| self.img_folder = img_folder | |
| self.ann_file = ann_file | |
| self.return_masks = return_masks | |
| self.remap_mscoco_category = remap_mscoco_category | |
| def __getitem__(self, idx): | |
| img, target = self.load_item(idx) | |
| if self._transforms is not None: | |
| img, target, _ = self._transforms(img, target, self) | |
| return img, target | |
| def load_item(self, idx): | |
| image, target = super(CocoDetection, self).__getitem__(idx) | |
| image_id = self.ids[idx] | |
| image_path = os.path.join(self.img_folder, self.coco.loadImgs(image_id)[0]["file_name"]) | |
| target = {"image_id": image_id, "image_path": image_path, "annotations": target} | |
| if self.remap_mscoco_category: | |
| image, target = self.prepare(image, target, category2label=mscoco_category2label) | |
| else: | |
| image, target = self.prepare(image, target) | |
| target["idx"] = torch.tensor([idx]) | |
| if "boxes" in target: | |
| target["boxes"] = convert_to_tv_tensor( | |
| target["boxes"], key="boxes", spatial_size=image.size[::-1] | |
| ) | |
| if "masks" in target: | |
| target["masks"] = convert_to_tv_tensor(target["masks"], key="masks") | |
| return image, target | |
| def extra_repr(self) -> str: | |
| s = f" img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n" | |
| s += f" return_masks: {self.return_masks}\n" | |
| if hasattr(self, "_transforms") and self._transforms is not None: | |
| s += f" transforms:\n {repr(self._transforms)}" | |
| if hasattr(self, "_preset") and self._preset is not None: | |
| s += f" preset:\n {repr(self._preset)}" | |
| return s | |
| def categories( | |
| self, | |
| ): | |
| return self.coco.dataset["categories"] | |
| def category2name( | |
| self, | |
| ): | |
| return {cat["id"]: cat["name"] for cat in self.categories} | |
| def category2label( | |
| self, | |
| ): | |
| return {cat["id"]: i for i, cat in enumerate(self.categories)} | |
| def label2category( | |
| self, | |
| ): | |
| return {i: cat["id"] for i, cat in enumerate(self.categories)} | |
| def convert_coco_poly_to_mask(segmentations, height, width): | |
| masks = [] | |
| for polygons in segmentations: | |
| rles = coco_mask.frPyObjects(polygons, height, width) | |
| mask = coco_mask.decode(rles) | |
| if len(mask.shape) < 3: | |
| mask = mask[..., None] | |
| mask = torch.as_tensor(mask, dtype=torch.uint8) | |
| mask = mask.any(dim=2) | |
| masks.append(mask) | |
| if masks: | |
| masks = torch.stack(masks, dim=0) | |
| else: | |
| masks = torch.zeros((0, height, width), dtype=torch.uint8) | |
| return masks | |
| class ConvertCocoPolysToMask(object): | |
| def __init__(self, return_masks=False): | |
| self.return_masks = return_masks | |
| def __call__(self, image: Image.Image, target, **kwargs): | |
| w, h = image.size | |
| image_id = target["image_id"] | |
| image_id = torch.tensor([image_id]) | |
| image_path = target["image_path"] | |
| anno = target["annotations"] | |
| anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] | |
| boxes = [obj["bbox"] for obj in anno] | |
| # guard against no boxes via resizing | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
| boxes[:, 2:] += boxes[:, :2] | |
| boxes[:, 0::2].clamp_(min=0, max=w) | |
| boxes[:, 1::2].clamp_(min=0, max=h) | |
| category2label = kwargs.get("category2label", None) | |
| if category2label is not None: | |
| labels = [category2label[obj["category_id"]] for obj in anno] | |
| else: | |
| labels = [obj["category_id"] for obj in anno] | |
| labels = torch.tensor(labels, dtype=torch.int64) | |
| if self.return_masks: | |
| segmentations = [obj["segmentation"] for obj in anno] | |
| masks = convert_coco_poly_to_mask(segmentations, h, w) | |
| keypoints = None | |
| if anno and "keypoints" in anno[0]: | |
| keypoints = [obj["keypoints"] for obj in anno] | |
| keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | |
| num_keypoints = keypoints.shape[0] | |
| if num_keypoints: | |
| keypoints = keypoints.view(num_keypoints, -1, 3) | |
| keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
| boxes = boxes[keep] | |
| labels = labels[keep] | |
| if self.return_masks: | |
| masks = masks[keep] | |
| if keypoints is not None: | |
| keypoints = keypoints[keep] | |
| target = {} | |
| target["boxes"] = boxes | |
| target["labels"] = labels | |
| if self.return_masks: | |
| target["masks"] = masks | |
| target["image_id"] = image_id | |
| target["image_path"] = image_path | |
| if keypoints is not None: | |
| target["keypoints"] = keypoints | |
| # for conversion to coco api | |
| area = torch.tensor([obj["area"] for obj in anno]) | |
| iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) | |
| target["area"] = area[keep] | |
| target["iscrowd"] = iscrowd[keep] | |
| target["orig_size"] = torch.as_tensor([int(w), int(h)]) | |
| # target["size"] = torch.as_tensor([int(w), int(h)]) | |
| return image, target | |
| mscoco_category2name = { | |
| 1: "person", | |
| 2: "bicycle", | |
| 3: "car", | |
| 4: "motorcycle", | |
| 5: "airplane", | |
| 6: "bus", | |
| 7: "train", | |
| 8: "truck", | |
| 9: "boat", | |
| 10: "traffic light", | |
| 11: "fire hydrant", | |
| 13: "stop sign", | |
| 14: "parking meter", | |
| 15: "bench", | |
| 16: "bird", | |
| 17: "cat", | |
| 18: "dog", | |
| 19: "horse", | |
| 20: "sheep", | |
| 21: "cow", | |
| 22: "elephant", | |
| 23: "bear", | |
| 24: "zebra", | |
| 25: "giraffe", | |
| 27: "backpack", | |
| 28: "umbrella", | |
| 31: "handbag", | |
| 32: "tie", | |
| 33: "suitcase", | |
| 34: "frisbee", | |
| 35: "skis", | |
| 36: "snowboard", | |
| 37: "sports ball", | |
| 38: "kite", | |
| 39: "baseball bat", | |
| 40: "baseball glove", | |
| 41: "skateboard", | |
| 42: "surfboard", | |
| 43: "tennis racket", | |
| 44: "bottle", | |
| 46: "wine glass", | |
| 47: "cup", | |
| 48: "fork", | |
| 49: "knife", | |
| 50: "spoon", | |
| 51: "bowl", | |
| 52: "banana", | |
| 53: "apple", | |
| 54: "sandwich", | |
| 55: "orange", | |
| 56: "broccoli", | |
| 57: "carrot", | |
| 58: "hot dog", | |
| 59: "pizza", | |
| 60: "donut", | |
| 61: "cake", | |
| 62: "chair", | |
| 63: "couch", | |
| 64: "potted plant", | |
| 65: "bed", | |
| 67: "dining table", | |
| 70: "toilet", | |
| 72: "tv", | |
| 73: "laptop", | |
| 74: "mouse", | |
| 75: "remote", | |
| 76: "keyboard", | |
| 77: "cell phone", | |
| 78: "microwave", | |
| 79: "oven", | |
| 80: "toaster", | |
| 81: "sink", | |
| 82: "refrigerator", | |
| 84: "book", | |
| 85: "clock", | |
| 86: "vase", | |
| 87: "scissors", | |
| 88: "teddy bear", | |
| 89: "hair drier", | |
| 90: "toothbrush", | |
| } | |
| mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())} | |
| mscoco_label2category = {v: k for k, v in mscoco_category2label.items()} | |