ProjectedGANCLC / feature_networks /pretrained_builder.py
ZJW666's picture
fist version
3619ba7
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as zoomodels
from torch.autograd import Function
import timm
from feature_networks.vit import _make_vit_b16_backbone, forward_vit
from feature_networks.constants import ALL_MODELS, VITS, EFFNETS, REGNETS
from pg_modules.blocks import Interpolate
def _feature_splitter(model, idcs):
pretrained = nn.Module()
pretrained.layer0 = nn.Sequential(model.features[:idcs[0]])
pretrained.layer1 = nn.Sequential(model.features[idcs[0]:idcs[1]])
pretrained.layer2 = nn.Sequential(model.features[idcs[1]:idcs[2]])
pretrained.layer3 = nn.Sequential(model.features[idcs[2]:idcs[3]])
return pretrained
def _make_resnet(model):
pretrained = nn.Module()
pretrained.layer0 = nn.Sequential(
model.conv1, model.bn1, model.relu, model.maxpool, model.layer1,
)
pretrained.layer1 = model.layer2
pretrained.layer2 = model.layer3
pretrained.layer3 = model.layer4
return pretrained
def _make_regnet(model):
pretrained = nn.Module()
pretrained.layer0 = nn.Sequential(
model.stem, model.s1
)
pretrained.layer1 = model.s2
pretrained.layer2 = model.s3
pretrained.layer3 = model.s4
return pretrained
def _make_nfnet(model):
pretrained = nn.Module()
pretrained.layer0 = nn.Sequential(
model.stem, model.stages[0]
)
pretrained.layer1 = model.stages[1]
pretrained.layer2 = model.stages[2]
pretrained.layer3 = model.stages[3]
return pretrained
def _make_resnet_v2(model):
pretrained = nn.Module()
pretrained.layer0 = nn.Sequential(model.stem, model.stages[0])
pretrained.layer1 = model.stages[1]
pretrained.layer2 = model.stages[2]
pretrained.layer3 = model.stages[3]
return pretrained
def _make_densenet(model):
pretrained = nn.Module()
pretrained.layer0 = model.features[:6]
pretrained.layer1 = model.features[6:8]
pretrained.layer1[-1][-1] = nn.Identity()
pretrained.layer1 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer1)
pretrained.layer2 = model.features[8:10]
pretrained.layer2[-1][-1] = nn.Identity()
pretrained.layer2 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer2)
pretrained.layer3 = model.features[10:12]
pretrained.layer3 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer3)
return pretrained
def _make_shufflenet(model):
pretrained = nn.Module()
pretrained.layer0 = nn.Sequential(model.conv1, model.maxpool)
pretrained.layer1 = model.stage2
pretrained.layer2 = model.stage3
pretrained.layer3 = model.stage4
return pretrained
def _make_cspresnet(model):
pretrained = nn.Module()
pretrained.layer0 = nn.Sequential(model.stem, model.stages[0])
pretrained.layer1 = model.stages[1]
pretrained.layer2 = model.stages[2]
pretrained.layer3 = model.stages[3]
return pretrained
def _make_efficientnet(model):
pretrained = nn.Module()
pretrained.layer0 = nn.Sequential(
model.conv_stem, model.bn1, model.act1, *model.blocks[0:2]
)
pretrained.layer1 = nn.Sequential(*model.blocks[2:3])
pretrained.layer2 = nn.Sequential(*model.blocks[3:5])
pretrained.layer3 = nn.Sequential(*model.blocks[5:9])
return pretrained
def _make_ghostnet(model):
pretrained = nn.Module()
pretrained.layer0 = nn.Sequential(
model.conv_stem, model.bn1, model.act1, *model.blocks[0:3],
)
pretrained.layer1 = nn.Sequential(*model.blocks[3:5])
pretrained.layer2 = nn.Sequential(*model.blocks[5:7])
pretrained.layer3 = nn.Sequential(*model.blocks[7:-1])
return pretrained
def _make_vit(model, name):
if 'tiny' in name:
features = [24, 48, 96, 192]
hooks = [2, 5, 8, 11]
vit_features = 192
elif 'small' in name:
features = [48, 96, 192, 384]
hooks = [2, 5, 8, 11]
vit_features = 384
elif 'base' in name:
features = [96, 192, 384, 768]
hooks = [2, 5, 8, 11]
vit_features = 768
elif 'large' in name:
features = [256, 512, 1024, 1024]
hooks = [5, 11, 17, 23]
vit_features = 1024
else:
raise NotImplementedError('Invalid ViT backbone not available')
return _make_vit_b16_backbone(
model,
features=features,
size=[224, 224],
hooks=hooks,
vit_features=vit_features,
start_index=2 if 'deit' in name else 1,
)
def calc_dims(pretrained, is_vit=False):
dims = []
inp_res = 256
tmp = torch.zeros(1, 3, inp_res, inp_res)
if not is_vit:
tmp = pretrained.layer0(tmp)
dims.append(tmp.shape[1:3])
tmp = pretrained.layer1(tmp)
dims.append(tmp.shape[1:3])
tmp = pretrained.layer2(tmp)
dims.append(tmp.shape[1:3])
tmp = pretrained.layer3(tmp)
dims.append(tmp.shape[1:3])
else:
tmp = forward_vit(pretrained, tmp)
dims = [out.shape[1:3] for out in tmp]
# split to channels and resolution multiplier
dims = np.array(dims)
channels = dims[:, 0]
res_mult = dims[:, 1] / inp_res
return channels, res_mult
def _make_pretrained(backbone, verbose=False):
assert backbone in ALL_MODELS
if backbone == 'vgg11_bn':
model = zoomodels.__dict__[backbone](True)
idcs = [7, 14, 21, 28]
pretrained = _feature_splitter(model, idcs)
elif backbone == 'vgg13_bn':
model = zoomodels.__dict__[backbone](True)
idcs = [13, 20, 27, 34]
pretrained = _feature_splitter(model, idcs)
elif backbone == 'vgg16_bn':
model = zoomodels.__dict__[backbone](True)
idcs = [13, 23, 33, 43]
pretrained = _feature_splitter(model, idcs)
elif backbone == 'vgg19_bn':
model = zoomodels.__dict__[backbone](True)
idcs = [13, 26, 39, 52]
pretrained = _feature_splitter(model, idcs)
elif backbone == 'densenet121':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_densenet(model)
elif backbone == 'densenet169':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_densenet(model)
elif backbone == 'densenet201':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_densenet(model)
elif backbone == 'resnet18':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_resnet(model)
elif backbone == 'resnet34':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_resnet(model)
elif backbone == 'resnet50':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_resnet(model)
elif backbone == 'resnet101':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_resnet(model)
elif backbone == 'resnet152':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_resnet(model)
elif backbone == 'wide_resnet50_2':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_resnet(model)
elif backbone == 'wide_resnet101_2':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_resnet(model)
elif backbone == 'shufflenet_v2_x0_5':
model = zoomodels.__dict__[backbone](True)
pretrained = _make_shufflenet(model)
elif backbone == 'mobilenet_v2':
model = zoomodels.__dict__[backbone](True)
idcs = [4, 7, 14, 18]
pretrained = _feature_splitter(model, idcs) # same structure as vgg
elif backbone == 'mnasnet0_5':
model = zoomodels.__dict__[backbone](True)
model.features = model.layers
idcs = [9, 10, 12, 14]
pretrained = _feature_splitter(model, idcs)
elif backbone == 'mnasnet1_0':
model = zoomodels.__dict__[backbone](True)
model.features = model.layers
idcs = [9, 10, 12, 14]
pretrained = _feature_splitter(model, idcs)
elif backbone == 'ghostnet_100':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_ghostnet(model)
elif backbone == 'cspresnet50':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'fbnetc_100':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_efficientnet(model)
elif backbone == 'spnasnet_100':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_efficientnet(model)
elif backbone == 'resnet50d':
model = timm.create_model(backbone, pretrained=True)
model.relu = model.act1
pretrained = _make_resnet(model)
elif backbone == 'resnet26':
model = timm.create_model(backbone, pretrained=True)
model.relu = model.act1
pretrained = _make_resnet(model)
elif backbone == 'resnet26d':
model = timm.create_model(backbone, pretrained=True)
model.relu = model.act1
pretrained = _make_resnet(model)
elif backbone == 'seresnet50':
model = timm.create_model(backbone, pretrained=True)
model.relu = model.act1
pretrained = _make_resnet(model)
elif backbone == 'resnetblur50':
model = timm.create_model(backbone, pretrained=True)
model.relu = model.act1
pretrained = _make_resnet(model)
elif backbone == 'resnetrs50':
model = timm.create_model(backbone, pretrained=True)
model.relu = model.act1
pretrained = _make_resnet(model)
elif backbone == 'tf_mixnet_s':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_efficientnet(model)
elif backbone == 'tf_mixnet_m':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_efficientnet(model)
elif backbone == 'tf_mixnet_l':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_efficientnet(model)
elif backbone == 'dm_nfnet_f0':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'dm_nfnet_f1':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'ese_vovnet19b_dw':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'ese_vovnet39b':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'res2next50':
model = timm.create_model(backbone, pretrained=True)
model.relu = model.act1
pretrained = _make_resnet(model)
elif backbone == 'gernet_s':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'gernet_m':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'repvgg_a2':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'repvgg_b0':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'repvgg_b1':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'repvgg_b1g4':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_cspresnet(model)
elif backbone == 'dm_nfnet_f1':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_nfnet(model)
elif backbone == 'nfnet_l0':
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_nfnet(model)
elif backbone in REGNETS:
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_regnet(model)
elif backbone in EFFNETS:
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_efficientnet(model)
elif backbone in VITS:
model = timm.create_model(backbone, pretrained=True)
pretrained = _make_vit(model, backbone)
else:
raise NotImplementedError('Wrong model name?')
pretrained.CHANNELS, pretrained.RES_MULT = calc_dims(pretrained, is_vit=backbone in VITS)
if verbose:
print(f"Succesfully loaded: {backbone}")
print(f"Channels: {pretrained.CHANNELS}")
print(f"Resolution Multiplier: {pretrained.RES_MULT}")
print(f"Out Res for 256 : {pretrained.RES_MULT*256}")
return pretrained