| """Shared utilities for all main scripts.""" |
|
|
| import os |
| import pickle |
| import random |
|
|
| import numpy as np |
| import torch |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, default_collate |
| from torch.utils.data.distributed import DistributedSampler |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import trange |
|
|
|
|
| class BaseTrainTester: |
| """Basic train/test class to be inherited.""" |
|
|
| def __init__(self, args): |
| """Initialize.""" |
| |
| |
|
|
| |
|
|
| |
| |
|
|
| @staticmethod |
| def get_datasets(): |
| """Initialize datasets.""" |
| train_dataset = None |
| return train_dataset |
| @staticmethod |
| def seed_worker(worker_id): |
| worker_seed = torch.initial_seed() % 2**32 |
| np.random.seed(worker_seed) |
| random.seed(worker_seed) |
| np.random.seed(np.random.get_state()[1][0] + worker_id) |
|
|
| def get_loaders(self, collate_fn): |
| """Initialize data loaders.""" |
| |
| train_dataset = self.get_datasets() |
| |
| |
| |
| |
| |
| g = torch.Generator() |
| g.manual_seed(0) |
| train_sampler = DistributedSampler(train_dataset) |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=2, |
| shuffle=False, |
| num_workers=1, |
| worker_init_fn=BaseTrainTester.seed_worker, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| sampler=train_sampler, |
| drop_last=True, |
| generator=g |
| ) |
| return train_loader |
|
|
| @staticmethod |
| def get_model(): |
| """Initialize the model.""" |
| return None |
|
|
| @staticmethod |
| def get_criterion(): |
| """Get loss criterion for training.""" |
| |
| return None |
|
|
| def get_optimizer(self, model): |
| """Initialize optimizer.""" |
| optimizer_grouped_parameters = [ |
| {"params": [], "weight_decay": 0.0, "lr": self.args.lr}, |
| {"params": [], "weight_decay": 5e-4, "lr": self.args.lr} |
| ] |
| no_decay = ["bias", "LayerNorm.weight", "LayerNorm.bias"] |
| for name, param in model.named_parameters(): |
| if any(nd in name for nd in no_decay): |
| optimizer_grouped_parameters[0]["params"].append(param) |
| else: |
| optimizer_grouped_parameters[1]["params"].append(param) |
| optimizer = optim.AdamW(optimizer_grouped_parameters) |
| return optimizer |
|
|
| def main(self, collate_fn=default_collate): |
| """Run main training/testing pipeline.""" |
| |
| train_loader, test_loader = self.get_loaders(collate_fn) |
|
|
| |
| model = self.get_model() |
|
|
| |
| criterion = self.get_criterion() |
|
|
| |
| optimizer = self.get_optimizer(model) |
|
|
| |
| if torch.cuda.is_available(): |
| model = model.cuda() |
| model = DistributedDataParallel( |
| model, device_ids=[self.args.local_rank], |
| broadcast_buffers=False, find_unused_parameters=True |
| ) |
|
|
| |
| start_iter, best_loss = 0, None |
| if self.args.checkpoint: |
| assert os.path.isfile(self.args.checkpoint) |
| start_iter, best_loss = self.load_checkpoint(model, optimizer) |
|
|
| |
| if bool(self.args.eval_only): |
| print("Test evaluation.......") |
| model.eval() |
| new_loss = self.evaluate_nsteps( |
| model, criterion, test_loader, step_id=-1, |
| val_iters=max( |
| 5, |
| int(4 * len(self.args.tasks)/self.args.batch_size_val) |
| ) |
| ) |
| return model |
|
|
| |
| iter_loader = iter(train_loader) |
| model.train() |
| for step_id in trange(start_iter, self.args.train_iters): |
| try: |
| sample = next(iter_loader) |
| except StopIteration: |
| iter_loader = iter(train_loader) |
| sample = next(iter_loader) |
|
|
| self.train_one_step(model, criterion, optimizer, step_id, sample) |
| if (step_id + 1) % self.args.val_freq == 0: |
| print("Train evaluation.......") |
| model.eval() |
| new_loss = self.evaluate_nsteps( |
| model, criterion, train_loader, step_id, |
| val_iters=max( |
| 5, |
| int(4 * len(self.args.tasks)/self.args.batch_size_val) |
| ), |
| split='train' |
| ) |
| print("Test evaluation.......") |
| model.eval() |
| new_loss = self.evaluate_nsteps( |
| model, criterion, test_loader, step_id, |
| val_iters=max( |
| 5, |
| int(4 * len(self.args.tasks)/self.args.batch_size_val) |
| ) |
| ) |
| if dist.get_rank() == 0: |
| best_loss = self.save_checkpoint( |
| model, optimizer, step_id, |
| new_loss, best_loss |
| ) |
| model.train() |
|
|
| return model |
|
|
| def train_one_step(self, model, criterion, optimizer, step_id, sample): |
| """Run a single training step.""" |
| pass |
|
|
| @torch.no_grad() |
| def evaluate_nsteps(self, model, criterion, loader, step_id, val_iters, |
| split='val'): |
| """Run a given number of evaluation steps.""" |
| return None |
|
|
| def load_checkpoint(self, model, optimizer): |
| """Load from checkpoint.""" |
| print("=> loading checkpoint '{}'".format(self.args.checkpoint)) |
|
|
| model_dict = torch.load(self.args.checkpoint, map_location="cpu") |
| model.load_state_dict(model_dict["weight"]) |
| if 'optimizer' in model_dict: |
| optimizer.load_state_dict(model_dict["optimizer"]) |
| for p in range(len(optimizer.param_groups)): |
| optimizer.param_groups[p]['lr'] = self.args.lr |
| start_iter = model_dict.get("iter", 0) |
| best_loss = model_dict.get("best_loss", None) |
|
|
| print("=> loaded successfully '{}' (step {})".format( |
| self.args.checkpoint, model_dict.get("iter", 0) |
| )) |
| del model_dict |
| torch.cuda.empty_cache() |
| return start_iter, best_loss |
|
|
| def save_checkpoint(self, model, optimizer, step_id, new_loss, best_loss): |
| """Save checkpoint if requested.""" |
| if new_loss is None or best_loss is None or new_loss <= best_loss: |
| best_loss = new_loss |
| torch.save({ |
| "weight": model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "iter": step_id + 1, |
| "best_loss": best_loss |
| }, self.args.log_dir / "best.pth") |
| torch.save({ |
| "weight": model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "iter": step_id + 1, |
| "best_loss": best_loss |
| }, self.args.log_dir / "last.pth") |
| return best_loss |
|
|
| def synchronize_between_processes(self, a_dict): |
| all_dicts = all_gather(a_dict) |
|
|
| if not is_dist_avail_and_initialized() or dist.get_rank() == 0: |
| merged = {} |
| for key in all_dicts[0].keys(): |
| device = all_dicts[0][key].device |
| merged[key] = torch.cat([ |
| p[key].to(device) for p in all_dicts |
| if key in p |
| ]) |
| a_dict = merged |
| return a_dict |
|
|
|
|
| def all_gather(data): |
| """ |
| Run all_gather on arbitrary picklable data (not necessarily tensors) |
| |
| Args: |
| data: any picklable object |
| Returns: |
| list[data]: list of data gathered from each rank |
| """ |
| world_size = get_world_size() |
| if world_size == 1: |
| return [data] |
|
|
| |
| buffer = pickle.dumps(data) |
| storage = torch.ByteStorage.from_buffer(buffer) |
| tensor = torch.ByteTensor(storage).to("cuda") |
|
|
| |
| local_size = torch.tensor([tensor.numel()], device="cuda") |
| size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] |
| dist.all_gather(size_list, local_size) |
| size_list = [int(size.item()) for size in size_list] |
| max_size = max(size_list) |
|
|
| |
| |
| |
| tensor_list = [] |
| for _ in size_list: |
| tensor_list.append(torch.empty( |
| (max_size,), dtype=torch.uint8, device="cuda" |
| )) |
| if local_size != max_size: |
| padding = torch.empty( |
| size=(max_size - local_size,), |
| dtype=torch.uint8, device="cuda" |
| ) |
| tensor = torch.cat((tensor, padding), dim=0) |
| dist.all_gather(tensor_list, tensor) |
|
|
| data_list = [] |
| for size, tensor in zip(size_list, tensor_list): |
| buffer = tensor.cpu().numpy().tobytes()[:size] |
| data_list.append(pickle.loads(buffer)) |
|
|
| return data_list |
|
|
|
|
| def is_dist_avail_and_initialized(): |
| if not dist.is_available(): |
| return False |
| if not dist.is_initialized(): |
| return False |
| return True |
|
|
|
|
| def get_world_size(): |
| if not is_dist_avail_and_initialized(): |
| return 1 |
| return dist.get_world_size() |
|
|