import math from telnetlib import PRAGMA_HEARTBEAT import torch from torch import nn from torch.nn import init from torch.nn import functional as F from einops import rearrange class TargetEmbedding(nn.Module): def __init__(self, dim, feature_channel_num): super().__init__() self.linear = nn.Sequential( nn.Linear(feature_channel_num, dim), nn.GELU(), nn.Linear(dim, dim), ) def forward(self,mix): mix_linear = self.linear(mix) return mix_linear class DownSample(nn.Module): def __init__(self, in_ch): super().__init__() self.c1 = nn.Sequential(nn.GELU(), nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1, padding_mode='reflect')) self.c2 = nn.Sequential(nn.GELU(), nn.Conv2d(in_ch, in_ch, 5, stride=2, padding=2, padding_mode='reflect')) def forward(self, x, target_emb): x = self.c1(x)+self.c2(x) return x class UpSample(nn.Module): def __init__(self, in_ch): super().__init__() self.t = nn.Sequential(nn.GELU(), nn.GroupNorm(16, in_ch), nn.ConvTranspose2d(in_ch, in_ch, 5, 2, 2, 1)) def forward(self, x, target_emb): _, _, H, W = x.shape x = self.t(x) return x class AttnBlock(nn.Module): def __init__(self, in_ch): super().__init__() self.norm = nn.GroupNorm(16,in_ch) self.to_kv=nn.Conv2d(in_ch,in_ch*2,1) self.out = nn.Sequential(nn.GroupNorm(16,in_ch), nn.GELU(), nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)) def forward(self, x): B, C, H, W = x.shape x = self.norm(x) q_scale = int(C) ** (-0.5) kv = self.to_kv(x).chunk(2,dim=1) k,v = map(lambda t: rearrange(t, 'b c x y -> b c (x y)'),kv) q = F.softmax(k,dim=-2) k = F.softmax(k,dim=-1) q = q*q_scale context = torch.einsum('b d n, b e n -> b d e',k,v) assert list(context.shape) == [B, C, C] out = torch.einsum('b d e, b d n -> b e n',context,q) assert list(out.shape) == [B, C, H*W] out = rearrange(out,'b c (i j) -> b c i j', i=H ,j=W) out = self.out(out) return x + out class ResBlock(nn.Module): def __init__(self, in_ch, out_ch, tdim, attn=True): super().__init__() self.block1 = nn.Sequential( nn.GELU(), nn.GroupNorm(16, in_ch), nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1, padding_mode='reflect'), ) self.target_proj = nn.Sequential( nn.GELU(), nn.Linear(tdim, out_ch), ) self.block2 = nn.Sequential( nn.GELU(), nn.GroupNorm(16, out_ch), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, padding_mode='reflect'), ) if in_ch != out_ch: self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) else: self.shortcut = nn.Identity() if attn: self.attn = AttnBlock(out_ch) else: self.attn = nn.Identity() def forward(self, x , target): h = self.block1(x) h += self.target_proj(target)[:, :, None, None] h = self.block2(h) B,C,H,W = h.size() layer_norm = nn.LayerNorm([C,H,W]) h = layer_norm.to(h.device)(h) h = h + self.shortcut(x) h = self.attn(h) return h class GCT(nn.Module): def __init__(self, num_channels, tdim, epsilon=1e-5, mode='l2', after_relu=False): super(GCT, self).__init__() self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1)) self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) self.epsilon = epsilon self.mode = mode self.after_relu = after_relu def forward(self, x): b, c, h, w = x.shape if self.mode == 'l2': embedding = (x.pow(2).sum((2, 3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha norm = self.gamma / \ (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5) elif self.mode == 'l1': if not self.after_relu: _x = torch.abs(x) else: _x = x embedding = _x.sum((2, 3), keepdim=True) * self.alpha norm = self.gamma / \ (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon) gate = 1. + torch.tanh(embedding * norm + self.beta) return x * gate class Generator(nn.Module): def __init__(self, num_target,feature_channel_num, ch, ch_mult, num_res_blocks,inception=False): super().__init__() tdim = ch * 4 self.target_embedding = TargetEmbedding(tdim,feature_channel_num) self.head = nn.Sequential( nn.Conv2d(3, ch, kernel_size=5, padding=2, padding_mode='reflect'), nn.GELU(), nn.GroupNorm(16, ch) ) self.crop = nn.ConstantPad2d((0, -1, -1, 0), 0) self.inception = inception self.downblocks = nn.ModuleList() chs = [ch] now_ch = ch feature_channel = feature_channel_num for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim)) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, attn=True), ResBlock(now_ch, now_ch, tdim, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, attn=False)) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.gct = GCT(num_channels=now_ch,tdim=tdim) self.tail = nn.Sequential( nn.GroupNorm(16, now_ch), nn.GELU(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1, padding_mode='reflect'), ) def weight_init(self): for m in self.modules(): if isinstance(m,torch.nn.Conv2d): init.xavier_uniform_(m.weight.data) if m.bias is not None: init.constant_(m.bias.data,0.1) elif isinstance(m,torch.nn.GroupNorm): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m,torch.nn.Linear): m.weight.data.normal_(0,0.01) if m.bias is not None: m.bias.data.zero_() def forward(self, x,mix): targetemb = self.target_embedding(mix) h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, targetemb) hs.append(h) for layer in self.middleblocks: h = layer(h, targetemb) for layer in self.upblocks: if isinstance(layer, ResBlock): hs_pop = hs.pop() if hs_pop.size(2) != h.size(2) and self.inception: h = self.crop(h) h = torch.cat([h, hs_pop], dim=1) h = layer(h, targetemb) h = self.gct(h) h = self.tail(h) assert len(hs) == 0 return (torch.tanh(h) + 1) / 2