DR-Augmentation / utils.py
prasannareddyp's picture
Upload 10 files
941ee5b verified
import logging
import math
import os
from PIL import Image
import yaml
from sklearn.metrics import confusion_matrix
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torchvision import transforms
from moco.loader import GaussianBlur
import numpy as np
from augmentations import JigsawPuzzle, JigsawPuzzle_l, JigsawPuzzle_all, RandomErasing, RandomPatchNoise, RandomPatchErase
LOG_FORMAT = "[%(levelname)s] %(asctime)s %(filename)s:%(lineno)s %(message)s"
LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"
NUM_CLASSES = {"domainnet-126": 126, "VISDA-C": 12, "PACS": 7}
import torch
import numpy as np
from PIL import Image
def configure_logger(rank, log_path=None):
if log_path:
log_dir = os.path.dirname(log_path)
os.makedirs(log_dir, exist_ok=True)
# only master process will print & write
level = logging.INFO if rank in {-1, 0} else logging.WARNING
handlers = [logging.StreamHandler()]
if rank in {0, -1} and log_path:
handlers.append(logging.FileHandler(log_path, "w"))
logging.basicConfig(
level=level,
format=LOG_FORMAT,
datefmt=LOG_DATEFMT,
handlers=handlers,
)
class UnevenBatchLoader:
"""Loader that loads data from multiple datasets with different length."""
def __init__(self, data_loaders, is_ddp=False):
# register N data loaders with epoch counters.
self.data_loaders = data_loaders
self.epoch_counters = [0 for _ in range(len(data_loaders))]
# set_epoch() needs to be called before creating the iterator
self.is_ddp = is_ddp
if is_ddp:
for data_loader in data_loaders:
data_loader.sampler.set_epoch(0)
self.iterators = [iter(data_loader) for data_loader in data_loaders]
def next_batch(self):
"""Load the next batch by collecting from N data loaders.
Args:
None
Returns:
data: a list of N items from N data loaders. each item has the format
output by a single data loader.
"""
data = []
for i, iterator in enumerate(self.iterators):
try:
batch_i = next(iterator)
except StopIteration:
self.epoch_counters[i] += 1
# create a new iterator
if self.is_ddp:
self.data_loaders[i].sampler.set_epoch(self.epoch_counters[i])
new_iterator = iter(self.data_loaders[i])
self.iterators[i] = new_iterator
batch_i = next(new_iterator)
data.append(batch_i)
return data
def update_loader(self, idx, loader, epoch=None):
if self.is_ddp and isinstance(epoch, int):
loader.sampler.set_epoch(epoch)
self.iterators[idx] = iter(loader)
class CustomDistributedDataParallel(DistributedDataParallel):
"""A wrapper class over DDP that relay "module" attribute."""
def __init__(self, model, **kwargs):
super(CustomDistributedDataParallel, self).__init__(model, **kwargs)
def __getattr__(self, name):
try:
return super(CustomDistributedDataParallel, self).__getattr__(name)
except AttributeError:
return getattr(self.module, name)
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
@torch.no_grad()
def remove_wrap_arounds(tensor, ranks):
if ranks == 0:
return tensor
world_size = dist.get_world_size()
single_length = len(tensor) // world_size
output = []
for rank in range(world_size):
sub_tensor = tensor[rank * single_length : (rank + 1) * single_length]
if rank >= ranks:
output.append(sub_tensor[:-1])
else:
output.append(sub_tensor)
output = torch.cat(output)
return output
def get_categories(category_file):
"""Return a list of categories ordered by corresponding label.
Args:
category_file: str, path to the category file. can be .yaml or .txt
Returns:
categories: List[str], a list of categories ordered by label.
"""
if category_file.endswith(".yaml"):
with open(category_file, "r") as fd:
cat_mapping = yaml.load(fd, Loader=yaml.SafeLoader)
categories = list(cat_mapping.keys())
categories.sort(key=lambda x: cat_mapping[x])
elif category_file.endswith(".txt"):
with open(category_file, "r") as fd:
categories = fd.readlines()
categories = [cat.strip() for cat in categories if cat]
else:
raise NotImplementedError()
categories = [cat.replace("_", " ") for cat in categories]
return categories
def get_augmentation(aug_type, patch_height=28, mix_prob=0.8, normalize=None):
if not normalize:
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
if aug_type == "moco-v2":
image_aug = transforms.Compose(
[
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomApply(
[transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)],
p=0.8, # not strengthened
),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
elif aug_type == "moco-v1":
image_aug = transforms.Compose(
[
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomGrayscale(p=0.2),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
elif aug_type == "plain":
image_aug = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
elif aug_type == "clip_inference":
image_aug = transforms.Compose(
[
transforms.Resize(224, interpolation=Image.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
)
elif aug_type == "test":
image_aug = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
)
elif aug_type == "jigsaw":
image_aug = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
# transforms.RandomHorizontalFlip(),
JigsawPuzzle(patch_height=patch_height, patch_width=patch_height, mix_prob=1),
transforms.ToTensor(),
normalize,
]
)
elif aug_type == "jigsaw_all":
image_aug = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
# transforms.RandomHorizontalFlip(),
JigsawPuzzle_all(patch_height=patch_height, patch_width=patch_height, mix_prob=1),
transforms.ToTensor(),
normalize,
]
)
elif aug_type == "jigsaw_l":
image_aug = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
# transforms.RandomHorizontalFlip(),
JigsawPuzzle_l(patch_height=patch_height, patch_width=patch_height, mix_prob=1),
transforms.ToTensor(),
normalize,
]
)
elif aug_type == "rpe":
image_aug = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
# transforms.RandomHorizontalFlip(),
RandomPatchErase(patch_height=patch_height, patch_width=patch_height, mix_prob=1),
transforms.ToTensor(),
normalize,
]
)
elif aug_type == "rpn":
image_aug = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
# transforms.RandomHorizontalFlip(),
RandomPatchNoise(patch_height=patch_height, patch_width=patch_height, mix_prob=1),
transforms.ToTensor(),
normalize,
]
)
elif aug_type in ["ours", "ours_1"]:
image_aug = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
JigsawPuzzle_all(patch_height=patch_height, patch_width=patch_height, mix_prob=mix_prob),
transforms.ToTensor(),
]
)
else:
image_aug = None
return DualTransform(
aug_type=aug_type,
image_transform=image_aug,
patch_height=patch_height,
patch_width=patch_height,
mix_prob=mix_prob,
)
def fuse_foreground_background(img1, img2, mask):
"""
Given a (C,H,W) image tensor and a (possibly 2D) mask,
multiply img by mask to black out the background.
Expects 0 as background in the mask.
"""
mask = (mask > 0.5)
output = img1 * mask + img2 * (~mask)
return output
def normalize(tensor):
T = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
return T(tensor)
class DualTransform:
"""
A wrapper that can apply image-only transforms or image+mask transforms.
"""
def __init__(self, aug_type, image_transform=None, patch_height=28, patch_width=28,mix_prob=1.0):
self.image_transform = image_transform
self.aug_type = aug_type
self.base_transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
]
)
self.moco_transform = transforms.Compose(
[
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomApply(
[transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)],
p=0.8, # not strengthened
),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
transforms.RandomHorizontalFlip(),
# RandomErasing(mode='soft_pixel'),
transforms.ToTensor(),
normalize,
]
)
self.fpn = RandomPatchNoise(patch_height=28, patch_width=28, mix_prob=mix_prob)
self.to_pil = transforms.ToPILImage()
self.to_tensor = transforms.ToTensor()
self.jigsaw = JigsawPuzzle(patch_height=28, patch_width=28, mix_prob=mix_prob)
self.jigsaw_all = JigsawPuzzle_all(mix_prob=mix_prob)
def __call__(self, img, mask=None):
if self.aug_type == "mask":
mask = self.base_transform(mask)
return normalize(mask)
elif self.aug_type == "foreground":
mask = self.base_transform(mask)
img = self.base_transform(img)
return normalize(img * (mask>0.5).float())
elif self.aug_type == "fpn":
mask = self.base_transform(mask)
img = self.base_transform(img)
img_n = self.to_tensor(self.fpn(self.to_pil(img)))
return normalize(img_n * (mask>0.5).float())
elif self.aug_type == "bps":
mask = self.base_transform(mask)
img = self.base_transform(img)
img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img)))
return normalize(img_jigsaw * (mask<0.5).float())
elif self.aug_type == "ours_raw":
mask = self.base_transform(mask)
img = self.base_transform(img)
img_n = self.to_tensor(self.fpn(self.to_pil(img)))
img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img)))
img_out = fuse_foreground_background(img_n, img_jigsaw, mask)
return normalize(img_out)
elif self.aug_type == "ours":
mask = self.base_transform(mask)
img = self.base_transform(img)
img_n = self.to_tensor(self.fpn(self.to_pil(img)))
img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img)))
img_out = fuse_foreground_background(img_n, img_jigsaw, mask)
return self.moco_transform(self.to_pil(img_out))
elif self.aug_type == "ours_fpn":
mask = self.base_transform(mask)
img = self.base_transform(img)
img_n = self.to_tensor(self.fpn(self.to_pil(img)))
img_out = fuse_foreground_background(img_n, img, mask)
return self.moco_transform(self.to_pil(img_out))
elif self.aug_type == "ours_bps":
mask = self.base_transform(mask)
img = self.base_transform(img)
# img_n = self.to_tensor(self.fpn(self.to_pil(img)))
img_jigsaw = self.to_tensor(self.jigsaw_all(self.to_pil(img)))
img_out = fuse_foreground_background(img, img_jigsaw, mask)
return self.moco_transform(self.to_pil(img_out))
# Always transform the image if we have an image_transform
else:
return self.image_transform(img)
'''
elif self.aug_type == "ours_old":
img_t = self.image_transform(img)
img = self.base_transform(img)
mask = self.base_transform(mask)
img_t1 = fuse_foreground_background(img, img_t, mask)
img_t1_pil = self.to_pil(img_t1)
output = self.moco_transform(img_t1_pil)
return output
elif self.aug_type == "ours_1":
img_t = self.image_transform(img)
img = self.base_transform(img)
mask = self.base_transform(mask)
img_t1 = fuse_foreground_background(img, img_t, mask)
return normalize(img_t1)'
'''
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
logging.info("\t".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
def save_checkpoint(model, optimizer, epoch, save_path="checkpoint.pth.tar"):
state = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}
torch.save(state, save_path)
def adjust_learning_rate(optimizer, progress, args):
"""
Decay the learning rate based on epoch or iteration.
"""
if args.optim.cos:
decay = 0.5 * (1.0 + math.cos(math.pi * progress / args.learn.full_progress))
elif args.optim.exp:
decay = (1 + 10 * progress / args.learn.full_progress) ** -0.75
else:
decay = 1.0
for milestone in args.optim.schedule:
decay *= args.optim.gamma if progress >= milestone else 1.0
for param_group in optimizer.param_groups:
param_group["lr"] = param_group["lr0"] * decay
return decay
def per_class_accuracy(y_true, y_pred):
matrix = confusion_matrix(y_true, y_pred)
acc_per_class = (matrix.diagonal() / matrix.sum(axis=1) * 100.0).round(2)
logging.info(
f"Accuracy per class: {acc_per_class}, mean: {acc_per_class.mean().round(2)}"
)
return acc_per_class
def get_distances(X, Y, dist_type="euclidean"):
"""
Args:
X: (N, D) tensor
Y: (M, D) tensor
"""
if dist_type == "euclidean":
distances = torch.cdist(X, Y)
elif dist_type == "cosine":
distances = 1 - torch.matmul(F.normalize(X, dim=1), F.normalize(Y, dim=1).T)
else:
raise NotImplementedError(f"{dist_type} distance not implemented.")
return distances
def is_master(args):
return args.rank % args.ngpus_per_node == 0
def use_wandb(args):
return is_master(args) and args.use_wandb