| |
|
|
| """Evaluation utilities.""" |
|
|
| import os |
| import time |
| from functools import partial |
|
|
| import torch |
|
|
| from megatron.training import get_args |
| from megatron.training import print_rank_last, is_last_rank |
| from megatron.core import mpu |
| from megatron.schedules import get_forward_backward_func |
| from tasks.finetune_utils import build_data_loader |
| from tasks.finetune_utils import process_batch |
|
|
|
|
| def accuracy_func_provider(single_dataset_provider): |
| """Provide function that calculates accuracies.""" |
| args = get_args() |
|
|
| |
| datapaths = args.valid_data |
| dataloaders = [] |
| for datapath in datapaths: |
| dataset = single_dataset_provider(datapath) |
| dataloader = build_data_loader( |
| dataset, args.orig_micro_batch_size, num_workers=args.num_workers, |
| drop_last=(mpu.get_data_parallel_world_size() > 1)) |
| dataloaders.append((dataset.dataset_name, dataloader)) |
|
|
| def metrics_func(model, epoch, output_predictions=False): |
| print_rank_last('calculating metrics ...') |
| correct = 0 |
| total = 0 |
| if output_predictions: |
| assert mpu.get_data_parallel_world_size() == 1 |
| named_predictions = [] |
| names = 'predictions' |
| for name, dataloader in dataloaders: |
| output = calculate_correct_answers(name, model, dataloader, |
| epoch, output_predictions) |
| if not output_predictions: |
| correct_ans, total_count = output |
| else: |
| correct_ans, total_count, predictions = output |
| named_predictions.append((name, predictions)) |
| names += '_' + name |
| correct += correct_ans |
| total += total_count |
| if is_last_rank(): |
| percent = float(correct) * 100.0 / float(total) |
| print(' >> |epoch: {}| overall: correct / total = {} / {} = ' |
| '{:.4f} %'.format(epoch, correct, total, percent)) |
|
|
| if output_predictions and is_last_rank(): |
| assert args.load is not None |
| filename = os.path.join(args.load, names + '.pt') |
| torch.save(named_predictions, filename) |
|
|
| return metrics_func |
|
|
|
|
| def calculate_correct_answers(name, model, dataloader, |
| epoch, output_predictions): |
| """Calculate correct over total answers and return prediction if the |
| `output_predictions` is true.""" |
| args = get_args() |
| forward_backward_func = get_forward_backward_func() |
| start_time = time.time() |
| for m in model: |
| m.eval() |
| saved_micro_batch_size = args.micro_batch_size |
| saved_global_batch_size = args.global_batch_size |
|
|
| ds = dataloader.dataset |
| if hasattr(ds, 'sample_multiplier'): |
| |
| |
| |
| |
| |
| sample_multiplier = ds.sample_multiplier |
| else: |
| sample_multiplier = 1 |
| micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size |
| num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel |
|
|
| def loss_func(output_predictions, labels, output_tensor): |
| logits = output_tensor |
|
|
| loss_dict = {} |
| |
| if output_predictions: |
| assert False |
| loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)( |
| logits.float()).data.cpu().numpy().tolist() |
| loss_dict['labels'] = labels.data.cpu().numpy().tolist() |
| loss_dict['ids'] = batch['uid'].cpu().numpy().tolist() |
| |
| predicted = torch.argmax(logits, dim=-1) |
| corrects = (predicted == labels) |
| |
| loss_dict['total'] = labels.size(0) |
| loss_dict['correct'] = corrects.sum().item() |
|
|
| return 0, loss_dict |
|
|
| |
| def correct_answers_forward_step(batch, model): |
| try: |
| batch_ = next(batch) |
| except Exception: |
| batch_ = batch |
| tokens, types, labels, attention_mask = process_batch(batch_) |
|
|
| |
| args = get_args() |
| output_tensor = model(tokens, attention_mask, tokentype_ids=types) |
|
|
| return output_tensor, partial(loss_func, output_predictions, labels) |
|
|
| with torch.no_grad(): |
| |
| total = 0 |
| correct = 0 |
| if output_predictions: |
| |
| assert mpu.get_data_parallel_world_size() == 1 |
| softmaxes = [] |
| labels = [] |
| ids = [] |
| for _, batch in enumerate(dataloader): |
| |
| |
| |
| actual_batch_size = len(batch['label']) |
| |
| args.micro_batch_size = actual_batch_size * sample_multiplier |
| args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches |
|
|
| loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, |
| optimizer=None, timers=None, forward_only=True) |
|
|
| for loss_dict in loss_dicts: |
| if output_predictions: |
| softmaxes.extend(loss_dict['softmaxes']) |
| labels.extend(loss_dict['labels']) |
| ids.extend(loss_dict['ids']) |
| total += loss_dict['total'] |
| correct += loss_dict['correct'] |
|
|
|
|
| for m in model: |
| m.train() |
| args.micro_batch_size = saved_micro_batch_size |
| args.global_batch_size = saved_global_batch_size |
|
|
| |
| if mpu.is_pipeline_last_stage(): |
| unreduced = torch.tensor([correct, total], dtype=torch.long, device='cuda') |
| torch.distributed.all_reduce(unreduced, |
| group=mpu.get_data_parallel_group()) |
|
|
| |
|
|
| correct_ans = unreduced[0].item() |
| total_count = unreduced[1].item() |
| percent = float(correct_ans) * 100.0 / float(total_count) |
| elapsed_time = time.time() - start_time |
| print_rank_last(' > |epoch: {}| metrics for {}: correct / total ' |
| '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format( |
| epoch, name, correct_ans, total_count, |
| percent, elapsed_time)) |
|
|
| if output_predictions: |
| return correct_ans, total_count, (softmaxes, labels, ids) |
| return correct_ans, total_count |
| if output_predictions: |
| return 0, 0, () |
| return 0, 0 |
|
|