Spaces:
Sleeping
Sleeping
| import math | |
| from typing import Iterable, Optional | |
| import torch | |
| import utils | |
| from scipy.special import softmax | |
| from sklearn.metrics import accuracy_score, average_precision_score | |
| from timm.data import Mixup | |
| from timm.utils import ModelEma, accuracy | |
| from utils import adjust_learning_rate | |
| def train_one_epoch( | |
| model: torch.nn.Module, | |
| criterion: torch.nn.Module, | |
| data_loader: Iterable, | |
| optimizer: torch.optim.Optimizer, | |
| device: torch.device, | |
| epoch: int, | |
| loss_scaler, | |
| max_norm: float = 0, | |
| model_ema: Optional[ModelEma] = None, | |
| mixup_fn: Optional[Mixup] = None, | |
| log_writer=None, | |
| args=None, | |
| ): | |
| model.train(True) | |
| metric_logger = utils.MetricLogger(delimiter=' ') | |
| metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) | |
| header = 'Epoch: [{}]'.format(epoch) | |
| print_freq = 100 | |
| update_freq = args.update_freq | |
| use_amp = args.use_amp | |
| optimizer.zero_grad() | |
| for data_iter_step, (samples, targets) in enumerate( | |
| metric_logger.log_every(data_loader, print_freq, header) | |
| ): | |
| # we use a per iteration (instead of per epoch) lr scheduler | |
| if data_iter_step % update_freq == 0: | |
| adjust_learning_rate( | |
| optimizer, data_iter_step / len(data_loader) + epoch, args | |
| ) | |
| samples = samples.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| if mixup_fn is not None: | |
| samples, targets = mixup_fn(samples, targets) | |
| if use_amp: | |
| with torch.cuda.amp.autocast(): | |
| output = model(samples) | |
| loss = criterion(output, targets) | |
| else: # full precision | |
| output = model(samples) | |
| loss = criterion(output, targets) | |
| loss_value = loss.item() | |
| if not math.isfinite(loss_value): | |
| print('Loss is {}, stopping training'.format(loss_value)) | |
| assert math.isfinite(loss_value) | |
| if use_amp: | |
| # this attribute is added by timm on one optimizer (adahessian) | |
| is_second_order = ( | |
| hasattr(optimizer, 'is_second_order') and optimizer.is_second_order | |
| ) | |
| loss /= update_freq | |
| grad_norm = loss_scaler( | |
| loss, | |
| optimizer, | |
| clip_grad=max_norm, | |
| parameters=model.parameters(), | |
| create_graph=is_second_order, | |
| update_grad=(data_iter_step + 1) % update_freq == 0, | |
| ) | |
| if (data_iter_step + 1) % update_freq == 0: | |
| optimizer.zero_grad() | |
| if model_ema is not None: | |
| model_ema.update(model) | |
| else: # full precision | |
| loss /= update_freq | |
| loss.backward() | |
| if (data_iter_step + 1) % update_freq == 0: | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if model_ema is not None: | |
| model_ema.update(model) | |
| torch.cuda.synchronize() | |
| if mixup_fn is None: | |
| class_acc = (output.max(-1)[-1] == targets).float().mean() | |
| else: | |
| class_acc = None | |
| metric_logger.update(loss=loss_value) | |
| metric_logger.update(class_acc=class_acc) | |
| min_lr = 10.0 | |
| max_lr = 0.0 | |
| for group in optimizer.param_groups: | |
| min_lr = min(min_lr, group['lr']) | |
| max_lr = max(max_lr, group['lr']) | |
| metric_logger.update(lr=max_lr) | |
| metric_logger.update(min_lr=min_lr) | |
| weight_decay_value = None | |
| for group in optimizer.param_groups: | |
| if group['weight_decay'] > 0: | |
| weight_decay_value = group['weight_decay'] | |
| metric_logger.update(weight_decay=weight_decay_value) | |
| if use_amp: | |
| metric_logger.update(grad_norm=grad_norm) | |
| if log_writer is not None: | |
| log_writer.update(loss=loss_value, head='loss') | |
| log_writer.update(class_acc=class_acc, head='loss') | |
| log_writer.update(lr=max_lr, head='opt') | |
| log_writer.update(min_lr=min_lr, head='opt') | |
| log_writer.update(weight_decay=weight_decay_value, head='opt') | |
| if use_amp: | |
| log_writer.update(grad_norm=grad_norm, head='opt') | |
| log_writer.set_step() | |
| print('Averaged stats:', metric_logger) | |
| return {k: meter.global_avg for k, meter in metric_logger.meters.items()} | |
| def evaluate(data_loader, model, device, use_amp=False): | |
| criterion = torch.nn.CrossEntropyLoss() | |
| metric_logger = utils.MetricLogger(delimiter=' ') | |
| header = 'Test:' | |
| # switch to evaluation mode | |
| model.eval() | |
| predictions = [] | |
| labels = [] | |
| for index, batch in enumerate(metric_logger.log_every(data_loader, 10, header)): | |
| images = batch[0] | |
| target = batch[-1] | |
| images = images.to(device, non_blocking=True) | |
| target = target.to(device, non_blocking=True) | |
| # compute output | |
| if use_amp: | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| output = model(images) | |
| if isinstance(output, dict): | |
| output = output['logits'] | |
| loss = criterion(output, target) | |
| else: | |
| output = model(images) # [bs, num_cls] | |
| if isinstance(output, dict): | |
| output = output['logits'] | |
| loss = criterion(output, target) | |
| predictions.append(output) | |
| labels.append(target) | |
| torch.cuda.synchronize() | |
| acc1, acc5 = accuracy(output, target, topk=(1, 2)) | |
| batch_size = images.shape[0] | |
| metric_logger.update(loss=loss.item()) | |
| metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) | |
| metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) | |
| print( | |
| '* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'.format( | |
| top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss | |
| ) | |
| ) | |
| # Concatenate predictions and labels | |
| predictions = torch.cat(predictions, dim=0) | |
| labels = torch.cat(labels, dim=0) | |
| y_pred = softmax(predictions.detach().cpu().numpy(), axis=1)[:, 1] | |
| y_true = labels.detach().cpu().numpy() | |
| y_true = y_true.astype(int) | |
| acc = accuracy_score(y_true, y_pred > 0.5) | |
| r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5) | |
| f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5) | |
| ap = average_precision_score(y_true, y_pred) | |
| return ( | |
| {k: meter.global_avg for k, meter in metric_logger.meters.items()}, | |
| acc, | |
| ap, | |
| r_acc, | |
| f_acc, | |
| ) | |