|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import time |
|
|
|
|
|
import torch |
|
|
from mmcv.cnn import get_model_complexity_info |
|
|
from mmcv.cnn.utils.flops_counter import flops_to_string, params_to_string |
|
|
from models.intern_vit_6b import InternViT6B |
|
|
from tqdm import tqdm |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Hyperparams') |
|
|
parser.add_argument('config', nargs='?', type=str, default=None) |
|
|
args = parser.parse_args() |
|
|
|
|
|
configs = { |
|
|
'a': { |
|
|
'embed_dim': 3968, |
|
|
'num_heads': 62, |
|
|
'mlp_ratio': 4, |
|
|
'depth': 32 |
|
|
}, |
|
|
'e': { |
|
|
'embed_dim': 3200, |
|
|
'num_heads': 50, |
|
|
'mlp_ratio': 4, |
|
|
'depth': 48 |
|
|
}, |
|
|
'f': { |
|
|
'embed_dim': 3200, |
|
|
'num_heads': 25, |
|
|
'mlp_ratio': 4, |
|
|
'depth': 48 |
|
|
}, |
|
|
'g': { |
|
|
'embed_dim': 2496, |
|
|
'num_heads': 39, |
|
|
'mlp_ratio': 8, |
|
|
'depth': 48 |
|
|
}, |
|
|
'i': { |
|
|
'embed_dim': 2816, |
|
|
'num_heads': 44, |
|
|
'mlp_ratio': 4, |
|
|
'depth': 64 |
|
|
}, |
|
|
'm': { |
|
|
'embed_dim': 2496, |
|
|
'num_heads': 39, |
|
|
'mlp_ratio': 4, |
|
|
'depth': 80 |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def sa_flops(h, w, dim): |
|
|
return 2 * h * w * h * w * dim |
|
|
|
|
|
|
|
|
def get_flops(model, input_shape): |
|
|
flops, params = get_model_complexity_info(model, |
|
|
input_shape, |
|
|
as_strings=False) |
|
|
_, H, W = input_shape |
|
|
print(flops, params) |
|
|
for i in range(model.depth): |
|
|
flops += sa_flops(H // model.patch_size, W // model.patch_size, |
|
|
model.embed_dim) |
|
|
return flops_to_string(flops), params_to_string(params) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
input_shape = (3, 224, 224) |
|
|
|
|
|
config = configs[args.config] |
|
|
print(config) |
|
|
model = InternViT6B(in_chans=3, |
|
|
patch_size=14, |
|
|
img_size=224, |
|
|
pretrain_size=224, |
|
|
qkv_bias=False, |
|
|
drop_path_rate=0.0, |
|
|
embed_dim=config['embed_dim'], |
|
|
num_heads=config['num_heads'], |
|
|
mlp_ratio=config['mlp_ratio'], |
|
|
init_values=0.1, |
|
|
qk_normalization=True, |
|
|
depth=config['depth'], |
|
|
use_flash_attn=True, |
|
|
with_cp=True, |
|
|
freeze_vit=True, |
|
|
cls_target='cls_patch_concat', |
|
|
num_classes=0, |
|
|
attn_pool_num_heads=16, |
|
|
clip_embed_dim=768, |
|
|
head_norm_type='bn').to(torch.bfloat16) |
|
|
|
|
|
for k, v in model.named_parameters(): |
|
|
v.requires_grad = True |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model.cuda() |
|
|
model.eval() |
|
|
|
|
|
flops, params = get_flops(model, input_shape) |
|
|
split_line = '=' * 30 |
|
|
print(f'{split_line}\nInput shape: {input_shape}\n' |
|
|
f'Flops: {flops}\nParams: {params}\n{split_line}') |
|
|
print('!!!Please be cautious if you use the results in papers. ' |
|
|
'You may need to check if all ops are supported and verify that the ' |
|
|
'flops computation is correct.') |
|
|
|
|
|
image = torch.rand(128, 3, 224, 224).to(torch.bfloat16).cuda() |
|
|
torch.cuda.synchronize() |
|
|
start_time = time.time() |
|
|
with torch.no_grad(): |
|
|
for i in tqdm(range(10)): |
|
|
out = model(image) |
|
|
torch.cuda.synchronize() |
|
|
end_time = time.time() |
|
|
|
|
|
print('warmup time: ', end_time - start_time) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
start_time = time.time() |
|
|
with torch.no_grad(): |
|
|
for i in tqdm(range(50)): |
|
|
out = model(image) |
|
|
torch.cuda.synchronize() |
|
|
end_time = time.time() |
|
|
print('using time: ', (end_time - start_time)) |
|
|
print('FPS: ', 50 * 128 / (end_time - start_time)) |
|
|
print(config) |
|
|
|