Spaces:
Runtime error
Runtime error
File size: 5,678 Bytes
7a59a55 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import torch
import torch.nn as nn
from feature_networks.vit import forward_vit
from feature_networks.pretrained_builder import _make_pretrained
from feature_networks.constants import NORMALIZED_INCEPTION, NORMALIZED_IMAGENET, NORMALIZED_CLIP, VITS
from pg_modules.blocks import FeatureFusionBlock
def get_backbone_normstats(backbone):
if backbone in NORMALIZED_INCEPTION:
return {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
}
elif backbone in NORMALIZED_IMAGENET:
return {
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
}
elif backbone in NORMALIZED_CLIP:
return {
'mean': [0.48145466, 0.4578275, 0.40821073],
'std': [0.26862954, 0.26130258, 0.27577711],
}
else:
raise NotImplementedError
def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
# shapes
out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
scratch.CHANNELS = out_channels
return scratch
def _make_scratch_csm(scratch, in_channels, cout, expand):
scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
# last refinenet does not expand to save channels in higher dimensions
scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
return scratch
def _make_projector(im_res, backbone, cout, proj_type, expand=False):
assert proj_type in [0, 1, 2], "Invalid projection type"
### Build pretrained feature network
pretrained = _make_pretrained(backbone)
# Following Projected GAN
im_res = 256
pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
if proj_type == 0: return pretrained, None
### Build CCM
scratch = nn.Module()
scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
pretrained.CHANNELS = scratch.CHANNELS
if proj_type == 1: return pretrained, scratch
### build CSM
scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
# CSM upsamples x2 so the feature map resolution doubles
pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
pretrained.CHANNELS = scratch.CHANNELS
return pretrained, scratch
class F_Identity(nn.Module):
def forward(self, x):
return x
class F_RandomProj(nn.Module):
def __init__(
self,
backbone="tf_efficientnet_lite3",
im_res=256,
cout=64,
expand=True,
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
**kwargs,
):
super().__init__()
self.proj_type = proj_type
self.backbone = backbone
self.cout = cout
self.expand = expand
self.normstats = get_backbone_normstats(backbone)
# build pretrained feature network and random decoder (scratch)
self.pretrained, self.scratch = _make_projector(im_res=im_res, backbone=self.backbone, cout=self.cout,
proj_type=self.proj_type, expand=self.expand)
self.CHANNELS = self.pretrained.CHANNELS
self.RESOLUTIONS = self.pretrained.RESOLUTIONS
def forward(self, x):
# predict feature maps
if self.backbone in VITS:
out0, out1, out2, out3 = forward_vit(self.pretrained, x)
else:
out0 = self.pretrained.layer0(x)
out1 = self.pretrained.layer1(out0)
out2 = self.pretrained.layer2(out1)
out3 = self.pretrained.layer3(out2)
# start enumerating at the lowest layer (this is where we put the first discriminator)
out = {
'0': out0,
'1': out1,
'2': out2,
'3': out3,
}
if self.proj_type == 0: return out
out0_channel_mixed = self.scratch.layer0_ccm(out['0'])
out1_channel_mixed = self.scratch.layer1_ccm(out['1'])
out2_channel_mixed = self.scratch.layer2_ccm(out['2'])
out3_channel_mixed = self.scratch.layer3_ccm(out['3'])
out = {
'0': out0_channel_mixed,
'1': out1_channel_mixed,
'2': out2_channel_mixed,
'3': out3_channel_mixed,
}
if self.proj_type == 1: return out
# from bottom to top
out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
out = {
'0': out0_scale_mixed,
'1': out1_scale_mixed,
'2': out2_scale_mixed,
'3': out3_scale_mixed,
}
return out
|