|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import os |
|
|
from collections import OrderedDict |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from timm.utils import get_state_dict |
|
|
|
|
|
try: |
|
|
|
|
|
from apex import amp |
|
|
except ImportError: |
|
|
amp = None |
|
|
|
|
|
|
|
|
def load_ema_checkpoint(config, model_ema, logger): |
|
|
logger.info( |
|
|
f'==============> Resuming form {config.MODEL.RESUME}....................' |
|
|
) |
|
|
if config.MODEL.RESUME.startswith('https'): |
|
|
checkpoint = torch.hub.load_state_dict_from_url(config.MODEL.RESUME, |
|
|
map_location='cpu', |
|
|
check_hash=True) |
|
|
else: |
|
|
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
|
|
|
|
|
assert isinstance(checkpoint, dict) |
|
|
if 'model_ema' in checkpoint: |
|
|
new_state_dict = OrderedDict() |
|
|
for k, v in checkpoint['model_ema'].items(): |
|
|
if model_ema.ema_has_module: |
|
|
name = 'module.' + k if not k.startswith('module') else k |
|
|
else: |
|
|
name = k |
|
|
new_state_dict[name] = v |
|
|
msg = model_ema.ema.load_state_dict(new_state_dict, strict=False) |
|
|
logger.info(msg) |
|
|
logger.info('Loaded state_dict_ema') |
|
|
else: |
|
|
logger.warning( |
|
|
'Failed to find state_dict_ema, starting from loaded model weights' |
|
|
) |
|
|
|
|
|
max_accuracy_ema = 0 |
|
|
if 'max_accuracy_ema' in checkpoint: |
|
|
max_accuracy_ema = checkpoint['max_accuracy_ema'] |
|
|
if 'ema_decay' in checkpoint: |
|
|
model_ema.decay = checkpoint['ema_decay'] |
|
|
return max_accuracy_ema |
|
|
|
|
|
|
|
|
def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): |
|
|
logger.info( |
|
|
f'==============> Resuming form {config.MODEL.RESUME}....................' |
|
|
) |
|
|
if config.MODEL.RESUME.startswith('https'): |
|
|
checkpoint = torch.hub.load_state_dict_from_url(config.MODEL.RESUME, |
|
|
map_location='cpu', |
|
|
check_hash=True) |
|
|
else: |
|
|
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
|
|
|
|
|
print('resuming model') |
|
|
|
|
|
model_checkpoint = checkpoint['model'] |
|
|
msg = model.load_state_dict(model_checkpoint, strict=False) |
|
|
logger.info(msg) |
|
|
max_accuracy = 0.0 |
|
|
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: |
|
|
if optimizer is not None: |
|
|
print('resuming optimizer') |
|
|
try: |
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
except: |
|
|
print('resume optimizer failed') |
|
|
if lr_scheduler is not None: |
|
|
print('resuming lr_scheduler') |
|
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
|
config.defrost() |
|
|
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
|
|
config.freeze() |
|
|
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != 'O0' and checkpoint['config'].AMP_OPT_LEVEL != 'O0': |
|
|
scaler.load_state_dict(checkpoint['amp']) |
|
|
logger.info( |
|
|
f"=> loaded successfully {config.MODEL.RESUME} (epoch {checkpoint['epoch']})" |
|
|
) |
|
|
if 'max_accuracy' in checkpoint: |
|
|
max_accuracy = checkpoint['max_accuracy'] |
|
|
|
|
|
del checkpoint |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return max_accuracy |
|
|
|
|
|
|
|
|
def load_pretrained(config, model, logger): |
|
|
logger.info( |
|
|
f'==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......' |
|
|
) |
|
|
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') |
|
|
|
|
|
state_dict = checkpoint |
|
|
if 'model' in checkpoint: |
|
|
state_dict = checkpoint['model'] |
|
|
elif 'module' in checkpoint: |
|
|
state_dict = checkpoint['module'] |
|
|
|
|
|
first_key = list(state_dict.keys())[0] |
|
|
|
|
|
if 'student' in first_key or 'teacher' in first_key: |
|
|
new_state_dict = OrderedDict() |
|
|
for k, v in state_dict.items(): |
|
|
if 'student_proj' in k: |
|
|
continue |
|
|
if 'student' in k: |
|
|
new_k = k.replace('student.', '') |
|
|
new_state_dict[new_k] = v |
|
|
state_dict = new_state_dict |
|
|
|
|
|
|
|
|
if 'mask_token' in first_key: |
|
|
new_state_dict = OrderedDict() |
|
|
for k, v in state_dict.items(): |
|
|
if 'mm_dcnv3' in k: |
|
|
continue |
|
|
if 'dcnv3' not in k and 'clip_projector' not in k: |
|
|
continue |
|
|
new_k = k.replace('dcnv3.', '') |
|
|
new_state_dict[new_k] = v |
|
|
new_state_dict['fc_norm.weight'] = state_dict[ |
|
|
'clip.classifier_ln.weight'] |
|
|
new_state_dict['fc_norm.bias'] = state_dict['clip.classifier_ln.bias'] |
|
|
new_state_dict['head.weight'] = state_dict['clip.classifier.weight'] |
|
|
new_state_dict['head.bias'] = state_dict['clip.classifier.bias'] |
|
|
state_dict = new_state_dict |
|
|
|
|
|
|
|
|
relative_position_index_keys = [ |
|
|
k for k in state_dict.keys() if 'relative_position_index' in k |
|
|
] |
|
|
for k in relative_position_index_keys: |
|
|
del state_dict[k] |
|
|
|
|
|
|
|
|
relative_position_index_keys = [ |
|
|
k for k in state_dict.keys() if 'relative_coords_table' in k |
|
|
] |
|
|
for k in relative_position_index_keys: |
|
|
del state_dict[k] |
|
|
|
|
|
|
|
|
attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k] |
|
|
for k in attn_mask_keys: |
|
|
del state_dict[k] |
|
|
|
|
|
|
|
|
relative_position_bias_table_keys = [ |
|
|
k for k in state_dict.keys() if 'relative_position_bias_table' in k |
|
|
] |
|
|
for k in relative_position_bias_table_keys: |
|
|
relative_position_bias_table_pretrained = state_dict[k] |
|
|
relative_position_bias_table_current = model.state_dict()[k] |
|
|
L1, nH1 = relative_position_bias_table_pretrained.size() |
|
|
L2, nH2 = relative_position_bias_table_current.size() |
|
|
if nH1 != nH2: |
|
|
logger.warning(f'Error in loading {k}, passing......') |
|
|
else: |
|
|
if L1 != L2: |
|
|
|
|
|
S1 = int(L1 ** 0.5) |
|
|
S2 = int(L2 ** 0.5) |
|
|
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( |
|
|
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), |
|
|
size=(S2, S2), |
|
|
mode='bicubic') |
|
|
state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) |
|
|
|
|
|
|
|
|
absolute_pos_embed_keys = [ |
|
|
k for k in state_dict.keys() if 'absolute_pos_embed' in k |
|
|
] |
|
|
for k in absolute_pos_embed_keys: |
|
|
|
|
|
absolute_pos_embed_pretrained = state_dict[k] |
|
|
absolute_pos_embed_current = model.state_dict()[k] |
|
|
_, L1, C1 = absolute_pos_embed_pretrained.size() |
|
|
_, L2, C2 = absolute_pos_embed_current.size() |
|
|
if C1 != C1: |
|
|
logger.warning(f'Error in loading {k}, passing......') |
|
|
else: |
|
|
if L1 != L2: |
|
|
S1 = int(L1 ** 0.5) |
|
|
S2 = int(L2 ** 0.5) |
|
|
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) |
|
|
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) |
|
|
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( |
|
|
absolute_pos_embed_pretrained, |
|
|
size=(S2, S2), |
|
|
mode='bicubic') |
|
|
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) |
|
|
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) |
|
|
state_dict[k] = absolute_pos_embed_pretrained_resized |
|
|
|
|
|
|
|
|
if 'head.bias' in state_dict: |
|
|
head_bias_pretrained = state_dict['head.bias'] |
|
|
Nc1 = head_bias_pretrained.shape[0] |
|
|
Nc2 = model.head.bias.shape[0] |
|
|
|
|
|
if (Nc1 != Nc2): |
|
|
if config.TRAIN.RAND_INIT_FT_HEAD: |
|
|
model.head.weight.data = model.head.weight.data * 0.001 |
|
|
model.head.bias.data = model.head.bias.data * 0.001 |
|
|
del state_dict['head.weight'] |
|
|
del state_dict['head.bias'] |
|
|
logger.warning(f'Error in loading classifier head, re-init classifier head to 0') |
|
|
elif Nc1 == 21841 and Nc2 == 1000: |
|
|
logger.info('loading ImageNet-22K weight to ImageNet-1K ......') |
|
|
map22kto1k_path = 'meta_data/map22kto1k.txt' |
|
|
logger.info(map22kto1k_path) |
|
|
with open(map22kto1k_path) as f: |
|
|
map22kto1k = f.readlines() |
|
|
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] |
|
|
state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] |
|
|
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] |
|
|
|
|
|
msg = model.load_state_dict(state_dict, strict=False) |
|
|
logger.warning(msg) |
|
|
|
|
|
logger.info(f'=> loaded successfully {config.MODEL.PRETRAINED}') |
|
|
|
|
|
del checkpoint |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
def convert_22k_head_to_1k(model, logger): |
|
|
head_weight = model.module.head.weight |
|
|
head_bias = model.module.head.bias |
|
|
Nc1 = head_bias.shape[0] |
|
|
|
|
|
if Nc1 == 21841: |
|
|
logger.info('converting ImageNet-22K head to ImageNet-1K ......') |
|
|
map22kto1k_path = 'meta_data/map22kto1k.txt' |
|
|
logger.info(map22kto1k_path) |
|
|
with open(map22kto1k_path) as f: |
|
|
map22kto1k = f.readlines() |
|
|
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] |
|
|
model.module.head.weight = torch.nn.Parameter(head_weight[map22kto1k, :]) |
|
|
model.module.head.bias = torch.nn.Parameter(head_bias[map22kto1k]) |
|
|
else: |
|
|
logger.warning(f'Error in converting classifier head') |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def save_checkpoint(config, |
|
|
epoch, |
|
|
model, |
|
|
max_accuracy, |
|
|
optimizer, |
|
|
lr_scheduler, |
|
|
scaler, |
|
|
logger, |
|
|
model_ema=None, |
|
|
max_accuracy_ema=None, |
|
|
ema_decay=None, |
|
|
model_ems=None, |
|
|
max_accuracy_ems=None, |
|
|
ems_model_num=None, |
|
|
best=None): |
|
|
save_state = { |
|
|
'model': model.state_dict(), |
|
|
'optimizer': optimizer.state_dict(), |
|
|
'lr_scheduler': lr_scheduler.state_dict(), |
|
|
'max_accuracy': max_accuracy, |
|
|
'epoch': epoch, |
|
|
'config': config |
|
|
} |
|
|
if model_ema is not None: |
|
|
save_state['model_ema'] = get_state_dict(model_ema) |
|
|
if max_accuracy_ema is not None: |
|
|
save_state['max_accuracy_ema'] = max_accuracy_ema |
|
|
if ema_decay is not None: |
|
|
save_state['ema_decay'] = ema_decay |
|
|
if model_ems is not None: |
|
|
save_state['model_ems'] = get_state_dict(model_ems) |
|
|
if max_accuracy_ems is not None: |
|
|
save_state['max_accuracy_ems'] = max_accuracy_ems |
|
|
if ems_model_num is not None: |
|
|
save_state['ems_model_num'] = ems_model_num |
|
|
if config.AMP_OPT_LEVEL != 'O0': |
|
|
|
|
|
save_state['amp'] = scaler.state_dict() |
|
|
if best is None: |
|
|
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') |
|
|
else: |
|
|
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{best}.pth') |
|
|
logger.info(f'{save_path} saving......') |
|
|
torch.save(save_state, save_path) |
|
|
logger.info(f'{save_path} saved !!!') |
|
|
|
|
|
if dist.get_rank() == 0 and isinstance(epoch, int): |
|
|
to_del = epoch - config.SAVE_CKPT_NUM * config.SAVE_FREQ |
|
|
old_ckpt = os.path.join(config.OUTPUT, f'ckpt_epoch_{to_del}.pth') |
|
|
if os.path.exists(old_ckpt): |
|
|
os.remove(old_ckpt) |
|
|
|
|
|
|
|
|
def get_grad_norm(parameters, norm_type=2): |
|
|
if isinstance(parameters, torch.Tensor): |
|
|
parameters = [parameters] |
|
|
parameters = list(filter(lambda p: p.grad is not None, parameters)) |
|
|
norm_type = float(norm_type) |
|
|
total_norm = 0 |
|
|
for p in parameters: |
|
|
param_norm = p.grad.data.norm(norm_type) |
|
|
total_norm += param_norm.item() ** norm_type |
|
|
total_norm = total_norm ** (1. / norm_type) |
|
|
return total_norm |
|
|
|
|
|
|
|
|
def auto_resume_helper(output_dir): |
|
|
checkpoints = os.listdir(output_dir) |
|
|
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] |
|
|
print(f'All checkpoints founded in {output_dir}: {checkpoints}') |
|
|
if len(checkpoints) > 0: |
|
|
latest_checkpoint = max( |
|
|
[os.path.join(output_dir, d) for d in checkpoints], |
|
|
key=os.path.getmtime) |
|
|
print(f'The latest checkpoint founded: {latest_checkpoint}') |
|
|
resume_file = latest_checkpoint |
|
|
else: |
|
|
resume_file = None |
|
|
return resume_file |
|
|
|
|
|
|
|
|
def reduce_tensor(tensor): |
|
|
rt = tensor.clone() |
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
|
|
rt /= dist.get_world_size() |
|
|
return rt |
|
|
|
|
|
|
|
|
|
|
|
class NativeScalerWithGradNormCount: |
|
|
state_dict_key = 'amp_scaler' |
|
|
|
|
|
def __init__(self): |
|
|
self._scaler = torch.cuda.amp.GradScaler() |
|
|
|
|
|
def __call__(self, |
|
|
loss, |
|
|
optimizer, |
|
|
clip_grad=None, |
|
|
parameters=None, |
|
|
create_graph=False, |
|
|
update_grad=True): |
|
|
self._scaler.scale(loss).backward(create_graph=create_graph) |
|
|
if update_grad: |
|
|
if clip_grad is not None: |
|
|
assert parameters is not None |
|
|
self._scaler.unscale_(optimizer) |
|
|
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) |
|
|
else: |
|
|
self._scaler.unscale_(optimizer) |
|
|
norm = get_grad_norm(parameters) |
|
|
self._scaler.step(optimizer) |
|
|
self._scaler.update() |
|
|
else: |
|
|
norm = None |
|
|
return norm |
|
|
|
|
|
def state_dict(self): |
|
|
return self._scaler.state_dict() |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
self._scaler.load_state_dict(state_dict) |
|
|
|
|
|
|
|
|
class MyAverageMeter(object): |
|
|
"""Computes and stores the average and current value.""" |
|
|
|
|
|
def __init__(self, max_len=-1): |
|
|
self.val_list = [] |
|
|
self.count = [] |
|
|
self.max_len = max_len |
|
|
self.val = 0 |
|
|
self.avg = 0 |
|
|
self.var = 0 |
|
|
|
|
|
def update(self, val): |
|
|
self.val = val |
|
|
self.avg = 0 |
|
|
self.var = 0 |
|
|
if not math.isnan(val) and not math.isinf(val): |
|
|
self.val_list.append(val) |
|
|
if self.max_len > 0 and len(self.val_list) > self.max_len: |
|
|
self.val_list = self.val_list[-self.max_len:] |
|
|
if len(self.val_list) > 0: |
|
|
self.avg = np.mean(np.array(self.val_list)) |
|
|
self.var = np.std(np.array(self.val_list)) |
|
|
|