|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import datetime |
|
|
import os |
|
|
import random |
|
|
import subprocess |
|
|
import time |
|
|
from contextlib import suppress |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.backends.cudnn as cudnn |
|
|
import torch.distributed as dist |
|
|
from config import get_config |
|
|
from dataset import build_loader |
|
|
from logger import create_logger |
|
|
from lr_scheduler import build_scheduler |
|
|
from models import build_model |
|
|
from optimizer import build_optimizer |
|
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy |
|
|
from timm.utils import ApexScaler, AverageMeter, ModelEma, accuracy |
|
|
from utils import MyAverageMeter |
|
|
from utils import NativeScalerWithGradNormCount as NativeScaler |
|
|
from utils import (auto_resume_helper, get_grad_norm, load_checkpoint, |
|
|
load_ema_checkpoint, load_pretrained, reduce_tensor, |
|
|
save_checkpoint) |
|
|
|
|
|
try: |
|
|
from apex import amp |
|
|
|
|
|
has_apex = True |
|
|
except ImportError: |
|
|
has_apex = False |
|
|
|
|
|
|
|
|
has_native_amp = False |
|
|
try: |
|
|
if getattr(torch.cuda.amp, 'autocast') is not None: |
|
|
has_native_amp = True |
|
|
except AttributeError: |
|
|
pass |
|
|
|
|
|
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) |
|
|
|
|
|
|
|
|
def obsolete_torch_version(torch_version, version_threshold): |
|
|
return torch_version == 'parrots' or torch_version <= version_threshold |
|
|
|
|
|
|
|
|
def parse_option(): |
|
|
parser = argparse.ArgumentParser( |
|
|
'InternVL training and evaluation script', add_help=False) |
|
|
parser.add_argument('--cfg', |
|
|
type=str, |
|
|
required=True, |
|
|
metavar='FILE', |
|
|
help='path to config file') |
|
|
parser.add_argument( |
|
|
'--opts', |
|
|
help="Modify config options by adding 'KEY VALUE' pairs. ", |
|
|
default=None, |
|
|
nargs='+') |
|
|
|
|
|
|
|
|
parser.add_argument('--batch-size', |
|
|
type=int, |
|
|
help='batch size for single GPU') |
|
|
parser.add_argument('--dataset', |
|
|
type=str, |
|
|
help='dataset name', |
|
|
default=None) |
|
|
parser.add_argument('--data-path', type=str, help='path to dataset') |
|
|
parser.add_argument('--zip', |
|
|
action='store_true', |
|
|
help='use zipped dataset instead of folder dataset') |
|
|
parser.add_argument( |
|
|
'--cache-mode', |
|
|
type=str, |
|
|
default='part', |
|
|
choices=['no', 'full', 'part'], |
|
|
help='no: no cache, ' |
|
|
'full: cache all data, ' |
|
|
'part: sharding the dataset into nonoverlapping pieces and only cache one piece' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--pretrained', |
|
|
help= |
|
|
'pretrained weight from checkpoint, could be imagenet22k pretrained weight' |
|
|
) |
|
|
parser.add_argument('--resume', help='resume from checkpoint') |
|
|
parser.add_argument('--accumulation-steps', |
|
|
type=int, |
|
|
default=1, |
|
|
help='gradient accumulation steps') |
|
|
parser.add_argument( |
|
|
'--use-checkpoint', |
|
|
action='store_true', |
|
|
help='whether to use gradient checkpointing to save memory') |
|
|
parser.add_argument( |
|
|
'--amp-opt-level', |
|
|
type=str, |
|
|
default='O1', |
|
|
choices=['O0', 'O1', 'O2'], |
|
|
help='mixed precision opt level, if O0, no amp is used') |
|
|
parser.add_argument( |
|
|
'--output', |
|
|
default='work_dirs', |
|
|
type=str, |
|
|
metavar='PATH', |
|
|
help= |
|
|
'root of output folder, the full path is <output>/<model_name>/<tag> (default: output)' |
|
|
) |
|
|
parser.add_argument('--tag', help='tag of experiment') |
|
|
parser.add_argument('--eval', |
|
|
action='store_true', |
|
|
help='Perform evaluation only') |
|
|
parser.add_argument('--throughput', |
|
|
action='store_true', |
|
|
help='Test throughput only') |
|
|
parser.add_argument('--save-ckpt-num', default=1, type=int) |
|
|
parser.add_argument( |
|
|
'--use-zero', |
|
|
action='store_true', |
|
|
help='whether to use ZeroRedundancyOptimizer (ZeRO) to save memory') |
|
|
|
|
|
|
|
|
parser.add_argument('--local-rank', |
|
|
type=int, |
|
|
required=True, |
|
|
help='local rank for DistributedDataParallel') |
|
|
parser.add_argument('--launcher', |
|
|
choices=['pytorch', 'slurm'], |
|
|
default='pytorch') |
|
|
args, unparsed = parser.parse_known_args() |
|
|
config = get_config(args) |
|
|
|
|
|
return args, config |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def throughput(data_loader, model, logger): |
|
|
model.eval() |
|
|
|
|
|
for idx, (images, _) in enumerate(data_loader): |
|
|
images = images.cuda(non_blocking=True) |
|
|
batch_size = images.shape[0] |
|
|
for i in range(50): |
|
|
model(images) |
|
|
torch.cuda.synchronize() |
|
|
logger.info(f'throughput averaged with 30 times') |
|
|
tic1 = time.time() |
|
|
for i in range(30): |
|
|
model(images) |
|
|
torch.cuda.synchronize() |
|
|
tic2 = time.time() |
|
|
logger.info( |
|
|
f'batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}' |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
def main(config): |
|
|
|
|
|
dataset_train, dataset_val, dataset_test, data_loader_train, \ |
|
|
data_loader_val, data_loader_test, mixup_fn = build_loader(config) |
|
|
|
|
|
|
|
|
logger.info(f'Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}') |
|
|
model = build_model(config) |
|
|
model.cuda() |
|
|
logger.info(str(model)) |
|
|
|
|
|
|
|
|
optimizer = build_optimizer(config, model) |
|
|
|
|
|
if config.AMP_OPT_LEVEL != 'O0': |
|
|
config.defrost() |
|
|
if has_native_amp: |
|
|
config.native_amp = True |
|
|
use_amp = 'native' |
|
|
elif has_apex: |
|
|
config.apex_amp = True |
|
|
use_amp = 'apex' |
|
|
else: |
|
|
use_amp = None |
|
|
logger.warning( |
|
|
'Neither APEX or native Torch AMP is available, using float32. ' |
|
|
'Install NVIDA apex or upgrade to PyTorch 1.6') |
|
|
config.freeze() |
|
|
|
|
|
|
|
|
amp_autocast = suppress |
|
|
loss_scaler = None |
|
|
if config.AMP_OPT_LEVEL != 'O0': |
|
|
if use_amp == 'apex': |
|
|
model, optimizer = amp.initialize(model, |
|
|
optimizer, |
|
|
opt_level=config.AMP_OPT_LEVEL) |
|
|
loss_scaler = ApexScaler() |
|
|
if config.LOCAL_RANK == 0: |
|
|
logger.info( |
|
|
'Using NVIDIA APEX AMP. Training in mixed precision.') |
|
|
if use_amp == 'native': |
|
|
amp_autocast = torch.cuda.amp.autocast |
|
|
loss_scaler = NativeScaler() |
|
|
if config.LOCAL_RANK == 0: |
|
|
logger.info( |
|
|
'Using native Torch AMP. Training in mixed precision.') |
|
|
else: |
|
|
if config.LOCAL_RANK == 0: |
|
|
logger.info('AMP not enabled. Training in float32.') |
|
|
|
|
|
|
|
|
model = torch.nn.parallel.DistributedDataParallel( |
|
|
model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_without_ddp = model.module |
|
|
|
|
|
n_parameters = sum(p.numel() for p in model.parameters() |
|
|
if p.requires_grad) |
|
|
logger.info(f'number of params: {n_parameters}') |
|
|
if hasattr(model_without_ddp, 'flops'): |
|
|
flops = model_without_ddp.flops() |
|
|
logger.info(f'number of GFLOPs: {flops / 1e9}') |
|
|
|
|
|
|
|
|
lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) \ |
|
|
if not config.EVAL_MODE else None |
|
|
|
|
|
|
|
|
if config.AUG.MIXUP > 0.: |
|
|
|
|
|
criterion = SoftTargetCrossEntropy() |
|
|
elif config.MODEL.LABEL_SMOOTHING > 0.: |
|
|
criterion = LabelSmoothingCrossEntropy( |
|
|
smoothing=config.MODEL.LABEL_SMOOTHING) |
|
|
else: |
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
|
|
max_accuracy = 0.0 |
|
|
max_ema_accuracy = 0.0 |
|
|
|
|
|
if config.MODEL.RESUME == '' and config.TRAIN.AUTO_RESUME: |
|
|
resume_file = auto_resume_helper(config.OUTPUT) |
|
|
if resume_file: |
|
|
if config.MODEL.RESUME: |
|
|
logger.warning( |
|
|
f'auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}' |
|
|
) |
|
|
config.defrost() |
|
|
config.MODEL.RESUME = resume_file |
|
|
config.freeze() |
|
|
logger.info(f'auto resuming from {resume_file}') |
|
|
else: |
|
|
logger.info( |
|
|
f'no checkpoint found in {config.OUTPUT}, ignoring auto resume' |
|
|
) |
|
|
|
|
|
|
|
|
if config.MODEL.RESUME: |
|
|
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, |
|
|
lr_scheduler, loss_scaler, logger) |
|
|
|
|
|
if data_loader_val is not None: |
|
|
if config.DATA.DATASET == 'imagenet-real': |
|
|
filenames = dataset_val.filenames() |
|
|
filenames = [os.path.basename(item) for item in filenames] |
|
|
from dataset.imagenet_real import RealLabelsImagenet |
|
|
real_labels = RealLabelsImagenet(filenames, real_json='meta_data/real.json') |
|
|
acc1, acc5, loss = validate_real(config, data_loader_val, model, real_labels, amp_autocast=amp_autocast) |
|
|
logger.info( |
|
|
f'ReaL Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%' |
|
|
) |
|
|
else: |
|
|
acc1, acc5, loss = validate(config, data_loader_val, model, amp_autocast=amp_autocast) |
|
|
logger.info( |
|
|
f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%' |
|
|
) |
|
|
elif config.MODEL.PRETRAINED: |
|
|
load_pretrained(config, model_without_ddp, logger) |
|
|
if data_loader_val is not None: |
|
|
acc1, acc5, loss = validate(config, data_loader_val, model, amp_autocast=amp_autocast) |
|
|
logger.info( |
|
|
f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%' |
|
|
) |
|
|
|
|
|
|
|
|
model_ema = None |
|
|
if config.TRAIN.EMA.ENABLE: |
|
|
|
|
|
model_ema = ModelEma(model, decay=config.TRAIN.EMA.DECAY) |
|
|
print('Using EMA with decay = %.8f' % config.TRAIN.EMA.DECAY) |
|
|
if config.MODEL.RESUME: |
|
|
load_ema_checkpoint(config, model_ema, logger) |
|
|
if config.DATA.DATASET == 'imagenet-real': |
|
|
|
|
|
assert dist.get_world_size() == 1, 'imagenet-real should test with one gpu' |
|
|
filenames = dataset_val.filenames() |
|
|
filenames = [os.path.basename(item) for item in filenames] |
|
|
from dataset.imagenet_real import RealLabelsImagenet |
|
|
real_labels = RealLabelsImagenet(filenames, real_json='meta_data/real.json') |
|
|
acc1, acc5, loss = validate_real(config, data_loader_val, model_ema.ema, real_labels, |
|
|
amp_autocast=amp_autocast) |
|
|
logger.info( |
|
|
f'ReaL Accuracy of the ema network on the {len(dataset_val)} test images: {acc1:.1f}%' |
|
|
) |
|
|
else: |
|
|
acc1, acc5, loss = validate(config, data_loader_val, model_ema.ema, amp_autocast=amp_autocast) |
|
|
logger.info( |
|
|
f'Accuracy of the ema network on the {len(dataset_val)} test images: {acc1:.1f}%' |
|
|
) |
|
|
|
|
|
if config.THROUGHPUT_MODE: |
|
|
throughput(data_loader_val, model, logger) |
|
|
|
|
|
if config.EVAL_MODE: |
|
|
return |
|
|
|
|
|
|
|
|
logger.info('Start training') |
|
|
start_time = time.time() |
|
|
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): |
|
|
data_loader_train.sampler.set_epoch(epoch) |
|
|
|
|
|
train_one_epoch(config, |
|
|
model, |
|
|
criterion, |
|
|
data_loader_train, |
|
|
optimizer, |
|
|
epoch, |
|
|
mixup_fn, |
|
|
lr_scheduler, |
|
|
amp_autocast, |
|
|
loss_scaler, |
|
|
model_ema=model_ema) |
|
|
if (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)) and config.TRAIN.OPTIMIZER.USE_ZERO: |
|
|
optimizer.consolidate_state_dict(to=0) |
|
|
if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): |
|
|
save_checkpoint(config, |
|
|
epoch, |
|
|
model_without_ddp, |
|
|
max_accuracy, |
|
|
optimizer, |
|
|
lr_scheduler, |
|
|
loss_scaler, |
|
|
logger, |
|
|
model_ema=model_ema) |
|
|
if data_loader_val is not None and epoch % config.EVAL_FREQ == 0: |
|
|
acc1, acc5, loss = validate(config, data_loader_val, model, epoch, amp_autocast=amp_autocast) |
|
|
logger.info( |
|
|
f'Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%' |
|
|
) |
|
|
if dist.get_rank() == 0 and acc1 > max_accuracy: |
|
|
save_checkpoint(config, |
|
|
epoch, |
|
|
model_without_ddp, |
|
|
max_accuracy, |
|
|
optimizer, |
|
|
lr_scheduler, |
|
|
loss_scaler, |
|
|
logger, |
|
|
model_ema=model_ema, |
|
|
best='best') |
|
|
max_accuracy = max(max_accuracy, acc1) |
|
|
logger.info(f'Max accuracy: {max_accuracy:.2f}%') |
|
|
|
|
|
if config.TRAIN.EMA.ENABLE: |
|
|
acc1, acc5, loss = validate(config, data_loader_val, |
|
|
model_ema.ema, epoch, amp_autocast=amp_autocast) |
|
|
logger.info( |
|
|
f'Accuracy of the ema network on the {len(dataset_val)} test images: {acc1:.1f}%' |
|
|
) |
|
|
if dist.get_rank() == 0 and acc1 > max_ema_accuracy: |
|
|
save_checkpoint(config, |
|
|
epoch, |
|
|
model_without_ddp, |
|
|
max_accuracy, |
|
|
optimizer, |
|
|
lr_scheduler, |
|
|
loss_scaler, |
|
|
logger, |
|
|
model_ema=model_ema, |
|
|
best='ema_best') |
|
|
max_ema_accuracy = max(max_ema_accuracy, acc1) |
|
|
logger.info(f'Max ema accuracy: {max_ema_accuracy:.2f}%') |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
|
logger.info('Training time {}'.format(total_time_str)) |
|
|
|
|
|
|
|
|
def train_one_epoch(config, |
|
|
model, |
|
|
criterion, |
|
|
data_loader, |
|
|
optimizer, |
|
|
epoch, |
|
|
mixup_fn, |
|
|
lr_scheduler, |
|
|
amp_autocast=suppress, |
|
|
loss_scaler=None, |
|
|
model_ema=None): |
|
|
model.train() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
num_steps = len(data_loader) |
|
|
batch_time = AverageMeter() |
|
|
model_time = AverageMeter() |
|
|
loss_meter = AverageMeter() |
|
|
norm_meter = MyAverageMeter(300) |
|
|
|
|
|
start = time.time() |
|
|
end = time.time() |
|
|
|
|
|
amp_type = torch.float16 if config.AMP_TYPE == 'float16' else torch.bfloat16 |
|
|
for idx, (samples, targets) in enumerate(data_loader): |
|
|
iter_begin_time = time.time() |
|
|
samples = samples.cuda(non_blocking=True) |
|
|
targets = targets.cuda(non_blocking=True) |
|
|
|
|
|
if mixup_fn is not None: |
|
|
samples, targets = mixup_fn(samples, targets) |
|
|
|
|
|
if not obsolete_torch_version(TORCH_VERSION, |
|
|
(1, 9)) and config.AMP_OPT_LEVEL != 'O0': |
|
|
with amp_autocast(dtype=amp_type): |
|
|
outputs = model(samples) |
|
|
else: |
|
|
with amp_autocast(): |
|
|
outputs = model(samples) |
|
|
|
|
|
if config.TRAIN.ACCUMULATION_STEPS > 1: |
|
|
if not obsolete_torch_version( |
|
|
TORCH_VERSION, (1, 9)) and config.AMP_OPT_LEVEL != 'O0': |
|
|
with amp_autocast(dtype=amp_type): |
|
|
loss = criterion(outputs, targets) |
|
|
loss = loss / config.TRAIN.ACCUMULATION_STEPS |
|
|
else: |
|
|
with amp_autocast(): |
|
|
loss = criterion(outputs, targets) |
|
|
loss = loss / config.TRAIN.ACCUMULATION_STEPS |
|
|
if config.AMP_OPT_LEVEL != 'O0': |
|
|
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order |
|
|
grad_norm = loss_scaler(loss, |
|
|
optimizer, |
|
|
clip_grad=config.TRAIN.CLIP_GRAD, |
|
|
parameters=model.parameters(), |
|
|
create_graph=is_second_order, |
|
|
update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) |
|
|
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: |
|
|
optimizer.zero_grad() |
|
|
if model_ema is not None: |
|
|
model_ema.update(model) |
|
|
else: |
|
|
loss.backward() |
|
|
if config.TRAIN.CLIP_GRAD: |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_( |
|
|
model.parameters(), config.TRAIN.CLIP_GRAD) |
|
|
else: |
|
|
grad_norm = get_grad_norm(model.parameters()) |
|
|
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
if model_ema is not None: |
|
|
model_ema.update(model) |
|
|
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: |
|
|
lr_scheduler.step_update(epoch * num_steps + idx) |
|
|
else: |
|
|
if not obsolete_torch_version( |
|
|
TORCH_VERSION, (1, 9)) and config.AMP_OPT_LEVEL != 'O0': |
|
|
with amp_autocast(dtype=amp_type): |
|
|
loss = criterion(outputs, targets) |
|
|
else: |
|
|
with amp_autocast(): |
|
|
loss = criterion(outputs, targets) |
|
|
optimizer.zero_grad() |
|
|
if config.AMP_OPT_LEVEL != 'O0': |
|
|
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order |
|
|
grad_norm = loss_scaler(loss, |
|
|
optimizer, |
|
|
clip_grad=config.TRAIN.CLIP_GRAD, |
|
|
parameters=model.parameters(), |
|
|
create_graph=is_second_order, |
|
|
update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) |
|
|
if model_ema is not None: |
|
|
model_ema.update(model) |
|
|
else: |
|
|
loss.backward() |
|
|
if config.TRAIN.CLIP_GRAD: |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_( |
|
|
model.parameters(), config.TRAIN.CLIP_GRAD) |
|
|
else: |
|
|
grad_norm = get_grad_norm(model.parameters()) |
|
|
optimizer.step() |
|
|
if model_ema is not None: |
|
|
model_ema.update(model) |
|
|
|
|
|
lr_scheduler.step_update(epoch * num_steps + idx) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
loss_meter.update(loss.item(), targets.size(0)) |
|
|
if grad_norm is not None: |
|
|
norm_meter.update(grad_norm.item()) |
|
|
batch_time.update(time.time() - end) |
|
|
model_time.update(time.time() - iter_begin_time) |
|
|
end = time.time() |
|
|
|
|
|
if idx % config.PRINT_FREQ == 0: |
|
|
lr = optimizer.param_groups[0]['lr'] |
|
|
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) |
|
|
etas = batch_time.avg * (num_steps - idx) |
|
|
logger.info( |
|
|
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' |
|
|
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' |
|
|
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' |
|
|
f'model_time {model_time.val:.4f} ({model_time.avg:.4f})\t' |
|
|
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' |
|
|
f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f}/{norm_meter.var:.4f})\t' |
|
|
f'mem {memory_used:.0f}MB') |
|
|
epoch_time = time.time() - start |
|
|
logger.info( |
|
|
f'EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}' |
|
|
) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def validate_real(config, data_loader, model, real_labels, amp_autocast=suppress): |
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
model.eval() |
|
|
|
|
|
batch_time = AverageMeter() |
|
|
loss_meter = AverageMeter() |
|
|
acc1_meter = AverageMeter() |
|
|
acc5_meter = AverageMeter() |
|
|
|
|
|
end = time.time() |
|
|
amp_type = torch.float16 if config.AMP_TYPE == 'float16' else torch.bfloat16 |
|
|
for idx, (images, target) in enumerate(data_loader): |
|
|
images = images.cuda(non_blocking=True) |
|
|
target = target.cuda(non_blocking=True) |
|
|
if not obsolete_torch_version(TORCH_VERSION, (1, 9)) and config.AMP_OPT_LEVEL != 'O0': |
|
|
with amp_autocast(dtype=amp_type): |
|
|
output = model(images) |
|
|
else: |
|
|
with amp_autocast(): |
|
|
output = model(images) |
|
|
|
|
|
|
|
|
if output.size(-1) == 21841: |
|
|
convert_file = './meta_data/map22kto1k.txt' |
|
|
with open(convert_file, 'r') as f: |
|
|
convert_list = [int(line) for line in f.readlines()] |
|
|
output = output[:, convert_list] |
|
|
|
|
|
real_labels.add_result(output) |
|
|
|
|
|
|
|
|
loss = criterion(output, target) |
|
|
acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
|
|
|
|
|
acc1 = reduce_tensor(acc1) |
|
|
acc5 = reduce_tensor(acc5) |
|
|
loss = reduce_tensor(loss) |
|
|
|
|
|
loss_meter.update(loss.item(), target.size(0)) |
|
|
acc1_meter.update(acc1.item(), target.size(0)) |
|
|
acc5_meter.update(acc5.item(), target.size(0)) |
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
|
|
|
if idx % config.PRINT_FREQ == 0: |
|
|
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) |
|
|
logger.info(f'Test: [{idx}/{len(data_loader)}]\t' |
|
|
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
|
|
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' |
|
|
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' |
|
|
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' |
|
|
f'Mem {memory_used:.0f}MB') |
|
|
|
|
|
|
|
|
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) |
|
|
|
|
|
print('* ReaL Acc@1 {:.3f} Acc@5 {:.3f} loss {losses:.3f}' |
|
|
.format(top1a, top5a, losses=loss_meter.avg)) |
|
|
|
|
|
return top1a, top5a, loss_meter.avg |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def validate(config, data_loader, model, epoch=None, amp_autocast=suppress): |
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
model.eval() |
|
|
|
|
|
batch_time = AverageMeter() |
|
|
loss_meter = AverageMeter() |
|
|
acc1_meter = AverageMeter() |
|
|
acc5_meter = AverageMeter() |
|
|
|
|
|
end = time.time() |
|
|
amp_type = torch.float16 if config.AMP_TYPE == 'float16' else torch.bfloat16 |
|
|
for idx, (images, target) in enumerate(data_loader): |
|
|
images = images.cuda(non_blocking=True) |
|
|
target = target.cuda(non_blocking=True) |
|
|
if not obsolete_torch_version(TORCH_VERSION, (1, 9)) and config.AMP_OPT_LEVEL != 'O0': |
|
|
with amp_autocast(dtype=amp_type): |
|
|
output = model(images) |
|
|
else: |
|
|
with amp_autocast(): |
|
|
output = model(images) |
|
|
|
|
|
|
|
|
if output.size(-1) == 21841: |
|
|
convert_file = './meta_data/map22kto1k.txt' |
|
|
with open(convert_file, 'r') as f: |
|
|
convert_list = [int(line) for line in f.readlines()] |
|
|
output = output[:, convert_list] |
|
|
|
|
|
if config.DATA.DATASET == 'imagenet_a': |
|
|
from dataset.imagenet_a_r_indices import imagenet_a_mask |
|
|
output = output[:, imagenet_a_mask] |
|
|
elif config.DATA.DATASET == 'imagenet_r': |
|
|
from dataset.imagenet_a_r_indices import imagenet_r_mask |
|
|
output = output[:, imagenet_r_mask] |
|
|
|
|
|
|
|
|
loss = criterion(output, target) |
|
|
acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
|
|
|
|
|
acc1 = reduce_tensor(acc1) |
|
|
acc5 = reduce_tensor(acc5) |
|
|
loss = reduce_tensor(loss) |
|
|
|
|
|
loss_meter.update(loss.item(), target.size(0)) |
|
|
acc1_meter.update(acc1.item(), target.size(0)) |
|
|
acc5_meter.update(acc5.item(), target.size(0)) |
|
|
|
|
|
|
|
|
batch_time.update(time.time() - end) |
|
|
end = time.time() |
|
|
|
|
|
if idx % config.PRINT_FREQ == 0: |
|
|
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) |
|
|
logger.info(f'Test: [{idx}/{len(data_loader)}]\t' |
|
|
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
|
|
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' |
|
|
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' |
|
|
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' |
|
|
f'Mem {memory_used:.0f}MB') |
|
|
if epoch is not None: |
|
|
logger.info( |
|
|
f'[Epoch:{epoch}] * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}' |
|
|
) |
|
|
else: |
|
|
logger.info( |
|
|
f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') |
|
|
|
|
|
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
_, config = parse_option() |
|
|
|
|
|
if config.AMP_OPT_LEVEL != 'O0': |
|
|
assert has_native_amp, 'Please update pytorch(1.6+) to support amp!' |
|
|
|
|
|
|
|
|
if _.launcher == 'slurm': |
|
|
print('\nDist init: SLURM') |
|
|
rank = int(os.environ['SLURM_PROCID']) |
|
|
gpu = rank % torch.cuda.device_count() |
|
|
config.defrost() |
|
|
config.LOCAL_RANK = gpu |
|
|
config.freeze() |
|
|
|
|
|
world_size = int(os.environ['SLURM_NTASKS']) |
|
|
if 'MASTER_PORT' not in os.environ: |
|
|
os.environ['MASTER_PORT'] = '29501' |
|
|
node_list = os.environ['SLURM_NODELIST'] |
|
|
addr = subprocess.getoutput( |
|
|
f'scontrol show hostname {node_list} | head -n1') |
|
|
if 'MASTER_ADDR' not in os.environ: |
|
|
os.environ['MASTER_ADDR'] = addr |
|
|
|
|
|
os.environ['RANK'] = str(rank) |
|
|
os.environ['LOCAL_RANK'] = str(gpu) |
|
|
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) |
|
|
os.environ['WORLD_SIZE'] = str(world_size) |
|
|
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
|
|
rank = int(os.environ['RANK']) |
|
|
world_size = int(os.environ['WORLD_SIZE']) |
|
|
print(f'RANK and WORLD_SIZE in environ: {rank}/{world_size}') |
|
|
else: |
|
|
rank = -1 |
|
|
world_size = -1 |
|
|
|
|
|
torch.cuda.set_device(config.LOCAL_RANK) |
|
|
torch.distributed.init_process_group(backend='nccl', |
|
|
init_method='env://', |
|
|
world_size=world_size, |
|
|
rank=rank) |
|
|
torch.distributed.barrier() |
|
|
|
|
|
seed = config.SEED + dist.get_rank() |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
cudnn.benchmark = True |
|
|
|
|
|
|
|
|
linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 |
|
|
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 |
|
|
linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 |
|
|
|
|
|
if config.TRAIN.ACCUMULATION_STEPS > 1: |
|
|
linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS |
|
|
linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS |
|
|
linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS |
|
|
config.defrost() |
|
|
config.TRAIN.BASE_LR = linear_scaled_lr |
|
|
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr |
|
|
config.TRAIN.MIN_LR = linear_scaled_min_lr |
|
|
print(config.AMP_OPT_LEVEL, _.amp_opt_level) |
|
|
|
|
|
config.freeze() |
|
|
|
|
|
os.makedirs(config.OUTPUT, exist_ok=True) |
|
|
logger = create_logger(output_dir=config.OUTPUT, |
|
|
dist_rank=dist.get_rank(), |
|
|
name=f'{config.MODEL.NAME}') |
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
path = os.path.join(config.OUTPUT, 'config.json') |
|
|
with open(path, 'w') as f: |
|
|
f.write(config.dump()) |
|
|
logger.info(f'Full config saved to {path}') |
|
|
|
|
|
|
|
|
logger.info(config.dump()) |
|
|
|
|
|
main(config) |
|
|
|