Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| copy and modified https://github.com/pytorch/vision/blob/main/references/detection/coco_utils.py | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import torch | |
| import torch.utils.data | |
| import torchvision | |
| import torchvision.transforms.functional as TVF | |
| import faster_coco_eval.core.mask as coco_mask | |
| from faster_coco_eval import COCO | |
| def convert_coco_poly_to_mask(segmentations, height, width): | |
| masks = [] | |
| for polygons in segmentations: | |
| rles = coco_mask.frPyObjects(polygons, height, width) | |
| mask = coco_mask.decode(rles) | |
| if len(mask.shape) < 3: | |
| mask = mask[..., None] | |
| mask = torch.as_tensor(mask, dtype=torch.uint8) | |
| mask = mask.any(dim=2) | |
| masks.append(mask) | |
| if masks: | |
| masks = torch.stack(masks, dim=0) | |
| else: | |
| masks = torch.zeros((0, height, width), dtype=torch.uint8) | |
| return masks | |
| class ConvertCocoPolysToMask: | |
| def __call__(self, image, target): | |
| w, h = image.size | |
| image_id = target["image_id"] | |
| anno = target["annotations"] | |
| anno = [obj for obj in anno if obj["iscrowd"] == 0] | |
| boxes = [obj["bbox"] for obj in anno] | |
| # guard against no boxes via resizing | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
| boxes[:, 2:] += boxes[:, :2] | |
| boxes[:, 0::2].clamp_(min=0, max=w) | |
| boxes[:, 1::2].clamp_(min=0, max=h) | |
| classes = [obj["category_id"] for obj in anno] | |
| classes = torch.tensor(classes, dtype=torch.int64) | |
| segmentations = [obj["segmentation"] for obj in anno] | |
| masks = convert_coco_poly_to_mask(segmentations, h, w) | |
| keypoints = None | |
| if anno and "keypoints" in anno[0]: | |
| keypoints = [obj["keypoints"] for obj in anno] | |
| keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | |
| num_keypoints = keypoints.shape[0] | |
| if num_keypoints: | |
| keypoints = keypoints.view(num_keypoints, -1, 3) | |
| keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
| boxes = boxes[keep] | |
| classes = classes[keep] | |
| masks = masks[keep] | |
| if keypoints is not None: | |
| keypoints = keypoints[keep] | |
| target = {} | |
| target["boxes"] = boxes | |
| target["labels"] = classes | |
| target["masks"] = masks | |
| target["image_id"] = image_id | |
| if keypoints is not None: | |
| target["keypoints"] = keypoints | |
| # for conversion to coco api | |
| area = torch.tensor([obj["area"] for obj in anno]) | |
| iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) | |
| target["area"] = area | |
| target["iscrowd"] = iscrowd | |
| return image, target | |
| def _coco_remove_images_without_annotations(dataset, cat_list=None): | |
| def _has_only_empty_bbox(anno): | |
| return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) | |
| def _count_visible_keypoints(anno): | |
| return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) | |
| min_keypoints_per_image = 10 | |
| def _has_valid_annotation(anno): | |
| # if it's empty, there is no annotation | |
| if len(anno) == 0: | |
| return False | |
| # if all boxes have close to zero area, there is no annotation | |
| if _has_only_empty_bbox(anno): | |
| return False | |
| # keypoints task have a slight different criteria for considering | |
| # if an annotation is valid | |
| if "keypoints" not in anno[0]: | |
| return True | |
| # for keypoint detection tasks, only consider valid images those | |
| # containing at least min_keypoints_per_image | |
| if _count_visible_keypoints(anno) >= min_keypoints_per_image: | |
| return True | |
| return False | |
| ids = [] | |
| for ds_idx, img_id in enumerate(dataset.ids): | |
| ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
| anno = dataset.coco.loadAnns(ann_ids) | |
| if cat_list: | |
| anno = [obj for obj in anno if obj["category_id"] in cat_list] | |
| if _has_valid_annotation(anno): | |
| ids.append(ds_idx) | |
| dataset = torch.utils.data.Subset(dataset, ids) | |
| return dataset | |
| def convert_to_coco_api(ds): | |
| coco_ds = COCO() | |
| # annotation IDs need to start at 1, not 0, see torchvision issue #1530 | |
| ann_id = 1 | |
| dataset = {"images": [], "categories": [], "annotations": []} | |
| categories = set() | |
| for img_idx in range(len(ds)): | |
| # find better way to get target | |
| # targets = ds.get_annotations(img_idx) | |
| # img, targets = ds[img_idx] | |
| img, targets = ds.load_item(img_idx) | |
| width, height = img.size | |
| image_id = targets["image_id"].item() | |
| img_dict = {} | |
| img_dict["id"] = image_id | |
| img_dict["width"] = width | |
| img_dict["height"] = height | |
| dataset["images"].append(img_dict) | |
| bboxes = targets["boxes"].clone() | |
| bboxes[:, 2:] -= bboxes[:, :2] # xyxy -> xywh | |
| bboxes = bboxes.tolist() | |
| labels = targets["labels"].tolist() | |
| areas = targets["area"].tolist() | |
| iscrowd = targets["iscrowd"].tolist() | |
| if "masks" in targets: | |
| masks = targets["masks"] | |
| # make masks Fortran contiguous for coco_mask | |
| masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) | |
| if "keypoints" in targets: | |
| keypoints = targets["keypoints"] | |
| keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() | |
| num_objs = len(bboxes) | |
| for i in range(num_objs): | |
| ann = {} | |
| ann["image_id"] = image_id | |
| ann["bbox"] = bboxes[i] | |
| ann["category_id"] = labels[i] | |
| categories.add(labels[i]) | |
| ann["area"] = areas[i] | |
| ann["iscrowd"] = iscrowd[i] | |
| ann["id"] = ann_id | |
| if "masks" in targets: | |
| ann["segmentation"] = coco_mask.encode(masks[i].numpy()) | |
| if "keypoints" in targets: | |
| ann["keypoints"] = keypoints[i] | |
| ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3]) | |
| dataset["annotations"].append(ann) | |
| ann_id += 1 | |
| dataset["categories"] = [{"id": i} for i in sorted(categories)] | |
| coco_ds.dataset = dataset | |
| coco_ds.createIndex() | |
| return coco_ds | |
| def get_coco_api_from_dataset(dataset): | |
| for _ in range(10): | |
| if isinstance(dataset, torchvision.datasets.CocoDetection): | |
| break | |
| if isinstance(dataset, torch.utils.data.Subset): | |
| dataset = dataset.dataset | |
| if isinstance(dataset, torchvision.datasets.CocoDetection): | |
| return dataset.coco | |
| return convert_to_coco_api(dataset) | |