| | |
| |
|
| | """Evaluation utilities.""" |
| |
|
| | import os |
| | from functools import partial |
| |
|
| | import torch |
| |
|
| | from megatron import get_args |
| | from megatron import print_rank_0, print_rank_last |
| | from megatron.core import mpu |
| | from megatron.schedules import get_forward_backward_func |
| | from tasks.vision.finetune_utils import build_data_loader |
| | from tasks.vision.finetune_utils import process_batch |
| | from torchvision import datasets, transforms |
| |
|
| |
|
| | def accuracy_func_provider(): |
| | """Provide function that calculates accuracies.""" |
| | args = get_args() |
| | data_path = args.data_path |
| | crop_size = (args.img_h, args.img_w) |
| |
|
| | |
| | val_data_path = data_path[1] |
| | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| | transform_val = transforms.Compose( |
| | [ |
| | transforms.Resize(crop_size), |
| | transforms.CenterCrop(crop_size), |
| | transforms.ToTensor(), |
| | normalize, |
| | ] |
| | ) |
| | dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val) |
| |
|
| | dataloader = build_data_loader( |
| | dataset, |
| | args.micro_batch_size, |
| | num_workers=args.num_workers, |
| | drop_last=(mpu.get_data_parallel_world_size() > 1), |
| | shuffle=False |
| | ) |
| |
|
| | def metrics_func(model, epoch): |
| | print_rank_0("calculating metrics ...") |
| | correct, total = calculate_correct_answers(model, dataloader, epoch) |
| | percent = float(correct) * 100.0 / float(total) |
| | print_rank_last( |
| | " >> |epoch: {}| overall: correct / total = {} / {} = " |
| | "{:.4f} %".format(epoch, correct, total, percent) |
| | ) |
| |
|
| | return metrics_func |
| |
|
| |
|
| | def calculate_correct_answers(model, dataloader, epoch): |
| | """Calculate correct over total answers""" |
| |
|
| | forward_backward_func = get_forward_backward_func() |
| | for m in model: |
| | m.eval() |
| |
|
| | def loss_func(labels, output_tensor): |
| | logits = output_tensor |
| |
|
| | loss_dict = {} |
| | |
| | predicted = torch.argmax(logits, dim=-1) |
| | corrects = (predicted == labels).float() |
| | |
| | loss_dict['total'] = labels.size(0) |
| | loss_dict['correct'] = corrects.sum().item() |
| |
|
| | return 0, loss_dict |
| |
|
| | |
| | def correct_answers_forward_step(batch, model): |
| | try: |
| | batch_ = next(batch) |
| | except BaseException: |
| | batch_ = batch |
| | images, labels = process_batch(batch_) |
| |
|
| | |
| | output_tensor = model(images) |
| |
|
| | return output_tensor, partial(loss_func, labels) |
| |
|
| | with torch.no_grad(): |
| | |
| | total = 0 |
| | correct = 0 |
| | for _, batch in enumerate(dataloader): |
| |
|
| | loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, |
| | optimizer=None, timers=None, forward_only=True) |
| |
|
| | for loss_dict in loss_dicts: |
| | total += loss_dict['total'] |
| | correct += loss_dict['correct'] |
| |
|
| | for m in model: |
| | m.train() |
| |
|
| | |
| | if mpu.is_pipeline_last_stage(): |
| | unreduced = torch.cuda.LongTensor([correct, total]) |
| | torch.distributed.all_reduce(unreduced, |
| | group=mpu.get_data_parallel_group()) |
| |
|
| | |
| | correct_ans = unreduced[0].item() |
| | total_count = unreduced[1].item() |
| | return correct_ans, total_count |
| |
|