"""Shared utilities for all main scripts.""" import os import pickle import random import numpy as np import torch import torch.optim as optim from torch.utils.data import DataLoader, default_collate from torch.utils.data.distributed import DistributedSampler import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.utils.tensorboard import SummaryWriter from tqdm import trange class BaseTrainTester: """Basic train/test class to be inherited.""" def __init__(self, args): """Initialize.""" # if dist.get_rank() == 0: # args.save(str(args.log_dir / "hparams.json")) # self.args = args # if dist.get_rank() == 0: # self.writer = SummaryWriter(log_dir=args.log_dir) @staticmethod def get_datasets(): """Initialize datasets.""" train_dataset = None return train_dataset @staticmethod def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) np.random.seed(np.random.get_state()[1][0] + worker_id) def get_loaders(self, collate_fn): """Initialize data loaders.""" # Datasets train_dataset = self.get_datasets() # for i, data in enumerate(dataset): # print(f"Sample {i} shapes:") # for key, value in data.items(): # print(f"{key}: {value.shape if isinstance(value, torch.Tensor) else type(value)}") # Samplers and loaders g = torch.Generator() g.manual_seed(0) train_sampler = DistributedSampler(train_dataset) train_loader = DataLoader( train_dataset, batch_size=2, shuffle=False, num_workers=1, worker_init_fn=BaseTrainTester.seed_worker, collate_fn=collate_fn, pin_memory=True, sampler=train_sampler, drop_last=True, generator=g ) return train_loader @staticmethod def get_model(): """Initialize the model.""" return None @staticmethod def get_criterion(): """Get loss criterion for training.""" # criterion is a class, must have compute_loss and compute_metrics return None def get_optimizer(self, model): """Initialize optimizer.""" optimizer_grouped_parameters = [ {"params": [], "weight_decay": 0.0, "lr": self.args.lr}, {"params": [], "weight_decay": 5e-4, "lr": self.args.lr} ] no_decay = ["bias", "LayerNorm.weight", "LayerNorm.bias"] for name, param in model.named_parameters(): if any(nd in name for nd in no_decay): optimizer_grouped_parameters[0]["params"].append(param) else: optimizer_grouped_parameters[1]["params"].append(param) optimizer = optim.AdamW(optimizer_grouped_parameters) return optimizer def main(self, collate_fn=default_collate): """Run main training/testing pipeline.""" # Get loaders train_loader, test_loader = self.get_loaders(collate_fn) # Get model model = self.get_model() # Get criterion criterion = self.get_criterion() # Get optimizer optimizer = self.get_optimizer(model) # Move model to devices if torch.cuda.is_available(): model = model.cuda() model = DistributedDataParallel( model, device_ids=[self.args.local_rank], broadcast_buffers=False, find_unused_parameters=True ) # Check for a checkpoint start_iter, best_loss = 0, None if self.args.checkpoint: assert os.path.isfile(self.args.checkpoint) start_iter, best_loss = self.load_checkpoint(model, optimizer) # Eval only if bool(self.args.eval_only): print("Test evaluation.......") model.eval() new_loss = self.evaluate_nsteps( model, criterion, test_loader, step_id=-1, val_iters=max( 5, int(4 * len(self.args.tasks)/self.args.batch_size_val) ) ) return model # Training loop iter_loader = iter(train_loader) model.train() for step_id in trange(start_iter, self.args.train_iters): try: sample = next(iter_loader) except StopIteration: iter_loader = iter(train_loader) sample = next(iter_loader) self.train_one_step(model, criterion, optimizer, step_id, sample) if (step_id + 1) % self.args.val_freq == 0: print("Train evaluation.......") model.eval() new_loss = self.evaluate_nsteps( model, criterion, train_loader, step_id, val_iters=max( 5, int(4 * len(self.args.tasks)/self.args.batch_size_val) ), split='train' ) print("Test evaluation.......") model.eval() new_loss = self.evaluate_nsteps( model, criterion, test_loader, step_id, val_iters=max( 5, int(4 * len(self.args.tasks)/self.args.batch_size_val) ) ) if dist.get_rank() == 0: # save model best_loss = self.save_checkpoint( model, optimizer, step_id, new_loss, best_loss ) model.train() return model def train_one_step(self, model, criterion, optimizer, step_id, sample): """Run a single training step.""" pass @torch.no_grad() def evaluate_nsteps(self, model, criterion, loader, step_id, val_iters, split='val'): """Run a given number of evaluation steps.""" return None def load_checkpoint(self, model, optimizer): """Load from checkpoint.""" print("=> loading checkpoint '{}'".format(self.args.checkpoint)) model_dict = torch.load(self.args.checkpoint, map_location="cpu") model.load_state_dict(model_dict["weight"]) if 'optimizer' in model_dict: optimizer.load_state_dict(model_dict["optimizer"]) for p in range(len(optimizer.param_groups)): optimizer.param_groups[p]['lr'] = self.args.lr start_iter = model_dict.get("iter", 0) best_loss = model_dict.get("best_loss", None) print("=> loaded successfully '{}' (step {})".format( self.args.checkpoint, model_dict.get("iter", 0) )) del model_dict torch.cuda.empty_cache() return start_iter, best_loss def save_checkpoint(self, model, optimizer, step_id, new_loss, best_loss): """Save checkpoint if requested.""" if new_loss is None or best_loss is None or new_loss <= best_loss: best_loss = new_loss torch.save({ "weight": model.state_dict(), "optimizer": optimizer.state_dict(), "iter": step_id + 1, "best_loss": best_loss }, self.args.log_dir / "best.pth") torch.save({ "weight": model.state_dict(), "optimizer": optimizer.state_dict(), "iter": step_id + 1, "best_loss": best_loss }, self.args.log_dir / "last.pth") return best_loss def synchronize_between_processes(self, a_dict): all_dicts = all_gather(a_dict) if not is_dist_avail_and_initialized() or dist.get_rank() == 0: merged = {} for key in all_dicts[0].keys(): device = all_dicts[0][key].device merged[key] = torch.cat([ p[key].to(device) for p in all_dicts if key in p ]) a_dict = merged return a_dict def all_gather(data): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ world_size = get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to("cuda") # obtain Tensor size of each rank local_size = torch.tensor([tensor.numel()], device="cuda") size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.empty( (max_size,), dtype=torch.uint8, device="cuda" )) if local_size != max_size: padding = torch.empty( size=(max_size - local_size,), dtype=torch.uint8, device="cuda" ) tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size()