Spaces:
Sleeping
Sleeping
| import copy | |
| import os | |
| import torch | |
| import torch.utils.data | |
| import torchvision | |
| from torchvision import transforms as T | |
| from pycocotools import mask as coco_mask | |
| from pycocotools.coco import COCO | |
| class FilterAndRemapCocoCategories: | |
| def __init__(self, categories, remap=True): | |
| self.categories = categories | |
| self.remap = remap | |
| def __call__(self, image, target): | |
| anno = target["annotations"] | |
| anno = [obj for obj in anno if obj["category_id"] in self.categories] | |
| if not self.remap: | |
| target["annotations"] = anno | |
| return image, target | |
| anno = copy.deepcopy(anno) | |
| for obj in anno: | |
| obj["category_id"] = self.categories.index(obj["category_id"]) | |
| target["annotations"] = anno | |
| return image, target | |
| def convert_coco_poly_to_mask(segmentations, height, width): | |
| masks = [] | |
| for polygons in segmentations: | |
| rles = coco_mask.frPyObjects(polygons, height, width) | |
| mask = coco_mask.decode(rles) | |
| if len(mask.shape) < 3: | |
| mask = mask[..., None] | |
| mask = torch.as_tensor(mask, dtype=torch.uint8) | |
| mask = mask.any(dim=2) | |
| masks.append(mask) | |
| if masks: | |
| masks = torch.stack(masks, dim=0) | |
| else: | |
| masks = torch.zeros((0, height, width), dtype=torch.uint8) | |
| return masks | |
| class ConvertCocoPolysToMask: | |
| def __call__(self, image, target): | |
| w, h = image.size | |
| image_id = target["image_id"] | |
| image_id = torch.tensor([image_id]) | |
| anno = target["annotations"] | |
| anno = [obj for obj in anno if obj["iscrowd"] == 0] | |
| boxes = [obj["bbox"] for obj in anno] | |
| # guard against no boxes via resizing | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
| boxes[:, 2:] += boxes[:, :2] | |
| boxes[:, 0::2].clamp_(min=0, max=w) | |
| boxes[:, 1::2].clamp_(min=0, max=h) | |
| classes = [obj["category_id"] for obj in anno] | |
| classes = torch.tensor(classes, dtype=torch.int64) | |
| segmentations = [obj["segmentation"] for obj in anno] | |
| masks = convert_coco_poly_to_mask(segmentations, h, w) | |
| keypoints = None | |
| if anno and "keypoints" in anno[0]: | |
| keypoints = [obj["keypoints"] for obj in anno] | |
| keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | |
| num_keypoints = keypoints.shape[0] | |
| if num_keypoints: | |
| keypoints = keypoints.view(num_keypoints, -1, 3) | |
| keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
| boxes = boxes[keep] | |
| classes = classes[keep] | |
| masks = masks[keep] | |
| if keypoints is not None: | |
| keypoints = keypoints[keep] | |
| target = {} | |
| target["boxes"] = boxes | |
| target["labels"] = classes | |
| target["masks"] = masks | |
| target["image_id"] = image_id | |
| if keypoints is not None: | |
| target["keypoints"] = keypoints | |
| # for conversion to coco api | |
| area = torch.tensor([obj["area"] for obj in anno]) | |
| iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) | |
| target["area"] = area | |
| target["iscrowd"] = iscrowd | |
| return image, target | |
| def _coco_remove_images_without_annotations(dataset, cat_list=None): | |
| def _has_only_empty_bbox(anno): | |
| return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) | |
| def _count_visible_keypoints(anno): | |
| return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |
| min_keypoints_per_image = 10 | |
| def _has_valid_annotation(anno): | |
| # if it's empty, there is no annotation | |
| if len(anno) == 0: | |
| return False | |
| # if all boxes have close to zero area, there is no annotation | |
| if _has_only_empty_bbox(anno): | |
| return False | |
| # keypoints task have a slight different critera for considering | |
| # if an annotation is valid | |
| if "keypoints" not in anno[0]: | |
| return True | |
| # for keypoint detection tasks, only consider valid images those | |
| # containing at least min_keypoints_per_image | |
| if _count_visible_keypoints(anno) >= min_keypoints_per_image: | |
| return True | |
| return False | |
| if not isinstance(dataset, torchvision.datasets.CocoDetection): | |
| raise TypeError( | |
| f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" | |
| ) | |
| ids = [] | |
| for ds_idx, img_id in enumerate(dataset.ids): | |
| ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
| anno = dataset.coco.loadAnns(ann_ids) | |
| if cat_list: | |
| anno = [obj for obj in anno if obj["category_id"] in cat_list] | |
| if _has_valid_annotation(anno): | |
| ids.append(ds_idx) | |
| dataset = torch.utils.data.Subset(dataset, ids) | |
| return dataset | |
| def convert_to_coco_api(ds): | |
| coco_ds = COCO() | |
| # annotation IDs need to start at 1, not 0, see torchvision issue #1530 | |
| ann_id = 1 | |
| dataset = {"images": [], "categories": [], "annotations": []} | |
| categories = set() | |
| for img_idx in range(len(ds)): | |
| # find better way to get target | |
| # targets = ds.get_annotations(img_idx) | |
| img, targets = ds[img_idx] | |
| image_id = targets["image_id"].item() | |
| img_dict = {} | |
| img_dict["id"] = image_id | |
| img_dict["height"] = img.shape[-2] | |
| img_dict["width"] = img.shape[-1] | |
| dataset["images"].append(img_dict) | |
| bboxes = targets["boxes"].clone() | |
| bboxes[:, 2:] -= bboxes[:, :2] | |
| bboxes = bboxes.tolist() | |
| labels = targets["labels"].tolist() | |
| areas = targets["area"].tolist() | |
| iscrowd = targets["iscrowd"].tolist() | |
| if "masks" in targets: | |
| masks = targets["masks"] | |
| # make masks Fortran contiguous for coco_mask | |
| masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) | |
| if "keypoints" in targets: | |
| keypoints = targets["keypoints"] | |
| keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() | |
| num_objs = len(bboxes) | |
| for i in range(num_objs-1): | |
| print(i) | |
| ann = {} | |
| ann["image_id"] = image_id | |
| ann["bbox"] = bboxes[i] | |
| ann["category_id"] = labels[i] | |
| categories.add(labels[i]) | |
| ann["area"] = areas[i] | |
| ann["iscrowd"] = iscrowd[i] | |
| ann["id"] = ann_id | |
| if "masks" in targets: | |
| ann["segmentation"] = coco_mask.encode(masks[i].numpy()) | |
| if "keypoints" in targets: | |
| ann["keypoints"] = keypoints[i] | |
| ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3]) | |
| dataset["annotations"].append(ann) | |
| ann_id += 1 | |
| dataset["categories"] = [{"id": i} for i in sorted(categories)] | |
| coco_ds.dataset = dataset | |
| coco_ds.createIndex() | |
| return coco_ds | |
| def get_coco_api_from_dataset(dataset): | |
| for _ in range(10): | |
| if isinstance(dataset, torchvision.datasets.CocoDetection): | |
| break | |
| if isinstance(dataset, torch.utils.data.Subset): | |
| dataset = dataset.dataset | |
| if isinstance(dataset, torchvision.datasets.CocoDetection): | |
| return dataset.coco | |
| return convert_to_coco_api(dataset) | |
| class CocoDetection(torchvision.datasets.CocoDetection): | |
| def __init__(self, img_folder, ann_file, transforms): | |
| super().__init__(img_folder, ann_file) | |
| self._transforms = transforms | |
| def __getitem__(self, idx): | |
| img, target = super().__getitem__(idx) | |
| image_id = self.ids[idx] | |
| target = dict(image_id=image_id, annotations=target) | |
| if self._transforms is not None: | |
| img, target = self._transforms(img, target) | |
| return img, target | |
| def get_coco(root, image_set, transforms, mode="instances"): | |
| anno_file_template = "{}_{}2017.json" | |
| PATHS = { | |
| "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), | |
| "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))), | |
| # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) | |
| } | |
| t = [ConvertCocoPolysToMask()] | |
| if transforms is not None: | |
| t.append(transforms) | |
| transforms = T.Compose(t) | |
| img_folder, ann_file = PATHS[image_set] | |
| img_folder = os.path.join(root, img_folder) | |
| ann_file = os.path.join(root, ann_file) | |
| dataset = CocoDetection(img_folder, ann_file, transforms=transforms) | |
| if image_set == "train": | |
| dataset = _coco_remove_images_without_annotations(dataset) | |
| # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) | |
| return dataset | |
| def get_coco_kp(root, image_set, transforms): | |
| return get_coco(root, image_set, transforms, mode="person_keypoints") |