| """ |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from ..misc import (MetricLogger, SmoothedValue, reduce_dict) |
|
|
|
|
| def train_one_epoch(model: nn.Module, criterion: nn.Module, dataloader, optimizer, ema, epoch, device): |
| """ |
| """ |
| model.train() |
|
|
| metric_logger = MetricLogger(delimiter=" ") |
| metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) |
| print_freq = 100 |
| header = 'Epoch: [{}]'.format(epoch) |
|
|
| for imgs, labels in metric_logger.log_every(dataloader, print_freq, header): |
| imgs = imgs.to(device) |
| labels = labels.to(device) |
|
|
| preds = model(imgs) |
| loss: torch.Tensor = criterion(preds, labels, epoch) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| if ema is not None: |
| ema.update(model) |
|
|
| loss_reduced_values = {k: v.item() for k, v in reduce_dict({'loss': loss}).items()} |
| metric_logger.update(**loss_reduced_values) |
| metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
|
|
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
|
|
| stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
| return stats |
|
|
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, criterion, dataloader, device): |
| model.eval() |
|
|
| metric_logger = MetricLogger(delimiter=" ") |
| |
| |
| metric_logger.add_meter('acc', SmoothedValue(window_size=1)) |
| metric_logger.add_meter('loss', SmoothedValue(window_size=1)) |
|
|
| header = 'Test:' |
| for imgs, labels in metric_logger.log_every(dataloader, 10, header): |
| imgs, labels = imgs.to(device), labels.to(device) |
| preds = model(imgs) |
|
|
| acc = (preds.argmax(dim=-1) == labels).sum() / preds.shape[0] |
| loss = criterion(preds, labels) |
|
|
| dict_reduced = reduce_dict({'acc': acc, 'loss': loss}) |
| reduced_values = {k: v.item() for k, v in dict_reduced.items()} |
| metric_logger.update(**reduced_values) |
|
|
| metric_logger.synchronize_between_processes() |
| print("Averaged stats:", metric_logger) |
|
|
| stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
| return stats |
|
|