| from abc import abstractmethod
|
| import os
|
| import time
|
| import json
|
| import copy
|
| import threading
|
| from functools import partial
|
| from contextlib import nullcontext
|
|
|
| import torch
|
| import torch.distributed as dist
|
| from torch.utils.data import DataLoader
|
| from torch.nn.parallel import DistributedDataParallel as DDP
|
| import numpy as np
|
|
|
| from torchvision import utils
|
| from torch.utils.tensorboard import SummaryWriter
|
|
|
| from .utils import *
|
| from ..utils.general_utils import *
|
| from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
|
| from ..utils.dist_utils import *
|
| from ..utils import grad_clip_utils, elastic_utils
|
|
|
|
|
| class BasicTrainer:
|
| """
|
| Trainer for basic training loop.
|
|
|
| Args:
|
| models (dict[str, nn.Module]): Models to train.
|
| dataset (torch.utils.data.Dataset): Dataset.
|
| output_dir (str): Output directory.
|
| load_dir (str): Load directory.
|
| step (int): Step to load.
|
| batch_size (int): Batch size.
|
| batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
| batch_split (int): Split batch with gradient accumulation.
|
| max_steps (int): Max steps.
|
| optimizer (dict): Optimizer config.
|
| lr_scheduler (dict): Learning rate scheduler config.
|
| elastic (dict): Elastic memory management config.
|
| grad_clip (float or dict): Gradient clip config.
|
| ema_rate (float or list): Exponential moving average rates.
|
| mix_precision_mode (str):
|
| - None: No mixed precision.
|
| - 'inflat_all': Hold a inflated fp32 master param for all params.
|
| - 'amp': Automatic mixed precision.
|
| mix_precision_dtype (str): Mixed precision dtype.
|
| fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
| parallel_mode (str): Parallel mode. Options are 'ddp'.
|
| finetune_ckpt (dict): Finetune checkpoint.
|
| log_param_stats (bool): Log parameter stats.
|
| i_print (int): Print interval.
|
| i_log (int): Log interval.
|
| i_sample (int): Sample interval.
|
| i_save (int): Save interval.
|
| i_ddpcheck (int): DDP check interval.
|
| """
|
| def __init__(self,
|
| models,
|
| dataset,
|
| *,
|
| output_dir,
|
| load_dir,
|
| step,
|
| max_steps,
|
| batch_size=None,
|
| batch_size_per_gpu=None,
|
| batch_split=None,
|
| optimizer={},
|
| lr_scheduler=None,
|
| elastic=None,
|
| grad_clip=None,
|
| ema_rate=0.9999,
|
| fp16_mode=None,
|
| mix_precision_mode='inflat_all',
|
| mix_precision_dtype='float16',
|
| fp16_scale_growth=1e-3,
|
| parallel_mode='ddp',
|
| finetune_ckpt=None,
|
| log_param_stats=False,
|
| prefetch_data=True,
|
| snapshot_batch_size=4,
|
| i_print=1000,
|
| i_log=500,
|
| i_sample=10000,
|
| i_save=10000,
|
| i_ddpcheck=10000,
|
| **kwargs
|
| ):
|
| assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
|
|
|
| self.models = models
|
| self.dataset = dataset
|
| self.batch_split = batch_split if batch_split is not None else 1
|
| self.max_steps = max_steps
|
| self.optimizer_config = optimizer
|
| self.lr_scheduler_config = lr_scheduler
|
| self.elastic_controller_config = elastic
|
| self.grad_clip = grad_clip
|
| self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
|
| if fp16_mode is not None:
|
| mix_precision_dtype = 'float16'
|
| mix_precision_mode = fp16_mode
|
| self.mix_precision_mode = mix_precision_mode
|
| self.mix_precision_dtype = str_to_dtype(mix_precision_dtype)
|
| self.fp16_scale_growth = fp16_scale_growth
|
| self.parallel_mode = parallel_mode
|
| self.log_param_stats = log_param_stats
|
| self.prefetch_data = prefetch_data
|
| self.snapshot_batch_size = snapshot_batch_size
|
| self.log = []
|
| if self.prefetch_data:
|
| self._data_prefetched = None
|
|
|
| self.output_dir = output_dir
|
| self.i_print = i_print
|
| self.i_log = i_log
|
| self.i_sample = i_sample
|
| self.i_save = i_save
|
| self.i_ddpcheck = i_ddpcheck
|
|
|
| if dist.is_initialized():
|
|
|
| self.world_size = dist.get_world_size()
|
| self.rank = dist.get_rank()
|
| self.local_rank = dist.get_rank() % torch.cuda.device_count()
|
| self.is_master = self.rank == 0
|
| else:
|
|
|
| self.world_size = 1
|
| self.rank = 0
|
| self.local_rank = 0
|
| self.is_master = True
|
|
|
| self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
|
| self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
|
| assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
|
| assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
|
|
|
| self.init_models_and_more(**kwargs)
|
| self.prepare_dataloader(**kwargs)
|
|
|
|
|
| self.step = 0
|
| if load_dir is not None and step is not None:
|
| self.load(load_dir, step)
|
| elif finetune_ckpt is not None:
|
| self.finetune_from(finetune_ckpt)
|
|
|
| if self.is_master:
|
| os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
|
| os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
|
| self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
|
|
|
| if self.parallel_mode == 'ddp' and self.world_size > 1:
|
| self.check_ddp()
|
|
|
| if self.is_master:
|
| print('\n\nTrainer initialized.')
|
| print(self)
|
|
|
| def __str__(self):
|
| lines = []
|
| lines.append(self.__class__.__name__)
|
| lines.append(f' - Models:')
|
| for name, model in self.models.items():
|
| lines.append(f' - {name}: {model.__class__.__name__}')
|
| lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
|
| lines.append(f' - Dataloader:')
|
| lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
|
| lines.append(f' - Num workers: {self.dataloader.num_workers}')
|
| lines.append(f' - Number of steps: {self.max_steps}')
|
| lines.append(f' - Number of GPUs: {self.world_size}')
|
| lines.append(f' - Batch size: {self.batch_size}')
|
| lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
|
| lines.append(f' - Batch split: {self.batch_split}')
|
| lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
|
| lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
|
| if self.lr_scheduler_config is not None:
|
| lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
|
| if self.elastic_controller_config is not None:
|
| lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
|
| if self.grad_clip is not None:
|
| lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
|
| lines.append(f' - EMA rate: {self.ema_rate}')
|
| lines.append(f' - Mixed precision dtype: {self.mix_precision_dtype}')
|
| lines.append(f' - Mixed precision mode: {self.mix_precision_mode}')
|
| if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
|
| lines.append(f' - FP16 scale growth: {self.fp16_scale_growth}')
|
| lines.append(f' - Parallel mode: {self.parallel_mode}')
|
| return '\n'.join(lines)
|
|
|
| @property
|
| def device(self):
|
| for _, model in self.models.items():
|
| if hasattr(model, 'device'):
|
| return model.device
|
| return next(list(self.models.values())[0].parameters()).device
|
|
|
| def init_models_and_more(self, **kwargs):
|
| """
|
| Initialize models and more.
|
| """
|
| if self.world_size > 1:
|
|
|
| self.training_models = {
|
| name: DDP(
|
| model,
|
| device_ids=[self.local_rank],
|
| output_device=self.local_rank,
|
| bucket_cap_mb=128,
|
| find_unused_parameters=False
|
| )
|
| for name, model in self.models.items()
|
| }
|
| else:
|
| self.training_models = self.models
|
|
|
|
|
| self.model_params = sum(
|
| [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
|
| , [])
|
| if self.mix_precision_mode == 'amp':
|
| self.master_params = self.model_params
|
| if self.mix_precision_dtype == torch.float16:
|
| self.scaler = torch.GradScaler()
|
| elif self.mix_precision_mode == 'inflat_all':
|
| self.master_params = make_master_params(self.model_params)
|
| if self.mix_precision_dtype == torch.float16:
|
| self.log_scale = 20.0
|
| elif self.mix_precision_mode is None:
|
| self.master_params = self.model_params
|
| else:
|
| raise NotImplementedError(f'Mix precision mode {self.mix_precision_mode} is not implemented.')
|
|
|
|
|
| if self.is_master:
|
| self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
|
|
|
|
|
| if hasattr(torch.optim, self.optimizer_config['name']):
|
| self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
|
| else:
|
| self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
|
|
|
|
|
| if self.lr_scheduler_config is not None:
|
| if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
|
| self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
|
| else:
|
| self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
|
|
|
|
|
| if self.elastic_controller_config is not None:
|
| assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
|
| 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
|
| self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
|
| for model in self.models.values():
|
| if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
|
| model.register_memory_controller(self.elastic_controller)
|
|
|
|
|
| if self.grad_clip is not None:
|
| if isinstance(self.grad_clip, (float, int)):
|
| self.grad_clip = float(self.grad_clip)
|
| else:
|
| self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
|
|
|
| def prepare_dataloader(self, **kwargs):
|
| """
|
| Prepare dataloader.
|
| """
|
| self.data_sampler = ResumableSampler(
|
| self.dataset,
|
| shuffle=True,
|
| )
|
| self.dataloader = DataLoader(
|
| self.dataset,
|
| batch_size=self.batch_size_per_gpu,
|
| num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
|
| pin_memory=True,
|
| drop_last=True,
|
| persistent_workers=True,
|
| collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
| sampler=self.data_sampler,
|
| )
|
| self.data_iterator = cycle(self.dataloader)
|
|
|
| def _master_params_to_state_dicts(self, master_params):
|
| """
|
| Convert master params to dict of state_dicts.
|
| """
|
| if self.mix_precision_mode == 'inflat_all':
|
| master_params = unflatten_master_params(self.model_params, master_params)
|
| state_dicts = {name: model.state_dict() for name, model in self.models.items()}
|
| master_params_names = sum(
|
| [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
|
| , [])
|
| for i, (model_name, param_name) in enumerate(master_params_names):
|
| state_dicts[model_name][param_name] = master_params[i]
|
| return state_dicts
|
|
|
| def _state_dicts_to_master_params(self, master_params, state_dicts):
|
| """
|
| Convert a state_dict to master params.
|
| """
|
| master_params_names = sum(
|
| [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
|
| , [])
|
| params = [state_dicts[name][param_name] for name, param_name in master_params_names]
|
| if self.mix_precision_mode == 'inflat_all':
|
| model_params_to_master_params(params, master_params)
|
| else:
|
| for i, param in enumerate(params):
|
| master_params[i].data.copy_(param.data)
|
|
|
| def load(self, load_dir, step=0):
|
| """
|
| Load a checkpoint.
|
| Should be called by all processes.
|
| """
|
| if self.is_master:
|
| print(f'\nLoading checkpoint from step {step}...', end='')
|
|
|
| model_ckpts = {}
|
| for name, model in self.models.items():
|
| model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
|
| model_ckpts[name] = model_ckpt
|
| model.load_state_dict(model_ckpt)
|
| self._state_dicts_to_master_params(self.master_params, model_ckpts)
|
| del model_ckpts
|
|
|
| if self.is_master:
|
| for i, ema_rate in enumerate(self.ema_rate):
|
| ema_ckpts = {}
|
| for name, model in self.models.items():
|
| ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
|
| ema_ckpts[name] = ema_ckpt
|
| self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
|
| del ema_ckpts
|
|
|
| misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
|
| self.optimizer.load_state_dict(misc_ckpt['optimizer'])
|
| self.step = misc_ckpt['step']
|
| self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
|
| if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
|
| self.scaler.load_state_dict(misc_ckpt['scaler'])
|
| elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
|
| self.log_scale = misc_ckpt['log_scale']
|
| if self.lr_scheduler_config is not None:
|
| self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
|
| if self.elastic_controller_config is not None:
|
| self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
|
| if self.grad_clip is not None and not isinstance(self.grad_clip, float):
|
| self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
|
| del misc_ckpt
|
|
|
| if self.world_size > 1:
|
| dist.barrier()
|
| if self.is_master:
|
| print(' Done.')
|
|
|
| if self.world_size > 1:
|
| self.check_ddp()
|
|
|
| def save(self, non_blocking=True):
|
| """
|
| Save a checkpoint.
|
| Should be called only by the rank 0 process.
|
| """
|
| assert self.is_master, 'save() should be called only by the rank 0 process.'
|
| print(f'\nSaving checkpoint at step {self.step}...', end='')
|
|
|
| model_ckpts = self._master_params_to_state_dicts(self.master_params)
|
| for name, model_ckpt in model_ckpts.items():
|
| model_ckpt = {k: v.cpu() for k, v in model_ckpt.items()}
|
| if non_blocking:
|
| threading.Thread(
|
| target=torch.save,
|
| args=(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')),
|
| ).start()
|
| else:
|
| torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
|
|
|
| for i, ema_rate in enumerate(self.ema_rate):
|
| ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
|
| for name, ema_ckpt in ema_ckpts.items():
|
| ema_ckpt = {k: v.cpu() for k, v in ema_ckpt.items()}
|
| if non_blocking:
|
| threading.Thread(
|
| target=torch.save,
|
| args=(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')),
|
| ).start()
|
| else:
|
| torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
|
|
|
| misc_ckpt = {
|
| 'optimizer': self.optimizer.state_dict(),
|
| 'step': self.step,
|
| 'data_sampler': self.data_sampler.state_dict(),
|
| }
|
| if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
|
| misc_ckpt['scaler'] = self.scaler.state_dict()
|
| elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
|
| misc_ckpt['log_scale'] = self.log_scale
|
| if self.lr_scheduler_config is not None:
|
| misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
|
| if self.elastic_controller_config is not None:
|
| misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
|
| if self.grad_clip is not None and not isinstance(self.grad_clip, float):
|
| misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
|
| if non_blocking:
|
| threading.Thread(
|
| target=torch.save,
|
| args=(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')),
|
| ).start()
|
| else:
|
| torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
|
| print(' Done.')
|
|
|
| def finetune_from(self, finetune_ckpt):
|
| """
|
| Finetune from a checkpoint.
|
| Should be called by all processes.
|
| """
|
| if self.is_master:
|
| print('\nFinetuning from:')
|
| for name, path in finetune_ckpt.items():
|
| print(f' - {name}: {path}')
|
|
|
| model_ckpts = {}
|
| for name, model in self.models.items():
|
| model_state_dict = model.state_dict()
|
| if name in finetune_ckpt:
|
| model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
|
| for k, v in model_ckpt.items():
|
| if k not in model_state_dict:
|
| if self.is_master:
|
| print(f'Warning: {k} not found in model_state_dict, skipped.')
|
| model_ckpt[k] = None
|
| elif model_ckpt[k].shape != model_state_dict[k].shape:
|
| if self.is_master:
|
| print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
|
| model_ckpt[k] = model_state_dict[k]
|
| model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None}
|
| model_ckpts[name] = model_ckpt
|
| model.load_state_dict(model_ckpt)
|
| else:
|
| if self.is_master:
|
| print(f'Warning: {name} not found in finetune_ckpt, skipped.')
|
| model_ckpts[name] = model_state_dict
|
| self._state_dicts_to_master_params(self.master_params, model_ckpts)
|
| if self.is_master:
|
| for i, ema_rate in enumerate(self.ema_rate):
|
| self._state_dicts_to_master_params(self.ema_params[i], model_ckpts)
|
| del model_ckpts
|
|
|
| if self.world_size > 1:
|
| dist.barrier()
|
| if self.is_master:
|
| print('Done.')
|
|
|
| if self.world_size > 1:
|
| self.check_ddp()
|
|
|
| @abstractmethod
|
| def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
|
| """
|
| Run a snapshot of the model.
|
| """
|
| pass
|
|
|
| @torch.no_grad()
|
| def visualize_sample(self, sample):
|
| """
|
| Convert a sample to an image.
|
| """
|
| if hasattr(self.dataset, 'visualize_sample'):
|
| return self.dataset.visualize_sample(sample)
|
| else:
|
| return sample
|
|
|
| @torch.no_grad()
|
| def snapshot_dataset(self, num_samples=100, batch_size=4):
|
| """
|
| Sample images from the dataset.
|
| """
|
| dataloader = torch.utils.data.DataLoader(
|
| self.dataset,
|
| batch_size=batch_size,
|
| num_workers=1,
|
| shuffle=True,
|
| collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
| )
|
| save_cfg = {}
|
| for i in range(0, num_samples, batch_size):
|
| data = next(iter(dataloader))
|
| data = {k: v[:min(num_samples - i, batch_size)] for k, v in data.items()}
|
| data = recursive_to_device(data, self.device)
|
| vis = self.visualize_sample(data)
|
| if isinstance(vis, dict):
|
| for k, v in vis.items():
|
| if f'dataset_{k}' not in save_cfg:
|
| save_cfg[f'dataset_{k}'] = []
|
| save_cfg[f'dataset_{k}'].append(v)
|
| else:
|
| if 'dataset' not in save_cfg:
|
| save_cfg['dataset'] = []
|
| save_cfg['dataset'].append(vis)
|
| for name, image in save_cfg.items():
|
| utils.save_image(
|
| torch.cat(image, dim=0),
|
| os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
|
| nrow=int(np.sqrt(num_samples)),
|
| normalize=True,
|
| value_range=self.dataset.value_range,
|
| )
|
|
|
| @torch.no_grad()
|
| def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
|
| """
|
| Sample images from the model.
|
| NOTE: This function should be called by all processes.
|
| """
|
| if self.is_master:
|
| print(f'\nSampling {num_samples} images...', end='')
|
|
|
| if suffix is None:
|
| suffix = f'step{self.step:07d}'
|
|
|
|
|
| num_samples_per_process = int(np.ceil(num_samples / self.world_size))
|
| amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext
|
| with amp_context():
|
| samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
|
|
|
|
|
| for key in list(samples.keys()):
|
| if samples[key]['type'] == 'sample':
|
| vis = self.visualize_sample(samples[key]['value'])
|
| if isinstance(vis, dict):
|
| for k, v in vis.items():
|
| samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
|
| del samples[key]
|
| else:
|
| samples[key] = {'value': vis, 'type': 'image'}
|
|
|
|
|
| if self.world_size > 1:
|
| for key in samples.keys():
|
| samples[key]['value'] = samples[key]['value'].contiguous()
|
| if self.is_master:
|
| all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
|
| else:
|
| all_images = []
|
| dist.gather(samples[key]['value'], all_images, dst=0)
|
| if self.is_master:
|
| samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
|
|
|
|
|
| if self.is_master:
|
| os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
|
| for key in samples.keys():
|
| if samples[key]['type'] == 'image':
|
| utils.save_image(
|
| samples[key]['value'],
|
| os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
|
| nrow=int(np.sqrt(num_samples)),
|
| normalize=True,
|
| value_range=self.dataset.value_range,
|
| )
|
| elif samples[key]['type'] == 'number':
|
| min = samples[key]['value'].min()
|
| max = samples[key]['value'].max()
|
| images = (samples[key]['value'] - min) / (max - min)
|
| images = utils.make_grid(
|
| images,
|
| nrow=int(np.sqrt(num_samples)),
|
| normalize=False,
|
| )
|
| save_image_with_notes(
|
| images,
|
| os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
|
| notes=f'{key} min: {min}, max: {max}',
|
| )
|
|
|
| if self.is_master:
|
| print(' Done.')
|
|
|
| def update_ema(self):
|
| """
|
| Update exponential moving average.
|
| Should only be called by the rank 0 process.
|
| """
|
| assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
|
| for i, ema_rate in enumerate(self.ema_rate):
|
| for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
|
| ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
|
|
|
| def check_ddp(self):
|
| """
|
| Check if DDP is working properly.
|
| Should be called by all process.
|
| """
|
| if self.is_master:
|
| print('\nPerforming DDP check...')
|
|
|
| if self.is_master:
|
| print('Checking if parameters are consistent across processes...')
|
| dist.barrier()
|
| try:
|
| for p in self.master_params:
|
|
|
| for i in range(0, p.numel(), 10000000):
|
| sub_size = min(10000000, p.numel() - i)
|
| sub_p = p.detach().view(-1)[i:i+sub_size]
|
|
|
| sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
|
| dist.all_gather(sub_p_gather, sub_p)
|
|
|
| assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
|
| except AssertionError as e:
|
| if self.is_master:
|
| print(f'\n\033[91mError: {e}\033[0m')
|
| print('DDP check failed.')
|
| raise e
|
|
|
| dist.barrier()
|
| if self.is_master:
|
| print('Done.')
|
|
|
| @abstractmethod
|
| def training_losses(**mb_data):
|
| """
|
| Compute training losses.
|
| """
|
| pass
|
|
|
| def load_data(self):
|
| """
|
| Load data.
|
| """
|
| if self.prefetch_data:
|
| if self._data_prefetched is None:
|
| self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
| data = self._data_prefetched
|
| self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
| else:
|
| data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
|
|
|
|
| if isinstance(data, dict):
|
| if self.batch_split == 1:
|
| data_list = [data]
|
| else:
|
| batch_size = list(data.values())[0].shape[0]
|
| data_list = [
|
| {k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
|
| for i in range(self.batch_split)
|
| ]
|
| elif isinstance(data, list):
|
| data_list = data
|
| else:
|
| raise ValueError('Data must be a dict or a list of dicts.')
|
|
|
| return data_list
|
|
|
| def run_step(self, data_list):
|
| """
|
| Run a training step.
|
| """
|
| step_log = {'loss': {}, 'status': {}}
|
| amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext
|
| elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
|
|
|
|
|
| losses = []
|
| statuses = []
|
| elastic_controller_logs = []
|
| zero_grad(self.model_params)
|
| for i, mb_data in enumerate(data_list):
|
|
|
| sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
|
| with nested_contexts(*sync_contexts), elastic_controller_context():
|
| with amp_context():
|
| loss, status = self.training_losses(**mb_data)
|
| l = loss['loss'] / len(data_list)
|
|
|
| if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
|
| self.scaler.scale(l).backward()
|
| elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16:
|
| scaled_l = l * (2 ** self.log_scale)
|
| scaled_l.backward()
|
| else:
|
| l.backward()
|
|
|
| losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
|
| statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
|
| if self.elastic_controller_config is not None:
|
| elastic_controller_logs.append(self.elastic_controller.log())
|
|
|
| if self.grad_clip is not None:
|
| if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
|
| self.scaler.unscale_(self.optimizer)
|
| elif self.mix_precision_mode == 'inflat_all':
|
| model_grads_to_master_grads(self.model_params, self.master_params)
|
| if self.mix_precision_dtype == torch.float16:
|
| self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
|
| if isinstance(self.grad_clip, float):
|
| grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
|
| else:
|
| grad_norm = self.grad_clip(self.master_params)
|
| if torch.isfinite(grad_norm):
|
| statuses[-1]['grad_norm'] = grad_norm.item()
|
|
|
| if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16:
|
| prev_scale = self.scaler.get_scale()
|
| self.scaler.step(self.optimizer)
|
| self.scaler.update()
|
| elif self.mix_precision_mode == 'inflat_all':
|
| if self.mix_precision_dtype == torch.float16:
|
| prev_scale = 2 ** self.log_scale
|
| if not any(not p.grad.isfinite().all() for p in self.model_params):
|
| if self.grad_clip is None:
|
| model_grads_to_master_grads(self.model_params, self.master_params)
|
| self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
|
| self.optimizer.step()
|
| master_params_to_model_params(self.model_params, self.master_params)
|
| self.log_scale += self.fp16_scale_growth
|
| else:
|
| self.log_scale -= 1
|
| else:
|
| prev_scale = 1.0
|
| if self.grad_clip is None:
|
| model_grads_to_master_grads(self.model_params, self.master_params)
|
| if not any(not p.grad.isfinite().all() for p in self.master_params):
|
| self.optimizer.step()
|
| master_params_to_model_params(self.model_params, self.master_params)
|
| else:
|
| print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
|
| else:
|
| prev_scale = 1.0
|
| if not any(not p.grad.isfinite().all() for p in self.model_params):
|
| self.optimizer.step()
|
| else:
|
| print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
|
|
|
| if self.lr_scheduler_config is not None:
|
| statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
|
| self.lr_scheduler.step()
|
|
|
|
|
| step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
|
| step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
|
| if self.elastic_controller_config is not None:
|
| step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
|
| if self.grad_clip is not None:
|
| step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
|
|
|
|
|
| if self.log_param_stats:
|
| param_norms = {}
|
| param_grads = {}
|
| for model_name, model in self.models.items():
|
| for name, param in model.named_parameters():
|
| if param.requires_grad:
|
| param_norms[f'{model_name}.{name}'] = param.norm().item()
|
| if param.grad is not None and torch.isfinite(param.grad).all():
|
| param_grads[f'{model_name}.{name}'] = param.grad.norm().item() / prev_scale
|
| step_log['param_norms'] = param_norms
|
| step_log['param_grads'] = param_grads
|
|
|
|
|
| if self.is_master:
|
| self.update_ema()
|
|
|
| return step_log
|
|
|
| def save_logs(self):
|
| log_str = '\n'.join([
|
| f'{step}: {json.dumps(dict_foreach(log, lambda x: float(x)))}' for step, log in self.log
|
| ])
|
| with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
|
| log_file.write(log_str + '\n')
|
|
|
|
|
| log_show = [l for _, l in self.log if not dict_any(l, lambda x: np.isnan(x))]
|
| log_show = dict_reduce(log_show, lambda x: np.mean(x))
|
| log_show = dict_flatten(log_show, sep='/')
|
| for key, value in log_show.items():
|
| self.writer.add_scalar(key, value, self.step)
|
| self.log = []
|
|
|
| def check_abort(self):
|
| """
|
| Check if training should be aborted due to certain conditions.
|
| """
|
|
|
| if self.mix_precision_dtype == torch.float16 and \
|
| self.mix_precision_mode == 'inflat_all' and \
|
| self.log_scale < 0:
|
| if self.is_master:
|
| print ('\n\n\033[91m')
|
| print (f'ABORT: log_scale in inflat_all mode is less than 0 at step {self.step}.')
|
| print ('This indicates that the model is diverging. You should look into the model and the data.')
|
| print ('\033[0m')
|
| self.save(non_blocking=False)
|
| self.save_logs()
|
| if self.world_size > 1:
|
| dist.barrier()
|
| raise ValueError('ABORT: log_scale in inflat_all mode is less than 0.')
|
|
|
| def run(self):
|
| """
|
| Run training.
|
| """
|
| if self.is_master:
|
| print('\nStarting training...')
|
| self.snapshot_dataset(batch_size=self.snapshot_batch_size)
|
| if self.step == 0:
|
| self.snapshot(suffix='init', batch_size=self.snapshot_batch_size)
|
| else:
|
| self.snapshot(suffix=f'resume_step{self.step:07d}', batch_size=self.snapshot_batch_size)
|
|
|
| time_last_print = 0.0
|
| time_elapsed = 0.0
|
| while self.step < self.max_steps:
|
| time_start = time.time()
|
|
|
| data_list = self.load_data()
|
| step_log = self.run_step(data_list)
|
|
|
| time_end = time.time()
|
| time_elapsed += time_end - time_start
|
|
|
| self.step += 1
|
|
|
|
|
| if self.is_master and self.step % self.i_print == 0:
|
| speed = self.i_print / (time_elapsed - time_last_print) * 3600
|
| columns = [
|
| f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
|
| f'Elapsed: {time_elapsed / 3600:.2f} h',
|
| f'Speed: {speed:.2f} steps/h',
|
| f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
|
| ]
|
| print(' | '.join([c.ljust(25) for c in columns]), flush=True)
|
| time_last_print = time_elapsed
|
|
|
|
|
| if self.parallel_mode == 'ddp' and self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
|
| self.check_ddp()
|
|
|
|
|
| if self.step % self.i_sample == 0:
|
| self.snapshot()
|
|
|
| if self.is_master:
|
| self.log.append((self.step, {}))
|
|
|
|
|
| self.log[-1][1]['time'] = {
|
| 'step': time_end - time_start,
|
| 'elapsed': time_elapsed,
|
| }
|
|
|
|
|
| if step_log is not None:
|
| self.log[-1][1].update(step_log)
|
|
|
|
|
| if self.mix_precision_dtype == torch.float16:
|
| if self.mix_precision_mode == 'amp':
|
| self.log[-1][1]['scale'] = self.scaler.get_scale()
|
| elif self.mix_precision_mode == 'inflat_all':
|
| self.log[-1][1]['log_scale'] = self.log_scale
|
|
|
|
|
| if self.step % self.i_log == 0:
|
| self.save_logs()
|
|
|
|
|
| if self.step % self.i_save == 0:
|
| self.save()
|
|
|
|
|
| self.check_abort()
|
|
|
| self.snapshot(suffix='final', batch_size=self.snapshot_batch_size)
|
| if self.world_size > 1:
|
| dist.barrier()
|
| if self.is_master:
|
| self.writer.close()
|
| print('Training finished.')
|
|
|
| def profile(self, wait=2, warmup=3, active=5):
|
| """
|
| Profile the training loop.
|
| """
|
| with torch.profiler.profile(
|
| schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
|
| on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
|
| profile_memory=True,
|
| with_stack=True,
|
| ) as prof:
|
| for _ in range(wait + warmup + active):
|
| self.run_step()
|
| prof.step()
|
|
|