Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as zoomodels | |
| from torch.autograd import Function | |
| import timm | |
| from feature_networks.vit import _make_vit_b16_backbone, forward_vit | |
| from feature_networks.constants import ALL_MODELS, VITS, EFFNETS, REGNETS | |
| from pg_modules.blocks import Interpolate | |
| def _feature_splitter(model, idcs): | |
| pretrained = nn.Module() | |
| pretrained.layer0 = nn.Sequential(model.features[:idcs[0]]) | |
| pretrained.layer1 = nn.Sequential(model.features[idcs[0]:idcs[1]]) | |
| pretrained.layer2 = nn.Sequential(model.features[idcs[1]:idcs[2]]) | |
| pretrained.layer3 = nn.Sequential(model.features[idcs[2]:idcs[3]]) | |
| return pretrained | |
| def _make_resnet(model): | |
| pretrained = nn.Module() | |
| pretrained.layer0 = nn.Sequential( | |
| model.conv1, model.bn1, model.relu, model.maxpool, model.layer1, | |
| ) | |
| pretrained.layer1 = model.layer2 | |
| pretrained.layer2 = model.layer3 | |
| pretrained.layer3 = model.layer4 | |
| return pretrained | |
| def _make_regnet(model): | |
| pretrained = nn.Module() | |
| pretrained.layer0 = nn.Sequential( | |
| model.stem, model.s1 | |
| ) | |
| pretrained.layer1 = model.s2 | |
| pretrained.layer2 = model.s3 | |
| pretrained.layer3 = model.s4 | |
| return pretrained | |
| def _make_nfnet(model): | |
| pretrained = nn.Module() | |
| pretrained.layer0 = nn.Sequential( | |
| model.stem, model.stages[0] | |
| ) | |
| pretrained.layer1 = model.stages[1] | |
| pretrained.layer2 = model.stages[2] | |
| pretrained.layer3 = model.stages[3] | |
| return pretrained | |
| def _make_resnet_v2(model): | |
| pretrained = nn.Module() | |
| pretrained.layer0 = nn.Sequential(model.stem, model.stages[0]) | |
| pretrained.layer1 = model.stages[1] | |
| pretrained.layer2 = model.stages[2] | |
| pretrained.layer3 = model.stages[3] | |
| return pretrained | |
| def _make_densenet(model): | |
| pretrained = nn.Module() | |
| pretrained.layer0 = model.features[:6] | |
| pretrained.layer1 = model.features[6:8] | |
| pretrained.layer1[-1][-1] = nn.Identity() | |
| pretrained.layer1 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer1) | |
| pretrained.layer2 = model.features[8:10] | |
| pretrained.layer2[-1][-1] = nn.Identity() | |
| pretrained.layer2 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer2) | |
| pretrained.layer3 = model.features[10:12] | |
| pretrained.layer3 = nn.Sequential(nn.AvgPool2d(2, 2), pretrained.layer3) | |
| return pretrained | |
| def _make_shufflenet(model): | |
| pretrained = nn.Module() | |
| pretrained.layer0 = nn.Sequential(model.conv1, model.maxpool) | |
| pretrained.layer1 = model.stage2 | |
| pretrained.layer2 = model.stage3 | |
| pretrained.layer3 = model.stage4 | |
| return pretrained | |
| def _make_cspresnet(model): | |
| pretrained = nn.Module() | |
| pretrained.layer0 = nn.Sequential(model.stem, model.stages[0]) | |
| pretrained.layer1 = model.stages[1] | |
| pretrained.layer2 = model.stages[2] | |
| pretrained.layer3 = model.stages[3] | |
| return pretrained | |
| def _make_efficientnet(model): | |
| pretrained = nn.Module() | |
| pretrained.layer0 = nn.Sequential( | |
| model.conv_stem, model.bn1, model.act1, *model.blocks[0:2] | |
| ) | |
| pretrained.layer1 = nn.Sequential(*model.blocks[2:3]) | |
| pretrained.layer2 = nn.Sequential(*model.blocks[3:5]) | |
| pretrained.layer3 = nn.Sequential(*model.blocks[5:9]) | |
| return pretrained | |
| def _make_ghostnet(model): | |
| pretrained = nn.Module() | |
| pretrained.layer0 = nn.Sequential( | |
| model.conv_stem, model.bn1, model.act1, *model.blocks[0:3], | |
| ) | |
| pretrained.layer1 = nn.Sequential(*model.blocks[3:5]) | |
| pretrained.layer2 = nn.Sequential(*model.blocks[5:7]) | |
| pretrained.layer3 = nn.Sequential(*model.blocks[7:-1]) | |
| return pretrained | |
| def _make_vit(model, name): | |
| if 'tiny' in name: | |
| features = [24, 48, 96, 192] | |
| hooks = [2, 5, 8, 11] | |
| vit_features = 192 | |
| elif 'small' in name: | |
| features = [48, 96, 192, 384] | |
| hooks = [2, 5, 8, 11] | |
| vit_features = 384 | |
| elif 'base' in name: | |
| features = [96, 192, 384, 768] | |
| hooks = [2, 5, 8, 11] | |
| vit_features = 768 | |
| elif 'large' in name: | |
| features = [256, 512, 1024, 1024] | |
| hooks = [5, 11, 17, 23] | |
| vit_features = 1024 | |
| else: | |
| raise NotImplementedError('Invalid ViT backbone not available') | |
| return _make_vit_b16_backbone( | |
| model, | |
| features=features, | |
| size=[224, 224], | |
| hooks=hooks, | |
| vit_features=vit_features, | |
| start_index=2 if 'deit' in name else 1, | |
| ) | |
| def calc_dims(pretrained, is_vit=False): | |
| dims = [] | |
| inp_res = 256 | |
| tmp = torch.zeros(1, 3, inp_res, inp_res) | |
| if not is_vit: | |
| tmp = pretrained.layer0(tmp) | |
| dims.append(tmp.shape[1:3]) | |
| tmp = pretrained.layer1(tmp) | |
| dims.append(tmp.shape[1:3]) | |
| tmp = pretrained.layer2(tmp) | |
| dims.append(tmp.shape[1:3]) | |
| tmp = pretrained.layer3(tmp) | |
| dims.append(tmp.shape[1:3]) | |
| else: | |
| tmp = forward_vit(pretrained, tmp) | |
| dims = [out.shape[1:3] for out in tmp] | |
| # split to channels and resolution multiplier | |
| dims = np.array(dims) | |
| channels = dims[:, 0] | |
| res_mult = dims[:, 1] / inp_res | |
| return channels, res_mult | |
| def _make_pretrained(backbone, verbose=False): | |
| assert backbone in ALL_MODELS | |
| if backbone == 'vgg11_bn': | |
| model = zoomodels.__dict__[backbone](True) | |
| idcs = [7, 14, 21, 28] | |
| pretrained = _feature_splitter(model, idcs) | |
| elif backbone == 'vgg13_bn': | |
| model = zoomodels.__dict__[backbone](True) | |
| idcs = [13, 20, 27, 34] | |
| pretrained = _feature_splitter(model, idcs) | |
| elif backbone == 'vgg16_bn': | |
| model = zoomodels.__dict__[backbone](True) | |
| idcs = [13, 23, 33, 43] | |
| pretrained = _feature_splitter(model, idcs) | |
| elif backbone == 'vgg19_bn': | |
| model = zoomodels.__dict__[backbone](True) | |
| idcs = [13, 26, 39, 52] | |
| pretrained = _feature_splitter(model, idcs) | |
| elif backbone == 'densenet121': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_densenet(model) | |
| elif backbone == 'densenet169': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_densenet(model) | |
| elif backbone == 'densenet201': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_densenet(model) | |
| elif backbone == 'resnet18': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'resnet34': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'resnet50': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'resnet101': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'resnet152': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'wide_resnet50_2': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'wide_resnet101_2': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'shufflenet_v2_x0_5': | |
| model = zoomodels.__dict__[backbone](True) | |
| pretrained = _make_shufflenet(model) | |
| elif backbone == 'mobilenet_v2': | |
| model = zoomodels.__dict__[backbone](True) | |
| idcs = [4, 7, 14, 18] | |
| pretrained = _feature_splitter(model, idcs) # same structure as vgg | |
| elif backbone == 'mnasnet0_5': | |
| model = zoomodels.__dict__[backbone](True) | |
| model.features = model.layers | |
| idcs = [9, 10, 12, 14] | |
| pretrained = _feature_splitter(model, idcs) | |
| elif backbone == 'mnasnet1_0': | |
| model = zoomodels.__dict__[backbone](True) | |
| model.features = model.layers | |
| idcs = [9, 10, 12, 14] | |
| pretrained = _feature_splitter(model, idcs) | |
| elif backbone == 'ghostnet_100': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_ghostnet(model) | |
| elif backbone == 'cspresnet50': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'fbnetc_100': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_efficientnet(model) | |
| elif backbone == 'spnasnet_100': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_efficientnet(model) | |
| elif backbone == 'resnet50d': | |
| model = timm.create_model(backbone, pretrained=True) | |
| model.relu = model.act1 | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'resnet26': | |
| model = timm.create_model(backbone, pretrained=True) | |
| model.relu = model.act1 | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'resnet26d': | |
| model = timm.create_model(backbone, pretrained=True) | |
| model.relu = model.act1 | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'seresnet50': | |
| model = timm.create_model(backbone, pretrained=True) | |
| model.relu = model.act1 | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'resnetblur50': | |
| model = timm.create_model(backbone, pretrained=True) | |
| model.relu = model.act1 | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'resnetrs50': | |
| model = timm.create_model(backbone, pretrained=True) | |
| model.relu = model.act1 | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'tf_mixnet_s': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_efficientnet(model) | |
| elif backbone == 'tf_mixnet_m': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_efficientnet(model) | |
| elif backbone == 'tf_mixnet_l': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_efficientnet(model) | |
| elif backbone == 'dm_nfnet_f0': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'dm_nfnet_f1': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'ese_vovnet19b_dw': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'ese_vovnet39b': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'res2next50': | |
| model = timm.create_model(backbone, pretrained=True) | |
| model.relu = model.act1 | |
| pretrained = _make_resnet(model) | |
| elif backbone == 'gernet_s': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'gernet_m': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'repvgg_a2': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'repvgg_b0': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'repvgg_b1': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'repvgg_b1g4': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_cspresnet(model) | |
| elif backbone == 'dm_nfnet_f1': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_nfnet(model) | |
| elif backbone == 'nfnet_l0': | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_nfnet(model) | |
| elif backbone in REGNETS: | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_regnet(model) | |
| elif backbone in EFFNETS: | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_efficientnet(model) | |
| elif backbone in VITS: | |
| model = timm.create_model(backbone, pretrained=True) | |
| pretrained = _make_vit(model, backbone) | |
| else: | |
| raise NotImplementedError('Wrong model name?') | |
| pretrained.CHANNELS, pretrained.RES_MULT = calc_dims(pretrained, is_vit=backbone in VITS) | |
| if verbose: | |
| print(f"Succesfully loaded: {backbone}") | |
| print(f"Channels: {pretrained.CHANNELS}") | |
| print(f"Resolution Multiplier: {pretrained.RES_MULT}") | |
| print(f"Out Res for 256 : {pretrained.RES_MULT*256}") | |
| return pretrained | |