ldmae / VMAE /util /misc.py
isno0907's picture
Upload 115 files
6c49103 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import builtins
import datetime
import os
import time
from collections import defaultdict, deque
from pathlib import Path
import torch
import torch.distributed as dist
from torch import inf
import numpy as np
from torchvision.transforms import functional as F
from typing import Optional, Tuple, Union, List
def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional["torch.device"] = None,
dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None,
):
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
is always created on the CPU.
"""
# device on which tensor is created defaults to device
rand_device = device
batch_size = shape[0]
layout = layout or torch.strided
device = device or torch.device("cpu")
if generator is not None:
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
print(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
# make sure generator list of length 1 is treated like a non-list
if isinstance(generator, list) and len(generator) == 1:
generator = generator[0]
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False,
fixed_std=None):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
self.fixed_std = fixed_std
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "DiagonalGaussianDistribution" = None, fixed_std=None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
sum_dim = self.mean.dim()
if self.fixed_std is not None:
fixed_var = torch.tensor(self.fixed_std) ** 2
fixed_logvar = torch.log(fixed_var)
return 0.5 * torch.sum(
self.var / fixed_var
- 1.0
- self.logvar
+ fixed_logvar,
dim=list(range(1,sum_dim)),
)
else:
if other is None:
# return 0.5 * torch.sum(
# torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
# dim=list(range(1,sum_dim)),
# )
return 0.5 * torch.sum(self.var - 1.0 - self.logvar,
dim=list(range(1,sum_dim)),
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=list(range(1,sum_dim)),
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
def set_for_tuning_decoder(args, model):
args.mask_ratio = 0.0
model.mask_token = None
# return it back
for name, param in model.named_parameters():
if 'decoder' not in name and 'from_latent' not in name:
param.requires_grad = False
else:
param.requires_grad = True
for name, param in model.named_parameters():
if param.requires_grad == False:
print(f"{name}: requires_grad = {param.requires_grad}")
else:
print(f"{name}: requires_grad = {param.requires_grad}")
def set_for_tuning_decoder_vae(args, model):
for name, param in model.named_parameters():
if 'post_quant_conv' in name or 'decoder' in name:
param.requires_grad = True
else:
param.requires_grad = False
for name, param in model.named_parameters():
print(f"{name}: requires_grad = {param.requires_grad}")
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = force or (get_world_size() > 8)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if args.dist_on_itp:
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
os.environ['LOCAL_RANK'] = str(args.gpu)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
# HERE
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
setup_for_distributed(is_master=True) # hack
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
from datetime import timedelta
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank,
timeout=timedelta(minutes=30) )
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
# self._scaler = torch.cuda.amp.GradScaler()
self._scaler = torch.amp.GradScaler("cuda")
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
# loss.backward()
# optimizer.step()
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
def save_model_vqvae(args, epoch, model, model_without_ddp, optimizer_ae, optimizer_disc):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer_ae': optimizer_ae.state_dict(),
'optimizer_disc': optimizer_disc.state_dict(),
'epoch': epoch,
'args': args,
}
save_on_master(to_save, checkpoint_path)
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
if loss_scaler is not None:
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch}
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
def resize_pos_embed(pos_embed, new_size):
_, HW, D = pos_embed.shape
H = int(HW ** 0.5)
assert H * H == HW
pos_embed_2d_resized = torch.nn.functional.interpolate(
pos_embed.reshape(1,H,H,D).permute(0, 3, 1, 2), # (batch, channels, height, width)
size=(new_size, new_size),
mode='bilinear',
align_corners=False
).permute(0, 2, 3, 1).reshape(1,-1,D) # (batch, height, width, channels)
return pos_embed_2d_resized
def load_model(args, model_without_ddp, optimizer, loss_scaler):
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
state_dict = checkpoint['model']
if state_dict['pos_embed'].shape[1] != model_without_ddp.pos_embed.shape[1]:
new_size = int(model_without_ddp.pos_embed.shape[1] ** 0.5)
print(f'latent resolution is {new_size} x {new_size}, reshape pos embedding')
print(f"prev pos embedding size: {state_dict['pos_embed'].shape}")
state_dict['pos_embed'] = resize_pos_embed(state_dict['pos_embed'], new_size)
print(f"new pos embedding size: {state_dict['pos_embed'].shape}")
print(f"prev dec pos embedding size: {state_dict['decoder_pos_embed'].shape}")
state_dict['decoder_pos_embed'] = resize_pos_embed(state_dict['decoder_pos_embed'], new_size)
print(f"new dec pos embedding size: {state_dict['decoder_pos_embed'].shape}")
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
print(msg)
print("Resume checkpoint %s" % args.resume)
if not args.tune_decoder:
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
def all_reduce_mean(x):
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
dist.all_reduce(x_reduce)
x_reduce /= world_size
return x_reduce.item()
else:
return x
def all_reduce_sum(x):
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
dist.all_reduce(x_reduce)
return x_reduce.item()
else:
return x
def write_stat(t, num_rows, path, len_dataset):
t_np = t.numpy().reshape(num_rows, -1)
if os.path.isfile(path):
stat = np.loadtxt(path, delimiter=',').astype(np.int64)
stat = stat.reshape(num_rows, -1)
stat = np.concatenate((stat, t_np), axis=1)
save = np.savetxt(path, stat, delimiter=',', fmt='%d')
else:
save = np.savetxt(path, t_np, delimiter=',', fmt='%d')
check = np.loadtxt(path, delimiter=',').astype(np.int64)/len_dataset
print(f'count_convergence is activated.\n{check.round(3)*100}')
import math
class SequentialDistributedSampler(torch.utils.data.sampler.Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, batch_size, rank=None, num_replicas=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.batch_size = batch_size
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += [indices[-1]] * (self.total_size - len(indices))
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
return iter(indices)
def __len__(self):
return self.num_samples
@torch.no_grad()
def update_mask(model, data_loader, device, dataset_train, target_attn,
mask_ratio = 0.75, ref_cluster = 'eigen', store_mask = False):
print("Starts upadating informed mask...")
len_ds = len(dataset_train)
model.eval()
metric_logger = MetricLogger(delimiter=" ")
header = 'Upadating informed mask...'
print_freq = 20
masks_weights =[]
mask_indices = []
for data_iter_step, (samples, _, index, _, path_first, path_second, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
samples = samples.to(device, non_blocking=True)
path_first = path_first.to(device, non_blocking=True)
path_second = path_second.to(device, non_blocking=True)
index = index.to(device, non_blocking=True)
with torch.cuda.amp.autocast():
if ref_cluster == 'alternate':
new_ids_shuffle_1, ref_cluster_size = model.forward_encoder_inference(samples, target_attn,
mask_ratio = mask_ratio, ref_cluster = 'small',
return_score = True)
new_ids_shuffle_2, ref_cluster_size = model.forward_encoder_inference(samples, target_attn,
mask_ratio = mask_ratio, ref_cluster = 'small',
return_score = True, force_flip=True)
new_ids_shuffle = torch.stack([new_ids_shuffle_1, new_ids_shuffle_2], dim=0) # 2 B N
new_ids_shuffle = new_ids_shuffle.permute(1,0,2) # B 2 N
# print(f'alternate: {new_ids_shuffle.shape}')
else:
new_ids_shuffle, ref_cluster_size = model.forward_encoder_inference(samples, target_attn,
mask_ratio = mask_ratio, ref_cluster = ref_cluster,
return_score = True)
# print(torch.cat([index, path_first, path_second]))
# mask_info = torch.cat([new_ids_shuffle, ref_cluster_size.unsqueeze(-1), path_first.unsqueeze(-1), path_second.unsqueeze(-1)], dim=-1)
mask_info = new_ids_shuffle
# mask_index = torch.cat([index.unsqueeze(-1),path_first.unsqueeze(-1), path_second.unsqueeze(-1)], dim=-1)
if data_iter_step % 200 == 0:
print(f'ids_shuffle: {new_ids_shuffle.shape}, ref_cluster_size: {ref_cluster_size.shape}')
print(f'mask_info: {mask_info.shape}')
if data_iter_step ==0:
print('Saving...')
examples = mask_info.detach().cpu().numpy()
store_path = f'/data2/projects/jeongwoo/jeongwoo/mae/analysis/convergence/mask_samples_{ref_cluster}'
save = np.save(store_path, examples)
masks_weights.append(mask_info)
# mask_indices.append(mask_index)
masks_weights = torch.cat(masks_weights, dim=0)
# mask_indices = torch.cat(mask_indices, dim=0)
print(f'masks_weights: {masks_weights.shape}')
dist.barrier()
gather_masks = [torch.ones_like(masks_weights) for _ in range(dist.get_world_size())]
# gather_mask_index = [torch.ones_like(mask_indices) for _ in range(dist.get_world_size())]
dist.all_gather(gather_masks, masks_weights)
# dist.all_gather(gather_mask_index, mask_indices)
all_mask_weights = torch.cat(gather_masks)
# all_mask_indices = torch.cat(gather_mask_index)
all_mask_weights = all_mask_weights[:len_ds]
# all_mask_indices = all_mask_indices[:len_ds]
if store_mask:
weights_to_store = all_mask_weights.cpu().numpy()
store_path = f'/data2/projects/jeongwoo/jeongwoo/mae/analysis/convergence/stored_masks_{ref_cluster}'
save = np.save(store_path, weights_to_store)
dataset_train.mask = all_mask_weights.cpu()
# dataset_train.mask_index = all_mask_indices.cpu()
print("Informed masks have been updated")
import torchvision.transforms as transforms
class maskRandomResizedCrop(transforms.RandomResizedCrop):
def __init__(self, size, **kwargs):
super().__init__(size, **kwargs)
self.mask_size = 14
def forward(self, img, mask):
mask = mask.reshape(14,14)
i, j, h, w = self.get_params(img, self.scale, self.ratio)
m_h_s = int(14 * (i/img.size[1]))
m_h_e = int(14 * ((i+h)/img.size[1])) + 1
m_w_s = int(14 * (j/img.size[0]))
m_w_e = int(14 * ((j+w)/img.size[0])) + 1
mask = mask[m_h_s:m_h_e, m_w_s:m_w_e]
img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
mask = F.resize(mask.unsqueeze(0), (self.mask_size, self.mask_size))
mask = mask.flatten()
return img, mask
class maskRandomHorizontalFlip(transforms.RandomHorizontalFlip):
def __init__(self):
super().__init__()
def forward(self, img, mask):
mask = mask.reshape(14, 14)
if torch.rand(1) < self.p:
img, mask = F.hflip(img), F.hflip(mask)
mask = mask.flatten()
return img, mask
class trainCompose(transforms.Compose):
def __call__(self, img, mask, hint_prob=False):
#mask needs to be processed individually for some operations
for i in self.transforms[:2]:
img, mask = i(img,mask)
# if not hint_prob:
# mask = torch.argsort(
# mask, dim=0, descending=False
# )
for t in self.transforms[2:]:
img = t(img)
return img, mask
# def schedule_hint(hint_ratio, hint_portion, do_schedule, cur_epoch, total_epoch):
# if hint_ratio is None: return None
# L = 196
# if do_schedule:
# alpha = 1 - (cur_epoch/total_epoch)**3 # 1 to 0
# hint_ratio = hint_ratio * alpha
# hint_portion = max(hint_portion, 0.2)
# hint_portion = alpha * (hint_portion - 0.2) + 0.2
# cluster_size = int(hint_portion*L)
# hint_num = max(int(hint_ratio * cluster_size), 2)
# print(f'{hint_num} tokens for hint in epoch {cur_epoch}')
# return hint_num
def schedule_hint(hint_ratio, hint_portion, do_schedule, cur_epoch, total_epoch, min_portion, min_ratio, schedule_exp, full_schedule = False):
if hint_ratio is None: return None
L = 196
if do_schedule:
assert hint_portion >= min_portion, 'min_portion is bigger than hint_portion.'
assert hint_ratio >= min_ratio, 'min_ratio is bigger than hint_ratio.'
if full_schedule:
total_epoch = 800
alpha = 1 - ((cur_epoch-0)/(total_epoch-0))**schedule_exp # 1 to 0
else:
alpha = 1 - ((cur_epoch-100)/(total_epoch-100))**schedule_exp # 1 to 0
hint_ratio = alpha * (hint_ratio - min_ratio) + min_ratio
hint_portion = alpha * (hint_portion - min_portion) + min_portion
hint_num = max(int(hint_ratio * L), 2)
print(f'{hint_num} tokens for hint in epoch {cur_epoch}')
print(f'Hint ratio & hint_portion: {hint_ratio, hint_portion} in epoch {cur_epoch}')
return hint_ratio, hint_portion
import torchvision.datasets as datasets
import random
class NormalImgDataset(datasets.ImageFolder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.num_retries = 10
def __getitem__(self, index: int):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
failed = []
for _ in range(self.num_retries):
path, target = self.samples[index]
try:
sample = self.loader(path)
except:
try:
sample = self.loader(path) # one more time
except:
failed.append(path)
index = random.randint(0, len(self.samples) - 1)
continue
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, torch.tensor(1)
else:
print('Failed to load {} after {} retries'.format(
failed, self.num_retries
))