Spaces:
Runtime error
Runtime error
| import random | |
| from copy import deepcopy | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import Dataset | |
| from torch.utils.data.distributed import DistributedSampler | |
| class SAMDistributedSampler(DistributedSampler): | |
| """ | |
| Modified from https://github.com/pytorch/pytorch/blob/97261be0a8f09bed9ab95d0cee82e75eebd249c3/torch/utils/data/distributed.py. | |
| """ | |
| def __init__( | |
| self, | |
| dataset: Dataset, | |
| num_replicas: Optional[int] = None, | |
| rank: Optional[int] = None, | |
| shuffle: bool = True, | |
| seed: int = 0, | |
| drop_last: bool = False, | |
| sub_epochs_per_epoch: int = 1, | |
| ) -> None: | |
| super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) | |
| self.sub_epoch = 0 | |
| self.sub_epochs_per_epoch = sub_epochs_per_epoch | |
| self.set_sub_num_samples() | |
| def __iter__(self): | |
| if self.shuffle: | |
| # deterministically shuffle based on epoch and seed | |
| g = torch.Generator() | |
| g.manual_seed(self.seed + self.epoch) | |
| indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] | |
| else: | |
| indices = list(range(len(self.dataset))) # type: ignore[arg-type] | |
| if not self.drop_last: | |
| # add extra samples to make it evenly divisible | |
| padding_size = self.total_size - len(indices) | |
| if padding_size <= len(indices): | |
| indices += indices[:padding_size] | |
| else: | |
| indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |
| else: | |
| # remove tail of data to make it evenly divisible. | |
| indices = indices[: self.total_size] | |
| assert len(indices) == self.total_size | |
| # subsample | |
| indices = indices[self.rank : self.total_size : self.num_replicas] | |
| assert len(indices) == self.num_samples | |
| indices = indices[(self.sub_epoch % self.sub_epochs_per_epoch) :: self.sub_epochs_per_epoch] | |
| return iter(indices) | |
| def __len__(self) -> int: | |
| return self.sub_num_samples | |
| def set_sub_num_samples(self) -> int: | |
| self.sub_num_samples = self.num_samples // self.sub_epochs_per_epoch | |
| if self.sub_num_samples % self.sub_epochs_per_epoch > self.sub_epoch: | |
| self.sub_num_samples += 1 | |
| def set_epoch_and_sub_epoch(self, epoch: int, sub_epoch: int) -> None: | |
| r""" | |
| Set the epoch for this sampler. | |
| When :attr:`shuffle=True`, this ensures all replicas | |
| use a different random ordering for each epoch. Otherwise, the next iteration of this | |
| sampler will yield the same ordering. | |
| Args: | |
| epoch (int): Epoch number. | |
| sub_epoch (int): Sub epoch number. | |
| """ | |
| self.epoch = epoch | |
| self.sub_epoch = sub_epoch | |
| self.set_sub_num_samples() | |
| class RandomHFlip(object): | |
| def __init__(self, prob=0.5): | |
| self.prob = prob | |
| def __call__(self, sample): | |
| image, masks, points, bboxs, shape = ( | |
| sample["image"], | |
| sample["masks"], | |
| sample["points"], | |
| sample["bboxs"], | |
| sample["shape"], | |
| ) | |
| if random.random() >= self.prob: | |
| image = torch.flip(image, dims=[2]) | |
| masks = torch.flip(masks, dims=[2]) | |
| points = deepcopy(points).to(torch.float) | |
| bboxs = deepcopy(bboxs).to(torch.float) | |
| points[:, 0] = shape[-1] - points[:, 0] | |
| bboxs[:, 0] = shape[-1] - bboxs[:, 2] - bboxs[:, 0] | |
| return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape} | |
| class ResizeLongestSide(object): | |
| """ | |
| Modified from https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/utils/transforms.py. | |
| """ | |
| def __init__(self, target_length: int) -> None: | |
| self.target_length = target_length | |
| def apply_image(self, image: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: | |
| target_size = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) | |
| return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) | |
| def apply_boxes(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: | |
| """ | |
| Expects a torch tensor with shape Bx4. Requires the original image | |
| size in (H, W) format. | |
| """ | |
| boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) | |
| return boxes.reshape(-1, 4) | |
| def apply_coords(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: | |
| """ | |
| Expects a torch tensor with length 2 in the last dimension. Requires the | |
| original image size in (H, W) format. | |
| """ | |
| old_h, old_w = original_size | |
| new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) | |
| coords = deepcopy(coords).to(torch.float) | |
| coords[..., 0] = coords[..., 0] * (new_w / old_w) | |
| coords[..., 1] = coords[..., 1] * (new_h / old_h) | |
| return coords | |
| def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: | |
| """ | |
| Compute the output size given input size and target long side length. | |
| """ | |
| scale = long_side_length * 1.0 / max(oldh, oldw) | |
| newh, neww = oldh * scale, oldw * scale | |
| neww = int(neww + 0.5) | |
| newh = int(newh + 0.5) | |
| return (newh, neww) | |
| def __call__(self, sample): | |
| image, masks, points, bboxs, shape = ( | |
| sample["image"], | |
| sample["masks"], | |
| sample["points"], | |
| sample["bboxs"], | |
| sample["shape"], | |
| ) | |
| image = self.apply_image(image.unsqueeze(0), shape).squeeze(0) | |
| masks = self.apply_image(masks.unsqueeze(1), shape).squeeze(1) | |
| points = self.apply_coords(points, shape) | |
| bboxs = self.apply_boxes(bboxs, shape) | |
| return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape} | |
| class Normalize_and_Pad(object): | |
| def __init__(self, target_length: int) -> None: | |
| self.target_length = target_length | |
| self.transform = transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) | |
| def __call__(self, sample): | |
| image, masks, points, bboxs, shape = ( | |
| sample["image"], | |
| sample["masks"], | |
| sample["points"], | |
| sample["bboxs"], | |
| sample["shape"], | |
| ) | |
| h, w = image.shape[-2:] | |
| image = self.transform(image) | |
| padh = self.target_length - h | |
| padw = self.target_length - w | |
| image = F.pad(image.unsqueeze(0), (0, padw, 0, padh), value=0).squeeze(0) | |
| masks = F.pad(masks.unsqueeze(1), (0, padw, 0, padh), value=0).squeeze(1) | |
| return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape} | |