Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from feature_networks.vit import forward_vit | |
| from feature_networks.pretrained_builder import _make_pretrained | |
| from feature_networks.constants import NORMALIZED_INCEPTION, NORMALIZED_IMAGENET, NORMALIZED_CLIP, VITS | |
| from pg_modules.blocks import FeatureFusionBlock | |
| def get_backbone_normstats(backbone): | |
| if backbone in NORMALIZED_INCEPTION: | |
| return { | |
| 'mean': [0.5, 0.5, 0.5], | |
| 'std': [0.5, 0.5, 0.5], | |
| } | |
| elif backbone in NORMALIZED_IMAGENET: | |
| return { | |
| 'mean': [0.485, 0.456, 0.406], | |
| 'std': [0.229, 0.224, 0.225], | |
| } | |
| elif backbone in NORMALIZED_CLIP: | |
| return { | |
| 'mean': [0.48145466, 0.4578275, 0.40821073], | |
| 'std': [0.26862954, 0.26130258, 0.27577711], | |
| } | |
| else: | |
| raise NotImplementedError | |
| def _make_scratch_ccm(scratch, in_channels, cout, expand=False): | |
| # shapes | |
| out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4 | |
| scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True) | |
| scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True) | |
| scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True) | |
| scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True) | |
| scratch.CHANNELS = out_channels | |
| return scratch | |
| def _make_scratch_csm(scratch, in_channels, cout, expand): | |
| scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True) | |
| scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand) | |
| scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand) | |
| scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False)) | |
| # last refinenet does not expand to save channels in higher dimensions | |
| scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4 | |
| return scratch | |
| def _make_projector(im_res, backbone, cout, proj_type, expand=False): | |
| assert proj_type in [0, 1, 2], "Invalid projection type" | |
| ### Build pretrained feature network | |
| pretrained = _make_pretrained(backbone) | |
| # Following Projected GAN | |
| im_res = 256 | |
| pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32] | |
| if proj_type == 0: return pretrained, None | |
| ### Build CCM | |
| scratch = nn.Module() | |
| scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand) | |
| pretrained.CHANNELS = scratch.CHANNELS | |
| if proj_type == 1: return pretrained, scratch | |
| ### build CSM | |
| scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand) | |
| # CSM upsamples x2 so the feature map resolution doubles | |
| pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS] | |
| pretrained.CHANNELS = scratch.CHANNELS | |
| return pretrained, scratch | |
| class F_Identity(nn.Module): | |
| def forward(self, x): | |
| return x | |
| class F_RandomProj(nn.Module): | |
| def __init__( | |
| self, | |
| backbone="tf_efficientnet_lite3", | |
| im_res=256, | |
| cout=64, | |
| expand=True, | |
| proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.proj_type = proj_type | |
| self.backbone = backbone | |
| self.cout = cout | |
| self.expand = expand | |
| self.normstats = get_backbone_normstats(backbone) | |
| # build pretrained feature network and random decoder (scratch) | |
| self.pretrained, self.scratch = _make_projector(im_res=im_res, backbone=self.backbone, cout=self.cout, | |
| proj_type=self.proj_type, expand=self.expand) | |
| self.CHANNELS = self.pretrained.CHANNELS | |
| self.RESOLUTIONS = self.pretrained.RESOLUTIONS | |
| def forward(self, x): | |
| # predict feature maps | |
| if self.backbone in VITS: | |
| out0, out1, out2, out3 = forward_vit(self.pretrained, x) | |
| else: | |
| out0 = self.pretrained.layer0(x) | |
| out1 = self.pretrained.layer1(out0) | |
| out2 = self.pretrained.layer2(out1) | |
| out3 = self.pretrained.layer3(out2) | |
| # start enumerating at the lowest layer (this is where we put the first discriminator) | |
| out = { | |
| '0': out0, | |
| '1': out1, | |
| '2': out2, | |
| '3': out3, | |
| } | |
| if self.proj_type == 0: return out | |
| out0_channel_mixed = self.scratch.layer0_ccm(out['0']) | |
| out1_channel_mixed = self.scratch.layer1_ccm(out['1']) | |
| out2_channel_mixed = self.scratch.layer2_ccm(out['2']) | |
| out3_channel_mixed = self.scratch.layer3_ccm(out['3']) | |
| out = { | |
| '0': out0_channel_mixed, | |
| '1': out1_channel_mixed, | |
| '2': out2_channel_mixed, | |
| '3': out3_channel_mixed, | |
| } | |
| if self.proj_type == 1: return out | |
| # from bottom to top | |
| out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed) | |
| out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed) | |
| out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed) | |
| out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed) | |
| out = { | |
| '0': out0_scale_mixed, | |
| '1': out1_scale_mixed, | |
| '2': out2_scale_mixed, | |
| '3': out3_scale_mixed, | |
| } | |
| return out | |