Spaces:
Sleeping
Sleeping
| import logging | |
| import math | |
| import os | |
| from PIL import Image | |
| import yaml | |
| from sklearn.metrics import confusion_matrix | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DistributedDataParallel | |
| from torchvision import transforms | |
| from moco.loader import GaussianBlur | |
| import numpy as np | |
| from augmentations import JigsawPuzzle, JigsawPuzzle_l, JigsawPuzzle_all, RandomErasing, RandomPatchNoise, RandomPatchErase | |
| LOG_FORMAT = "[%(levelname)s] %(asctime)s %(filename)s:%(lineno)s %(message)s" | |
| LOG_DATEFMT = "%Y-%m-%d %H:%M:%S" | |
| NUM_CLASSES = {"domainnet-126": 126, "VISDA-C": 12, "PACS": 7} | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| def configure_logger(rank, log_path=None): | |
| if log_path: | |
| log_dir = os.path.dirname(log_path) | |
| os.makedirs(log_dir, exist_ok=True) | |
| # only master process will print & write | |
| level = logging.INFO if rank in {-1, 0} else logging.WARNING | |
| handlers = [logging.StreamHandler()] | |
| if rank in {0, -1} and log_path: | |
| handlers.append(logging.FileHandler(log_path, "w")) | |
| logging.basicConfig( | |
| level=level, | |
| format=LOG_FORMAT, | |
| datefmt=LOG_DATEFMT, | |
| handlers=handlers, | |
| ) | |
| class UnevenBatchLoader: | |
| """Loader that loads data from multiple datasets with different length.""" | |
| def __init__(self, data_loaders, is_ddp=False): | |
| # register N data loaders with epoch counters. | |
| self.data_loaders = data_loaders | |
| self.epoch_counters = [0 for _ in range(len(data_loaders))] | |
| # set_epoch() needs to be called before creating the iterator | |
| self.is_ddp = is_ddp | |
| if is_ddp: | |
| for data_loader in data_loaders: | |
| data_loader.sampler.set_epoch(0) | |
| self.iterators = [iter(data_loader) for data_loader in data_loaders] | |
| def next_batch(self): | |
| """Load the next batch by collecting from N data loaders. | |
| Args: | |
| None | |
| Returns: | |
| data: a list of N items from N data loaders. each item has the format | |
| output by a single data loader. | |
| """ | |
| data = [] | |
| for i, iterator in enumerate(self.iterators): | |
| try: | |
| batch_i = next(iterator) | |
| except StopIteration: | |
| self.epoch_counters[i] += 1 | |
| # create a new iterator | |
| if self.is_ddp: | |
| self.data_loaders[i].sampler.set_epoch(self.epoch_counters[i]) | |
| new_iterator = iter(self.data_loaders[i]) | |
| self.iterators[i] = new_iterator | |
| batch_i = next(new_iterator) | |
| data.append(batch_i) | |
| return data | |
| def update_loader(self, idx, loader, epoch=None): | |
| if self.is_ddp and isinstance(epoch, int): | |
| loader.sampler.set_epoch(epoch) | |
| self.iterators[idx] = iter(loader) | |
| class CustomDistributedDataParallel(DistributedDataParallel): | |
| """A wrapper class over DDP that relay "module" attribute.""" | |
| def __init__(self, model, **kwargs): | |
| super(CustomDistributedDataParallel, self).__init__(model, **kwargs) | |
| def __getattr__(self, name): | |
| try: | |
| return super(CustomDistributedDataParallel, self).__getattr__(name) | |
| except AttributeError: | |
| return getattr(self.module, name) | |
| def concat_all_gather(tensor): | |
| """ | |
| Performs all_gather operation on the provided tensors. | |
| *** Warning ***: torch.distributed.all_gather has no gradient. | |
| """ | |
| tensors_gather = [torch.ones_like(tensor) for _ in range(dist.get_world_size())] | |
| dist.all_gather(tensors_gather, tensor, async_op=False) | |
| output = torch.cat(tensors_gather, dim=0) | |
| return output | |
| def remove_wrap_arounds(tensor, ranks): | |
| if ranks == 0: | |
| return tensor | |
| world_size = dist.get_world_size() | |
| single_length = len(tensor) // world_size | |
| output = [] | |
| for rank in range(world_size): | |
| sub_tensor = tensor[rank * single_length : (rank + 1) * single_length] | |
| if rank >= ranks: | |
| output.append(sub_tensor[:-1]) | |
| else: | |
| output.append(sub_tensor) | |
| output = torch.cat(output) | |
| return output | |
| def get_categories(category_file): | |
| """Return a list of categories ordered by corresponding label. | |
| Args: | |
| category_file: str, path to the category file. can be .yaml or .txt | |
| Returns: | |
| categories: List[str], a list of categories ordered by label. | |
| """ | |
| if category_file.endswith(".yaml"): | |
| with open(category_file, "r") as fd: | |
| cat_mapping = yaml.load(fd, Loader=yaml.SafeLoader) | |
| categories = list(cat_mapping.keys()) | |
| categories.sort(key=lambda x: cat_mapping[x]) | |
| elif category_file.endswith(".txt"): | |
| with open(category_file, "r") as fd: | |
| categories = fd.readlines() | |
| categories = [cat.strip() for cat in categories if cat] | |
| else: | |
| raise NotImplementedError() | |
| categories = [cat.replace("_", " ") for cat in categories] | |
| return categories | |
| def get_augmentation(aug_type, patch_height=28, mix_prob=0.8, normalize=None): | |
| if not normalize: | |
| normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ) | |
| if aug_type == "moco-v2": | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), | |
| transforms.RandomApply( | |
| [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], | |
| p=0.8, # not strengthened | |
| ), | |
| transforms.RandomGrayscale(p=0.2), | |
| transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| elif aug_type == "moco-v1": | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), | |
| transforms.RandomGrayscale(p=0.2), | |
| transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| elif aug_type == "plain": | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.RandomCrop(224), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| elif aug_type == "clip_inference": | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.Resize(224, interpolation=Image.BICUBIC), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| elif aug_type == "test": | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| elif aug_type == "jigsaw": | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| # transforms.RandomHorizontalFlip(), | |
| JigsawPuzzle(patch_height=patch_height, patch_width=patch_height, mix_prob=1), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| elif aug_type == "jigsaw_all": | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| # transforms.RandomHorizontalFlip(), | |
| JigsawPuzzle_all(patch_height=patch_height, patch_width=patch_height, mix_prob=1), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| elif aug_type == "jigsaw_l": | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| # transforms.RandomHorizontalFlip(), | |
| JigsawPuzzle_l(patch_height=patch_height, patch_width=patch_height, mix_prob=1), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| elif aug_type == "rpe": | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| # transforms.RandomHorizontalFlip(), | |
| RandomPatchErase(patch_height=patch_height, patch_width=patch_height, mix_prob=1), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| elif aug_type == "rpn": | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| # transforms.RandomHorizontalFlip(), | |
| RandomPatchNoise(patch_height=patch_height, patch_width=patch_height, mix_prob=1), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| elif aug_type in ["ours", "ours_1"]: | |
| image_aug = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| JigsawPuzzle_all(patch_height=patch_height, patch_width=patch_height, mix_prob=mix_prob), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| else: | |
| image_aug = None | |
| return DualTransform( | |
| aug_type=aug_type, | |
| image_transform=image_aug, | |
| patch_height=patch_height, | |
| patch_width=patch_height, | |
| mix_prob=mix_prob, | |
| ) | |
| def fuse_foreground_background(img1, img2, mask): | |
| """ | |
| Given a (C,H,W) image tensor and a (possibly 2D) mask, | |
| multiply img by mask to black out the background. | |
| Expects 0 as background in the mask. | |
| """ | |
| mask = (mask > 0.5) | |
| output = img1 * mask + img2 * (~mask) | |
| return output | |
| def normalize(tensor): | |
| T = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ) | |
| return T(tensor) | |
| class DualTransform: | |
| """ | |
| A wrapper that can apply image-only transforms or image+mask transforms. | |
| """ | |
| def __init__(self, aug_type, image_transform=None, patch_height=28, patch_width=28,mix_prob=1.0): | |
| self.image_transform = image_transform | |
| self.aug_type = aug_type | |
| self.base_transform = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| self.moco_transform = transforms.Compose( | |
| [ | |
| transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), | |
| transforms.RandomApply( | |
| [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], | |
| p=0.8, # not strengthened | |
| ), | |
| transforms.RandomGrayscale(p=0.2), | |
| transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5), | |
| transforms.RandomHorizontalFlip(), | |
| # RandomErasing(mode='soft_pixel'), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| self.fpn = RandomPatchNoise(patch_height=28, patch_width=28, mix_prob=mix_prob) | |
| self.to_pil = transforms.ToPILImage() | |
| self.to_tensor = transforms.ToTensor() | |
| self.jigsaw = JigsawPuzzle(patch_height=28, patch_width=28, mix_prob=mix_prob) | |
| self.jigsaw_all = JigsawPuzzle_all(mix_prob=mix_prob) | |
| def __call__(self, img, mask=None): | |
| if self.aug_type == "mask": | |
| mask = self.base_transform(mask) | |
| return normalize(mask) | |
| elif self.aug_type == "foreground": | |
| mask = self.base_transform(mask) | |
| img = self.base_transform(img) | |
| return normalize(img * (mask>0.5).float()) | |
| elif self.aug_type == "fpn": | |
| mask = self.base_transform(mask) | |
| img = self.base_transform(img) | |
| img_n = self.to_tensor(self.fpn(self.to_pil(img))) | |
| return normalize(img_n * (mask>0.5).float()) | |
| elif self.aug_type == "bps": | |
| mask = self.base_transform(mask) | |
| img = self.base_transform(img) | |
| img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img))) | |
| return normalize(img_jigsaw * (mask<0.5).float()) | |
| elif self.aug_type == "ours_raw": | |
| mask = self.base_transform(mask) | |
| img = self.base_transform(img) | |
| img_n = self.to_tensor(self.fpn(self.to_pil(img))) | |
| img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img))) | |
| img_out = fuse_foreground_background(img_n, img_jigsaw, mask) | |
| return normalize(img_out) | |
| elif self.aug_type == "ours": | |
| mask = self.base_transform(mask) | |
| img = self.base_transform(img) | |
| img_n = self.to_tensor(self.fpn(self.to_pil(img))) | |
| img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img))) | |
| img_out = fuse_foreground_background(img_n, img_jigsaw, mask) | |
| return self.moco_transform(self.to_pil(img_out)) | |
| elif self.aug_type == "ours_fpn": | |
| mask = self.base_transform(mask) | |
| img = self.base_transform(img) | |
| img_n = self.to_tensor(self.fpn(self.to_pil(img))) | |
| img_out = fuse_foreground_background(img_n, img, mask) | |
| return self.moco_transform(self.to_pil(img_out)) | |
| elif self.aug_type == "ours_bps": | |
| mask = self.base_transform(mask) | |
| img = self.base_transform(img) | |
| # img_n = self.to_tensor(self.fpn(self.to_pil(img))) | |
| img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img))) | |
| img_out = fuse_foreground_background(img, img_jigsaw, mask) | |
| return self.moco_transform(self.to_pil(img_out)) | |
| # Always transform the image if we have an image_transform | |
| else: | |
| return self.image_transform(img) | |
| ''' | |
| elif self.aug_type == "ours_old": | |
| img_t = self.image_transform(img) | |
| img = self.base_transform(img) | |
| mask = self.base_transform(mask) | |
| img_t1 = fuse_foreground_background(img, img_t, mask) | |
| img_t1_pil = self.to_pil(img_t1) | |
| output = self.moco_transform(img_t1_pil) | |
| return output | |
| elif self.aug_type == "ours_1": | |
| img_t = self.image_transform(img) | |
| img = self.base_transform(img) | |
| mask = self.base_transform(mask) | |
| img_t1 = fuse_foreground_background(img, img_t, mask) | |
| return normalize(img_t1)' | |
| ''' | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self, name, fmt=":f"): | |
| self.name = name | |
| self.fmt = fmt | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def __str__(self): | |
| fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" | |
| return fmtstr.format(**self.__dict__) | |
| class ProgressMeter(object): | |
| def __init__(self, num_batches, meters, prefix=""): | |
| self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | |
| self.meters = meters | |
| self.prefix = prefix | |
| def display(self, batch): | |
| entries = [self.prefix + self.batch_fmtstr.format(batch)] | |
| entries += [str(meter) for meter in self.meters] | |
| logging.info("\t".join(entries)) | |
| def _get_batch_fmtstr(self, num_batches): | |
| num_digits = len(str(num_batches // 1)) | |
| fmt = "{:" + str(num_digits) + "d}" | |
| return "[" + fmt + "/" + fmt.format(num_batches) + "]" | |
| def save_checkpoint(model, optimizer, epoch, save_path="checkpoint.pth.tar"): | |
| state = { | |
| "state_dict": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "epoch": epoch, | |
| } | |
| torch.save(state, save_path) | |
| def adjust_learning_rate(optimizer, progress, args): | |
| """ | |
| Decay the learning rate based on epoch or iteration. | |
| """ | |
| if args.optim.cos: | |
| decay = 0.5 * (1.0 + math.cos(math.pi * progress / args.learn.full_progress)) | |
| elif args.optim.exp: | |
| decay = (1 + 10 * progress / args.learn.full_progress) ** -0.75 | |
| else: | |
| decay = 1.0 | |
| for milestone in args.optim.schedule: | |
| decay *= args.optim.gamma if progress >= milestone else 1.0 | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] = param_group["lr0"] * decay | |
| return decay | |
| def per_class_accuracy(y_true, y_pred): | |
| matrix = confusion_matrix(y_true, y_pred) | |
| acc_per_class = (matrix.diagonal() / matrix.sum(axis=1) * 100.0).round(2) | |
| logging.info( | |
| f"Accuracy per class: {acc_per_class}, mean: {acc_per_class.mean().round(2)}" | |
| ) | |
| return acc_per_class | |
| def get_distances(X, Y, dist_type="euclidean"): | |
| """ | |
| Args: | |
| X: (N, D) tensor | |
| Y: (M, D) tensor | |
| """ | |
| if dist_type == "euclidean": | |
| distances = torch.cdist(X, Y) | |
| elif dist_type == "cosine": | |
| distances = 1 - torch.matmul(F.normalize(X, dim=1), F.normalize(Y, dim=1).T) | |
| else: | |
| raise NotImplementedError(f"{dist_type} distance not implemented.") | |
| return distances | |
| def is_master(args): | |
| return args.rank % args.ngpus_per_node == 0 | |
| def use_wandb(args): | |
| return is_master(args) and args.use_wandb | |