| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Finetune utilities.""" |
|
|
| from functools import partial |
| import sys |
| import torch |
|
|
| from megatron import get_args, get_num_microbatches |
| from megatron import print_rank_0 |
| from megatron import get_timers |
| from megatron import mpu |
| from megatron.checkpointing import load_checkpoint |
| from megatron.checkpointing import save_checkpoint |
| from megatron.model import ModelType |
| from megatron.training import evaluate_and_print_results |
| from megatron.training import setup_model_and_optimizer |
| from megatron.training import train_step |
| from megatron.training import training_log |
| from megatron.utils import average_losses_across_data_parallel_group |
| from megatron.utils import calc_params_l2_norm |
| from megatron.utils import check_adlr_autoresume_termination |
|
|
|
|
| def process_batch(batch): |
| """Process batch and produce inputs for the model.""" |
| args = get_args() |
|
|
| tokens = batch['text'].long().cuda().contiguous() |
| types = batch['types'].long().cuda().contiguous() |
| labels = batch['label'].long().cuda().contiguous() |
| attention_mask = batch['padding_mask'].float().cuda().contiguous() |
| if args.fp16: |
| attention_mask = attention_mask.half() |
|
|
| return tokens, types, labels, attention_mask |
|
|
|
|
| def cross_entropy_loss_func(labels, output_tensor): |
| logits = output_tensor |
|
|
| |
| loss_func = torch.nn.CrossEntropyLoss() |
| loss = loss_func(logits.contiguous().float(), labels) |
|
|
| |
| averaged_loss = average_losses_across_data_parallel_group([loss]) |
|
|
| return loss, {'training loss': averaged_loss[0]} |
|
|
|
|
| def _cross_entropy_forward_step(batch, model): |
| """Simple forward step with cross-entropy loss.""" |
| timers = get_timers() |
|
|
| |
| timers('batch-generator').start() |
| try: |
| batch_ = next(batch) |
| except BaseException: |
| batch_ = batch |
| tokens, types, labels, attention_mask = process_batch(batch_) |
| timers('batch-generator').stop() |
|
|
| |
| output_tensor = model(tokens, attention_mask, tokentype_ids=types) |
|
|
| return output_tensor, partial(cross_entropy_loss_func, labels) |
|
|
|
|
| def build_data_loader(dataset, micro_batch_size, num_workers, drop_last, |
| task_collate_fn=None): |
| """Data loader. Note that batch-size is the local (per GPU) batch-size.""" |
|
|
| |
| world_size = mpu.get_data_parallel_world_size() |
| rank = mpu.get_data_parallel_rank() |
| sampler = torch.utils.data.distributed.DistributedSampler( |
| dataset, num_replicas=world_size, rank=rank) |
|
|
| |
| data_loader = torch.utils.data.DataLoader(dataset, |
| batch_size=micro_batch_size, |
| sampler=sampler, |
| shuffle=False, |
| num_workers=num_workers, |
| drop_last=drop_last, |
| pin_memory=True, |
| collate_fn=task_collate_fn) |
|
|
| return data_loader |
|
|
|
|
| def _build_infinite_size_dataloader(dataloader): |
| """Build a looped dataloader with infinite size.""" |
|
|
| iterator = dataloader.__iter__() |
| while True: |
| try: |
| yield iterator.__next__() |
| except StopIteration: |
| iterator = dataloader.__iter__() |
|
|
|
|
| def _build_train_valid_dataloaders(train_dataset, valid_dataset, |
| task_collate_fn=None): |
| """Traing and validation dataloaders.""" |
| args = get_args() |
|
|
| print_rank_0('building train and validation dataloaders ...') |
| |
| train_dataloader = build_data_loader(train_dataset, args.micro_batch_size, |
| args.num_workers, not args.keep_last, |
| task_collate_fn) |
| |
| args.train_iters_per_epoch = len(train_dataloader) |
| args.train_iters = args.epochs * args.train_iters_per_epoch |
| |
| |
| valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size, |
| args.num_workers, not args.keep_last, |
| task_collate_fn) |
| valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) |
|
|
| |
| |
| |
| |
| |
| args.orig_micro_batch_size = args.micro_batch_size |
| args.orig_global_batch_size = args.global_batch_size |
| if hasattr(train_dataset, 'sample_multiplier'): |
| |
| |
| |
| |
| |
| args.micro_batch_size *= train_dataset.sample_multiplier |
| args.global_batch_size *= train_dataset.sample_multiplier |
|
|
| return train_dataloader, valid_dataloader |
|
|
|
|
| def _train(model, optimizer, opt_param_scheduler, forward_step, |
| train_dataloader, valid_dataloader, end_of_epoch_callback): |
| """Train the model.""" |
| args = get_args() |
| timers = get_timers() |
|
|
| assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work" |
|
|
| |
| for m in model: |
| m.train() |
|
|
| |
| losses_dict_sum = {} |
|
|
| |
| start_epoch = args.iteration // args.train_iters_per_epoch |
| start_iteration = args.iteration % args.train_iters_per_epoch |
| iteration = args.iteration |
|
|
| |
| report_memory_flag = True |
|
|
| |
| timers('interval-time').start() |
| for epoch in range(start_epoch, args.epochs): |
| print_rank_0('working on epoch {} ...'.format(epoch + 1)) |
|
|
| |
| train_dataloader.sampler.set_epoch(args.seed + epoch) |
|
|
| |
| for iteration_, batch in enumerate(train_dataloader): |
|
|
| |
| if iteration_ < start_iteration: |
| continue |
| |
| start_iteration = 0 |
|
|
| |
| out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler) |
|
|
| losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out |
| iteration += 1 |
|
|
| |
| params_norm = None |
| if args.log_params_norm: |
| params_norm = calc_params_l2_norm(model) |
| report_memory_flag = training_log(losses_dict, losses_dict_sum, |
| optimizer.param_groups[0]['lr'], |
| iteration, |
| optimizer.get_loss_scale().item(), |
| report_memory_flag, skipped_iter, |
| grad_norm, params_norm, num_zeros_in_grad, None) |
|
|
| |
| if args.adlr_autoresume and \ |
| (iteration % args.adlr_autoresume_interval == 0): |
| check_adlr_autoresume_termination(iteration, model, |
| optimizer, opt_param_scheduler) |
|
|
| |
| saved_checkpoint = False |
| if args.save and args.save_interval and \ |
| iteration % args.save_interval == 0: |
| save_checkpoint(iteration, model, optimizer, opt_param_scheduler) |
| saved_checkpoint = True |
|
|
| |
| if args.eval_interval and iteration % args.eval_interval == 0: |
| prefix = 'iteration {}'.format(iteration) |
| evaluate_and_print_results(prefix, forward_step, |
| valid_dataloader, model, |
| iteration, None, False) |
| if end_of_epoch_callback is not None: |
| end_of_epoch_callback(model, iteration) |
| print_rank_0('-' * 72 + '\n') |
| |
| |
| if args.exit_interval and iteration % args.exit_interval == 0: |
| if not saved_checkpoint: |
| save_checkpoint(iteration, model, optimizer, opt_param_scheduler) |
| torch.distributed.barrier() |
| print_rank_0('exiting program at iteration {}'.format(iteration)) |
| sys.exit() |
|
|
| |
| if args.save: |
| save_checkpoint(iteration, model, optimizer, opt_param_scheduler) |
| |
| prefix = 'iteration {}'.format(iteration) |
| evaluate_and_print_results(prefix, forward_step, |
| valid_dataloader, model, |
| iteration, None, False) |
| if end_of_epoch_callback is not None: |
| end_of_epoch_callback(model, iteration) |
| print_rank_0('-' * 72 + '\n') |
| |
| |
| |
| |
|
|
|
|
| def finetune(train_valid_datasets_provider, model_provider, |
| model_type=ModelType.encoder_or_decoder, |
| forward_step=_cross_entropy_forward_step, |
| end_of_epoch_callback_provider=None, |
| task_collate_fn=None): |
| """Main finetune function used across all tasks.""" |
| args = get_args() |
| timers = get_timers() |
|
|
| assert args.rampup_batch_size is None, \ |
| 'batch size scaling is not supported for finetuning' |
|
|
| |
| timers('train/valid/test dataset/dataloder').start() |
| if args.epochs > 0: |
| train_dataset, valid_dataset = train_valid_datasets_provider() |
| train_dataloader, valid_dataloader = _build_train_valid_dataloaders( |
| train_dataset, valid_dataset, task_collate_fn) |
| else: |
| args.train_iters = 0 |
| timers('train/valid/test dataset/dataloder').stop() |
|
|
| |
| timers('callback function').start() |
| end_of_epoch_callback = None |
| if end_of_epoch_callback_provider is not None: |
| end_of_epoch_callback = end_of_epoch_callback_provider() |
| timers('callback function').stop() |
|
|
| |
| timers('model and optimizer').start() |
| model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type) |
| timers('model and optimizer').stop() |
|
|
| |
| |
| |
| timers('pretrained checkpoint').start() |
| if args.iteration == 0 and args.pretrained_checkpoint is not None: |
| original_load = args.load |
| args.load = args.pretrained_checkpoint |
| original_rng = args.no_load_rng |
| args.no_load_rng = True |
| _ = load_checkpoint(model, None, None) |
| args.load = original_load |
| args.no_load_rng = original_rng |
| |
| |
| optimizer.reload_model_params() |
| timers('pretrained checkpoint').stop() |
|
|
| |
| print_rank_0('done with setups ...') |
| timers.log(['train/valid/test dataset/dataloder', 'callback function', |
| 'model and optimizer', 'pretrained checkpoint']) |
| print_rank_0('training ...') |
|
|
| |
| if args.epochs > 0: |
| _train(model, optimizer, opt_param_scheduler, forward_step, |
| train_dataloader, valid_dataloader, end_of_epoch_callback) |
| |
| else: |
| print_rank_0("Not Imp") |
| import pdb;pdb.set_trace() |
| |
| |
| |
| print_rank_0('done :-)') |
|
|