Spaces:
Runtime error
Runtime error
| import functools | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import init | |
| from diffusers import ModelMixin | |
| from diffusers.configuration_utils import (ConfigMixin, | |
| register_to_config) | |
| def proj(x, y): | |
| return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) | |
| def gram_schmidt(x, ys): | |
| for y in ys: | |
| x = x - proj(x, y) | |
| return x | |
| def power_iteration(W, u_, update=True, eps=1e-12): | |
| us, vs, svs = [], [], [] | |
| for i, u in enumerate(u_): | |
| with torch.no_grad(): | |
| v = torch.matmul(u, W) | |
| v = F.normalize(gram_schmidt(v, vs), eps=eps) | |
| vs += [v] | |
| u = torch.matmul(v, W.t()) | |
| u = F.normalize(gram_schmidt(u, us), eps=eps) | |
| us += [u] | |
| if update: | |
| u_[i][:] = u | |
| svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] | |
| return svs, us, vs | |
| class LinearBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim, | |
| out_dim, | |
| norm='none', | |
| act='relu', | |
| use_sn=False | |
| ): | |
| super(LinearBlock, self).__init__() | |
| use_bias = True | |
| self.fc = nn.Linear(in_dim, out_dim, bias=use_bias) | |
| if use_sn: | |
| self.fc = nn.utils.spectral_norm(self.fc) | |
| # initialize normalization | |
| norm_dim = out_dim | |
| if norm == 'bn': | |
| self.norm = nn.BatchNorm1d(norm_dim) | |
| elif norm == 'in': | |
| self.norm = nn.InstanceNorm1d(norm_dim) | |
| elif norm == 'none': | |
| self.norm = None | |
| else: | |
| assert 0, "Unsupported normalization: {}".format(norm) | |
| # initialize activation | |
| if act == 'relu': | |
| self.activation = nn.ReLU(inplace=True) | |
| elif act == 'lrelu': | |
| self.activation = nn.LeakyReLU(0.2, inplace=True) | |
| elif act == 'tanh': | |
| self.activation = nn.Tanh() | |
| elif act == 'none': | |
| self.activation = None | |
| else: | |
| assert 0, "Unsupported activation: {}".format(act) | |
| def forward(self, x): | |
| out = self.fc(x) | |
| if self.norm: | |
| out = self.norm(out) | |
| if self.activation: | |
| out = self.activation(out) | |
| return out | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| nf_in, | |
| nf_out, | |
| nf_mlp, | |
| num_blocks, | |
| norm, | |
| act, | |
| use_sn =False | |
| ): | |
| super(MLP,self).__init__() | |
| self.model = nn.ModuleList() | |
| nf = nf_mlp | |
| self.model.append(LinearBlock(nf_in, nf, norm = norm, act = act, use_sn = use_sn)) | |
| for _ in range((num_blocks - 2)): | |
| self.model.append(LinearBlock(nf, nf, norm=norm, act=act, use_sn=use_sn)) | |
| self.model.append(LinearBlock(nf, nf_out, norm='none', act ='none', use_sn = use_sn)) | |
| self.model = nn.Sequential(*self.model) | |
| def forward(self, x): | |
| return self.model(x.view(x.size(0), -1)) | |
| class SN(object): | |
| def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): | |
| self.num_itrs = num_itrs | |
| self.num_svs = num_svs | |
| self.transpose = transpose | |
| self.eps = eps | |
| for i in range(self.num_svs): | |
| self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) | |
| self.register_buffer('sv%d' % i, torch.ones(1)) | |
| def u(self): | |
| return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] | |
| def sv(self): | |
| return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] | |
| def W_(self): | |
| W_mat = self.weight.view(self.weight.size(0), -1) | |
| if self.transpose: | |
| W_mat = W_mat.t() | |
| for _ in range(self.num_itrs): | |
| svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) | |
| if self.training: | |
| with torch.no_grad(): | |
| for i, sv in enumerate(svs): | |
| self.sv[i][:] = sv | |
| return self.weight / svs[0] | |
| class SNConv2d(nn.Conv2d, SN): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, | |
| padding=0, dilation=1, groups=1, bias=True, | |
| num_svs=1, num_itrs=1, eps=1e-12): | |
| nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, | |
| padding, dilation, groups, bias) | |
| SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) | |
| def forward(self, x): | |
| return F.conv2d(x, self.W_(), self.bias, self.stride, | |
| self.padding, self.dilation, self.groups) | |
| def forward_wo_sn(self, x): | |
| return F.conv2d(x, self.weight, self.bias, self.stride, | |
| self.padding, self.dilation, self.groups) | |
| class SNLinear(nn.Linear, SN): | |
| def __init__(self, in_features, out_features, bias=True, | |
| num_svs=1, num_itrs=1, eps=1e-12): | |
| nn.Linear.__init__(self, in_features, out_features, bias) | |
| SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) | |
| def forward(self, x): | |
| return F.linear(x, self.W_(), self.bias) | |
| class DBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True, | |
| preactivation=False, activation=None, downsample=None,): | |
| super(DBlock, self).__init__() | |
| self.in_channels, self.out_channels = in_channels, out_channels | |
| self.hidden_channels = self.out_channels if wide else self.in_channels | |
| self.which_conv = which_conv | |
| self.preactivation = preactivation | |
| self.activation = activation | |
| self.downsample = downsample | |
| # Conv layers | |
| self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) | |
| self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) | |
| self.learnable_sc = True if (in_channels != out_channels) or downsample else False | |
| if self.learnable_sc: | |
| self.conv_sc = self.which_conv(in_channels, out_channels, | |
| kernel_size=1, padding=0) | |
| def shortcut(self, x): | |
| if self.preactivation: | |
| if self.learnable_sc: | |
| x = self.conv_sc(x) | |
| if self.downsample: | |
| x = self.downsample(x) | |
| else: | |
| if self.downsample: | |
| x = self.downsample(x) | |
| if self.learnable_sc: | |
| x = self.conv_sc(x) | |
| return x | |
| def forward(self, x): | |
| if self.preactivation: | |
| h = F.relu(x) | |
| else: | |
| h = x | |
| h = self.conv1(h) | |
| h = self.conv2(self.activation(h)) | |
| if self.downsample: | |
| h = self.downsample(h) | |
| return h + self.shortcut(x) | |
| class GBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, | |
| which_conv=nn.Conv2d,which_bn= nn.BatchNorm2d, activation=None, | |
| upsample=None): | |
| super(GBlock, self).__init__() | |
| self.in_channels, self.out_channels = in_channels, out_channels | |
| self.which_conv,self.which_bn =which_conv, which_bn | |
| self.activation = activation | |
| self.upsample = upsample | |
| # Conv layers | |
| self.conv1 = self.which_conv(self.in_channels, self.out_channels) | |
| self.conv2 = self.which_conv(self.out_channels, self.out_channels) | |
| self.learnable_sc = in_channels != out_channels or upsample | |
| if self.learnable_sc: | |
| self.conv_sc = self.which_conv(in_channels, out_channels, | |
| kernel_size=1, padding=0) | |
| # Batchnorm layers | |
| self.bn1 = self.which_bn(in_channels) | |
| self.bn2 = self.which_bn(out_channels) | |
| # upsample layers | |
| self.upsample = upsample | |
| def forward(self, x): | |
| h = self.activation(self.bn1(x)) | |
| if self.upsample: | |
| h = self.upsample(h) | |
| x = self.upsample(x) | |
| h = self.conv1(h) | |
| h = self.activation(self.bn2(h)) | |
| h = self.conv2(h) | |
| if self.learnable_sc: | |
| x = self.conv_sc(x) | |
| return h + x | |
| class GBlock2(nn.Module): | |
| def __init__(self, in_channels, out_channels, | |
| which_conv=nn.Conv2d, activation=None, | |
| upsample=None, skip_connection = True): | |
| super(GBlock2, self).__init__() | |
| self.in_channels, self.out_channels = in_channels, out_channels | |
| self.which_conv = which_conv | |
| self.activation = activation | |
| self.upsample = upsample | |
| # Conv layers | |
| self.conv1 = self.which_conv(self.in_channels, self.out_channels) | |
| self.conv2 = self.which_conv(self.out_channels, self.out_channels) | |
| self.learnable_sc = in_channels != out_channels or upsample | |
| if self.learnable_sc: | |
| self.conv_sc = self.which_conv(in_channels, out_channels, | |
| kernel_size=1, padding=0) | |
| # upsample layers | |
| self.upsample = upsample | |
| self.skip_connection = skip_connection | |
| def forward(self, x): | |
| h = self.activation(x) | |
| if self.upsample: | |
| h = self.upsample(h) | |
| x = self.upsample(x) | |
| h = self.conv1(h) | |
| h = self.activation(h) | |
| h = self.conv2(h) | |
| if self.learnable_sc: | |
| x = self.conv_sc(x) | |
| if self.skip_connection: | |
| out = h + x | |
| else: | |
| out = h | |
| return out | |
| def style_encoder_textedit_addskip_arch(ch =64,out_channel_multiplier = 1, input_nc = 3): | |
| arch = {} | |
| n=2 | |
| arch[96] = {'in_channels': [input_nc] + [ch*item for item in [1,2,4,8]], | |
| 'out_channels' : [item * ch for item in [1,2,4,8,16]], | |
| 'resolution': [48,24,12,6,3]} | |
| arch[128] = {'in_channels': [input_nc] + [ch*item for item in [1,2,4,8]], | |
| 'out_channels' : [item * ch for item in [1,2,4,8,16]], | |
| 'resolution': [64,32,16,8,4]} | |
| arch[256] = {'in_channels':[input_nc]+[ch*item for item in [1,2,4,8,8]], | |
| 'out_channels':[item*ch for item in [1,2,4,8,8,16]], | |
| 'resolution': [128,64,32,16,8,4]} | |
| return arch | |
| class StyleEncoder(ModelMixin, ConfigMixin): | |
| """ | |
| This class is to encode the style image to image embedding. | |
| Downsample scale is 32. | |
| For example: | |
| Input: Shape[Batch, 3, 128, 128] | |
| Output: Shape[Batch, 255, 4, 4] | |
| """ | |
| def __init__( | |
| self, | |
| G_ch=64, | |
| G_wide=True, | |
| resolution=128, | |
| G_kernel_size=3, | |
| G_attn='64_32_16_8', | |
| n_classes=1000, | |
| num_G_SVs=1, | |
| num_G_SV_itrs=1, | |
| G_activation=nn.ReLU(inplace=False), | |
| SN_eps=1e-12, | |
| output_dim=1, | |
| G_fp16=False, | |
| G_init='N02', | |
| G_param='SN', | |
| nf_mlp = 512, | |
| nEmbedding = 256, | |
| input_nc = 3, | |
| output_nc = 3 | |
| ): | |
| super(StyleEncoder, self).__init__() | |
| self.ch = G_ch | |
| self.G_wide = G_wide | |
| self.resolution = resolution | |
| self.kernel_size = G_kernel_size | |
| self.attention = G_attn | |
| self.n_classes = n_classes | |
| self.activation = G_activation | |
| self.init = G_init | |
| self.G_param = G_param | |
| self.SN_eps = SN_eps | |
| self.fp16 = G_fp16 | |
| if self.resolution == 96: | |
| self.save_featrues = [0,1,2,3,4] | |
| if self.resolution == 128: | |
| self.save_featrues = [0,1,2,3,4] | |
| elif self.resolution == 256: | |
| self.save_featrues = [0,1,2,3,4,5] | |
| self.out_channel_nultipiler = 1 | |
| self.arch = style_encoder_textedit_addskip_arch( | |
| self.ch, | |
| self.out_channel_nultipiler, | |
| input_nc | |
| )[resolution] | |
| if self.G_param == 'SN': | |
| self.which_conv = functools.partial( | |
| SNConv2d, | |
| kernel_size=3, padding=1, | |
| num_svs=num_G_SVs, | |
| num_itrs=num_G_SV_itrs, | |
| eps=self.SN_eps | |
| ) | |
| self.which_linear = functools.partial( | |
| SNLinear, | |
| num_svs=num_G_SVs, | |
| num_itrs=num_G_SV_itrs, | |
| eps=self.SN_eps | |
| ) | |
| self.blocks = [] | |
| for index in range(len(self.arch['out_channels'])): | |
| self.blocks += [[DBlock( | |
| in_channels=self.arch['in_channels'][index], | |
| out_channels=self.arch['out_channels'][index], | |
| which_conv=self.which_conv, | |
| wide=self.G_wide, | |
| activation=self.activation, | |
| preactivation=(index > 0), | |
| downsample=nn.AvgPool2d(2) | |
| )]] | |
| self.blocks = nn.ModuleList([ | |
| nn.ModuleList(block) for block in self.blocks | |
| ]) | |
| last_layer = nn.Sequential( | |
| nn.InstanceNorm2d(self.arch['out_channels'][-1]), | |
| self.activation, | |
| nn.Conv2d( | |
| self.arch['out_channels'][-1], | |
| self.arch['out_channels'][-1], | |
| kernel_size=1, | |
| stride=1 | |
| ) | |
| ) | |
| self.blocks.append(last_layer) | |
| self.init_weights() | |
| def init_weights(self): | |
| self.param_count = 0 | |
| for module in self.modules(): | |
| if (isinstance(module, nn.Conv2d) | |
| or isinstance(module, nn.Linear) | |
| or isinstance(module, nn.Embedding)): | |
| if self.init == 'ortho': | |
| init.orthogonal_(module.weight) | |
| elif self.init == 'N02': | |
| init.normal_(module.weight, 0, 0.02) | |
| elif self.init in ['glorot', 'xavier']: | |
| init.xavier_uniform_(module.weight) | |
| else: | |
| print('Init style not recognized...') | |
| self.param_count += sum([p.data.nelement() for p in module.parameters()]) | |
| print('Param count for D''s initialized parameters: %d' % self.param_count) | |
| def forward(self,x): | |
| h = x | |
| residual_features = [] | |
| residual_features.append(h) | |
| for index, blocklist in enumerate(self.blocks): | |
| for block in blocklist: | |
| h = block(h) | |
| if index in self.save_featrues[:-1]: | |
| residual_features.append(h) | |
| h = self.blocks[-1](h) | |
| style_emd = h | |
| h = F.adaptive_avg_pool2d(h,(1,1)) | |
| h = h.view(h.size(0),-1) | |
| return style_emd,h,residual_features | |