import math from typing import Iterable, Optional import torch import utils from scipy.special import softmax from sklearn.metrics import accuracy_score, average_precision_score from timm.data import Mixup from timm.utils import ModelEma, accuracy from utils import adjust_learning_rate def train_one_epoch( model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, args=None, ): model.train(True) metric_logger = utils.MetricLogger(delimiter=' ') metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 100 update_freq = args.update_freq use_amp = args.use_amp optimizer.zero_grad() for data_iter_step, (samples, targets) in enumerate( metric_logger.log_every(data_loader, print_freq, header) ): # we use a per iteration (instead of per epoch) lr scheduler if data_iter_step % update_freq == 0: adjust_learning_rate( optimizer, data_iter_step / len(data_loader) + epoch, args ) samples = samples.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) if use_amp: with torch.cuda.amp.autocast(): output = model(samples) loss = criterion(output, targets) else: # full precision output = model(samples) loss = criterion(output, targets) loss_value = loss.item() if not math.isfinite(loss_value): print('Loss is {}, stopping training'.format(loss_value)) assert math.isfinite(loss_value) if use_amp: # this attribute is added by timm on one optimizer (adahessian) is_second_order = ( hasattr(optimizer, 'is_second_order') and optimizer.is_second_order ) loss /= update_freq grad_norm = loss_scaler( loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=is_second_order, update_grad=(data_iter_step + 1) % update_freq == 0, ) if (data_iter_step + 1) % update_freq == 0: optimizer.zero_grad() if model_ema is not None: model_ema.update(model) else: # full precision loss /= update_freq loss.backward() if (data_iter_step + 1) % update_freq == 0: optimizer.step() optimizer.zero_grad() if model_ema is not None: model_ema.update(model) torch.cuda.synchronize() if mixup_fn is None: class_acc = (output.max(-1)[-1] == targets).float().mean() else: class_acc = None metric_logger.update(loss=loss_value) metric_logger.update(class_acc=class_acc) min_lr = 10.0 max_lr = 0.0 for group in optimizer.param_groups: min_lr = min(min_lr, group['lr']) max_lr = max(max_lr, group['lr']) metric_logger.update(lr=max_lr) metric_logger.update(min_lr=min_lr) weight_decay_value = None for group in optimizer.param_groups: if group['weight_decay'] > 0: weight_decay_value = group['weight_decay'] metric_logger.update(weight_decay=weight_decay_value) if use_amp: metric_logger.update(grad_norm=grad_norm) if log_writer is not None: log_writer.update(loss=loss_value, head='loss') log_writer.update(class_acc=class_acc, head='loss') log_writer.update(lr=max_lr, head='opt') log_writer.update(min_lr=min_lr, head='opt') log_writer.update(weight_decay=weight_decay_value, head='opt') if use_amp: log_writer.update(grad_norm=grad_norm, head='opt') log_writer.set_step() print('Averaged stats:', metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @torch.no_grad() def evaluate(data_loader, model, device, use_amp=False): criterion = torch.nn.CrossEntropyLoss() metric_logger = utils.MetricLogger(delimiter=' ') header = 'Test:' # switch to evaluation mode model.eval() predictions = [] labels = [] for index, batch in enumerate(metric_logger.log_every(data_loader, 10, header)): images = batch[0] target = batch[-1] images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) # compute output if use_amp: with torch.cuda.amp.autocast(dtype=torch.bfloat16): output = model(images) if isinstance(output, dict): output = output['logits'] loss = criterion(output, target) else: output = model(images) # [bs, num_cls] if isinstance(output, dict): output = output['logits'] loss = criterion(output, target) predictions.append(output) labels.append(target) torch.cuda.synchronize() acc1, acc5 = accuracy(output, target, topk=(1, 2)) batch_size = images.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) print( '* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'.format( top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss ) ) # Concatenate predictions and labels predictions = torch.cat(predictions, dim=0) labels = torch.cat(labels, dim=0) y_pred = softmax(predictions.detach().cpu().numpy(), axis=1)[:, 1] y_true = labels.detach().cpu().numpy() y_true = y_true.astype(int) acc = accuracy_score(y_true, y_pred > 0.5) r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5) f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5) ap = average_precision_score(y_true, y_pred) return ( {k: meter.global_avg for k, meter in metric_logger.meters.items()}, acc, ap, r_acc, f_acc, )