tqv06's picture
Upload folder using huggingface_hub
866ee56 verified
# --------------------------------------------------------
# InternVL
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import math
import os
from collections import OrderedDict
import numpy as np
import torch
import torch.distributed as dist
from timm.utils import get_state_dict
try:
# noinspection PyUnresolvedReferences
from apex import amp
except ImportError:
amp = None
def load_ema_checkpoint(config, model_ema, logger):
logger.info(
f'==============> Resuming form {config.MODEL.RESUME}....................'
)
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(config.MODEL.RESUME,
map_location='cpu',
check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
assert isinstance(checkpoint, dict)
if 'model_ema' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['model_ema'].items():
if model_ema.ema_has_module:
name = 'module.' + k if not k.startswith('module') else k
else:
name = k
new_state_dict[name] = v
msg = model_ema.ema.load_state_dict(new_state_dict, strict=False)
logger.info(msg)
logger.info('Loaded state_dict_ema')
else:
logger.warning(
'Failed to find state_dict_ema, starting from loaded model weights'
)
max_accuracy_ema = 0
if 'max_accuracy_ema' in checkpoint:
max_accuracy_ema = checkpoint['max_accuracy_ema']
if 'ema_decay' in checkpoint:
model_ema.decay = checkpoint['ema_decay']
return max_accuracy_ema
def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger):
logger.info(
f'==============> Resuming form {config.MODEL.RESUME}....................'
)
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(config.MODEL.RESUME,
map_location='cpu',
check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
print('resuming model')
model_checkpoint = checkpoint['model']
msg = model.load_state_dict(model_checkpoint, strict=False)
logger.info(msg)
max_accuracy = 0.0
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
if optimizer is not None:
print('resuming optimizer')
try:
optimizer.load_state_dict(checkpoint['optimizer'])
except:
print('resume optimizer failed')
if lr_scheduler is not None:
print('resuming lr_scheduler')
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != 'O0' and checkpoint['config'].AMP_OPT_LEVEL != 'O0':
scaler.load_state_dict(checkpoint['amp'])
logger.info(
f"=> loaded successfully {config.MODEL.RESUME} (epoch {checkpoint['epoch']})"
)
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
del checkpoint
torch.cuda.empty_cache()
return max_accuracy
def load_pretrained(config, model, logger):
logger.info(
f'==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......'
)
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
state_dict = checkpoint
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'module' in checkpoint:
state_dict = checkpoint['module']
first_key = list(state_dict.keys())[0]
# delete teacher weights
if 'student' in first_key or 'teacher' in first_key:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if 'student_proj' in k:
continue
if 'student' in k:
new_k = k.replace('student.', '')
new_state_dict[new_k] = v
state_dict = new_state_dict
# weights from sim
if 'mask_token' in first_key:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if 'mm_dcnv3' in k:
continue
if 'dcnv3' not in k and 'clip_projector' not in k:
continue
new_k = k.replace('dcnv3.', '')
new_state_dict[new_k] = v
new_state_dict['fc_norm.weight'] = state_dict[
'clip.classifier_ln.weight']
new_state_dict['fc_norm.bias'] = state_dict['clip.classifier_ln.bias']
new_state_dict['head.weight'] = state_dict['clip.classifier.weight']
new_state_dict['head.bias'] = state_dict['clip.classifier.bias']
state_dict = new_state_dict
# delete relative_position_index since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_position_index' in k
]
for k in relative_position_index_keys:
del state_dict[k]
# delete relative_coords_table since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_coords_table' in k
]
for k in relative_position_index_keys:
del state_dict[k]
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in state_dict.keys() if 'attn_mask' in k]
for k in attn_mask_keys:
del state_dict[k]
# bicubic interpolate relative_position_bias_table if not match
relative_position_bias_table_keys = [
k for k in state_dict.keys() if 'relative_position_bias_table' in k
]
for k in relative_position_bias_table_keys:
relative_position_bias_table_pretrained = state_dict[k]
relative_position_bias_table_current = model.state_dict()[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if nH1 != nH2:
logger.warning(f'Error in loading {k}, passing......')
else:
if L1 != L2:
# bicubic interpolate relative_position_bias_table if not match
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
# bicubic interpolate absolute_pos_embed if not match
absolute_pos_embed_keys = [
k for k in state_dict.keys() if 'absolute_pos_embed' in k
]
for k in absolute_pos_embed_keys:
# dpe
absolute_pos_embed_pretrained = state_dict[k]
absolute_pos_embed_current = model.state_dict()[k]
_, L1, C1 = absolute_pos_embed_pretrained.size()
_, L2, C2 = absolute_pos_embed_current.size()
if C1 != C1:
logger.warning(f'Error in loading {k}, passing......')
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
absolute_pos_embed_pretrained,
size=(S2, S2),
mode='bicubic')
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
state_dict[k] = absolute_pos_embed_pretrained_resized
# check classifier, if not match, then re-init classifier to zero
if 'head.bias' in state_dict:
head_bias_pretrained = state_dict['head.bias']
Nc1 = head_bias_pretrained.shape[0]
Nc2 = model.head.bias.shape[0]
if (Nc1 != Nc2):
if config.TRAIN.RAND_INIT_FT_HEAD:
model.head.weight.data = model.head.weight.data * 0.001
model.head.bias.data = model.head.bias.data * 0.001
del state_dict['head.weight']
del state_dict['head.bias']
logger.warning(f'Error in loading classifier head, re-init classifier head to 0')
elif Nc1 == 21841 and Nc2 == 1000:
logger.info('loading ImageNet-22K weight to ImageNet-1K ......')
map22kto1k_path = 'meta_data/map22kto1k.txt'
logger.info(map22kto1k_path)
with open(map22kto1k_path) as f:
map22kto1k = f.readlines()
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
msg = model.load_state_dict(state_dict, strict=False)
logger.warning(msg)
logger.info(f'=> loaded successfully {config.MODEL.PRETRAINED}')
del checkpoint
torch.cuda.empty_cache()
def convert_22k_head_to_1k(model, logger):
head_weight = model.module.head.weight
head_bias = model.module.head.bias
Nc1 = head_bias.shape[0]
if Nc1 == 21841:
logger.info('converting ImageNet-22K head to ImageNet-1K ......')
map22kto1k_path = 'meta_data/map22kto1k.txt'
logger.info(map22kto1k_path)
with open(map22kto1k_path) as f:
map22kto1k = f.readlines()
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
model.module.head.weight = torch.nn.Parameter(head_weight[map22kto1k, :])
model.module.head.bias = torch.nn.Parameter(head_bias[map22kto1k])
else:
logger.warning(f'Error in converting classifier head')
return model
def save_checkpoint(config,
epoch,
model,
max_accuracy,
optimizer,
lr_scheduler,
scaler,
logger,
model_ema=None,
max_accuracy_ema=None,
ema_decay=None,
model_ems=None,
max_accuracy_ems=None,
ems_model_num=None,
best=None):
save_state = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config
}
if model_ema is not None:
save_state['model_ema'] = get_state_dict(model_ema)
if max_accuracy_ema is not None:
save_state['max_accuracy_ema'] = max_accuracy_ema
if ema_decay is not None:
save_state['ema_decay'] = ema_decay
if model_ems is not None:
save_state['model_ems'] = get_state_dict(model_ems)
if max_accuracy_ems is not None:
save_state['max_accuracy_ems'] = max_accuracy_ems
if ems_model_num is not None:
save_state['ems_model_num'] = ems_model_num
if config.AMP_OPT_LEVEL != 'O0':
# save_state['amp'] = amp.state_dict()
save_state['amp'] = scaler.state_dict()
if best is None:
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
else:
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{best}.pth')
logger.info(f'{save_path} saving......')
torch.save(save_state, save_path)
logger.info(f'{save_path} saved !!!')
if dist.get_rank() == 0 and isinstance(epoch, int):
to_del = epoch - config.SAVE_CKPT_NUM * config.SAVE_FREQ
old_ckpt = os.path.join(config.OUTPUT, f'ckpt_epoch_{to_del}.pth')
if os.path.exists(old_ckpt):
os.remove(old_ckpt)
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return total_norm
def auto_resume_helper(output_dir):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
print(f'All checkpoints founded in {output_dir}: {checkpoints}')
if len(checkpoints) > 0:
latest_checkpoint = max(
[os.path.join(output_dir, d) for d in checkpoints],
key=os.path.getmtime)
print(f'The latest checkpoint founded: {latest_checkpoint}')
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
# https://github.com/facebookresearch/ConvNeXt/blob/main/utils.py
class NativeScalerWithGradNormCount:
state_dict_key = 'amp_scaler'
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self,
loss,
optimizer,
clip_grad=None,
parameters=None,
create_graph=False,
update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
class MyAverageMeter(object):
"""Computes and stores the average and current value."""
def __init__(self, max_len=-1):
self.val_list = []
self.count = []
self.max_len = max_len
self.val = 0
self.avg = 0
self.var = 0
def update(self, val):
self.val = val
self.avg = 0
self.var = 0
if not math.isnan(val) and not math.isinf(val):
self.val_list.append(val)
if self.max_len > 0 and len(self.val_list) > self.max_len:
self.val_list = self.val_list[-self.max_len:]
if len(self.val_list) > 0:
self.avg = np.mean(np.array(self.val_list))
self.var = np.std(np.array(self.val_list))