TheKernel01's picture
Sync from GitHub via hub-sync
0788e19 verified
import math
from typing import Iterable, Optional
import torch
import utils
from scipy.special import softmax
from sklearn.metrics import accuracy_score, average_precision_score
from timm.data import Mixup
from timm.utils import ModelEma, accuracy
from utils import adjust_learning_rate
def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
loss_scaler,
max_norm: float = 0,
model_ema: Optional[ModelEma] = None,
mixup_fn: Optional[Mixup] = None,
log_writer=None,
args=None,
):
model.train(True)
metric_logger = utils.MetricLogger(delimiter=' ')
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 100
update_freq = args.update_freq
use_amp = args.use_amp
optimizer.zero_grad()
for data_iter_step, (samples, targets) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)
):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % update_freq == 0:
adjust_learning_rate(
optimizer, data_iter_step / len(data_loader) + epoch, args
)
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
if use_amp:
with torch.cuda.amp.autocast():
output = model(samples)
loss = criterion(output, targets)
else: # full precision
output = model(samples)
loss = criterion(output, targets)
loss_value = loss.item()
if not math.isfinite(loss_value):
print('Loss is {}, stopping training'.format(loss_value))
assert math.isfinite(loss_value)
if use_amp:
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = (
hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
)
loss /= update_freq
grad_norm = loss_scaler(
loss,
optimizer,
clip_grad=max_norm,
parameters=model.parameters(),
create_graph=is_second_order,
update_grad=(data_iter_step + 1) % update_freq == 0,
)
if (data_iter_step + 1) % update_freq == 0:
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
else: # full precision
loss /= update_freq
loss.backward()
if (data_iter_step + 1) % update_freq == 0:
optimizer.step()
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
if mixup_fn is None:
class_acc = (output.max(-1)[-1] == targets).float().mean()
else:
class_acc = None
metric_logger.update(loss=loss_value)
metric_logger.update(class_acc=class_acc)
min_lr = 10.0
max_lr = 0.0
for group in optimizer.param_groups:
min_lr = min(min_lr, group['lr'])
max_lr = max(max_lr, group['lr'])
metric_logger.update(lr=max_lr)
metric_logger.update(min_lr=min_lr)
weight_decay_value = None
for group in optimizer.param_groups:
if group['weight_decay'] > 0:
weight_decay_value = group['weight_decay']
metric_logger.update(weight_decay=weight_decay_value)
if use_amp:
metric_logger.update(grad_norm=grad_norm)
if log_writer is not None:
log_writer.update(loss=loss_value, head='loss')
log_writer.update(class_acc=class_acc, head='loss')
log_writer.update(lr=max_lr, head='opt')
log_writer.update(min_lr=min_lr, head='opt')
log_writer.update(weight_decay=weight_decay_value, head='opt')
if use_amp:
log_writer.update(grad_norm=grad_norm, head='opt')
log_writer.set_step()
print('Averaged stats:', metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device, use_amp=False):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = utils.MetricLogger(delimiter=' ')
header = 'Test:'
# switch to evaluation mode
model.eval()
predictions = []
labels = []
for index, batch in enumerate(metric_logger.log_every(data_loader, 10, header)):
images = batch[0]
target = batch[-1]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
if use_amp:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model(images)
if isinstance(output, dict):
output = output['logits']
loss = criterion(output, target)
else:
output = model(images) # [bs, num_cls]
if isinstance(output, dict):
output = output['logits']
loss = criterion(output, target)
predictions.append(output)
labels.append(target)
torch.cuda.synchronize()
acc1, acc5 = accuracy(output, target, topk=(1, 2))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
print(
'* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'.format(
top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss
)
)
# Concatenate predictions and labels
predictions = torch.cat(predictions, dim=0)
labels = torch.cat(labels, dim=0)
y_pred = softmax(predictions.detach().cpu().numpy(), axis=1)[:, 1]
y_true = labels.detach().cpu().numpy()
y_true = y_true.astype(int)
acc = accuracy_score(y_true, y_pred > 0.5)
r_acc = accuracy_score(y_true[y_true == 0], y_pred[y_true == 0] > 0.5)
f_acc = accuracy_score(y_true[y_true == 1], y_pred[y_true == 1] > 0.5)
ap = average_precision_score(y_true, y_pred)
return (
{k: meter.global_avg for k, meter in metric_logger.meters.items()},
acc,
ap,
r_acc,
f_acc,
)