lucky0146 commited on
Commit
56f7c71
·
verified ·
1 Parent(s): 0c96f3e

Create codeformer_arch.py

Browse files
Files changed (1) hide show
  1. codeformer_arch.py +280 -0
codeformer_arch.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn, Tensor
5
+ import torch.nn.functional as F
6
+ from typing import Optional, List
7
+
8
+ from basicsr.archs.vqgan_arch import *
9
+ from basicsr.utils import get_root_logger
10
+ from basicsr.utils.registry import ARCH_REGISTRY
11
+
12
+ def calc_mean_std(feat, eps=1e-5):
13
+ """Calculate mean and std for adaptive_instance_normalization.
14
+
15
+ Args:
16
+ feat (Tensor): 4D tensor.
17
+ eps (float): A small value added to the variance to avoid
18
+ divide-by-zero. Default: 1e-5.
19
+ """
20
+ size = feat.size()
21
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
22
+ b, c = size[:2]
23
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
24
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
25
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
26
+ return feat_mean, feat_std
27
+
28
+
29
+ def adaptive_instance_normalization(content_feat, style_feat):
30
+ """Adaptive instance normalization.
31
+
32
+ Adjust the reference features to have the similar color and illuminations
33
+ as those in the degradate features.
34
+
35
+ Args:
36
+ content_feat (Tensor): The reference feature.
37
+ style_feat (Tensor): The degradate features.
38
+ """
39
+ size = content_feat.size()
40
+ style_mean, style_std = calc_mean_std(style_feat)
41
+ content_mean, content_std = calc_mean_std(content_feat)
42
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
43
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
44
+
45
+
46
+ class PositionEmbeddingSine(nn.Module):
47
+ """
48
+ This is a more standard version of the position embedding, very similar to the one
49
+ used by the Attention is all you need paper, generalized to work on images.
50
+ """
51
+
52
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
53
+ super().__init__()
54
+ self.num_pos_feats = num_pos_feats
55
+ self.temperature = temperature
56
+ self.normalize = normalize
57
+ if scale is not None and normalize is False:
58
+ raise ValueError("normalize should be True if scale is passed")
59
+ if scale is None:
60
+ scale = 2 * math.pi
61
+ self.scale = scale
62
+
63
+ def forward(self, x, mask=None):
64
+ if mask is None:
65
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
66
+ not_mask = ~mask
67
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
68
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
69
+ if self.normalize:
70
+ eps = 1e-6
71
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
72
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
73
+
74
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
75
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
76
+
77
+ pos_x = x_embed[:, :, :, None] / dim_t
78
+ pos_y = y_embed[:, :, :, None] / dim_t
79
+ pos_x = torch.stack(
80
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
81
+ ).flatten(3)
82
+ pos_y = torch.stack(
83
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
84
+ ).flatten(3)
85
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
86
+ return pos
87
+
88
+ def _get_activation_fn(activation):
89
+ """Return an activation function given a string"""
90
+ if activation == "relu":
91
+ return F.relu
92
+ if activation == "gelu":
93
+ return F.gelu
94
+ if activation == "glu":
95
+ return F.glu
96
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
97
+
98
+
99
+ class TransformerSALayer(nn.Module):
100
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
101
+ super().__init__()
102
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
103
+ # Implementation of Feedforward model - MLP
104
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
105
+ self.dropout = nn.Dropout(dropout)
106
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
107
+
108
+ self.norm1 = nn.LayerNorm(embed_dim)
109
+ self.norm2 = nn.LayerNorm(embed_dim)
110
+ self.dropout1 = nn.Dropout(dropout)
111
+ self.dropout2 = nn.Dropout(dropout)
112
+
113
+ self.activation = _get_activation_fn(activation)
114
+
115
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
116
+ return tensor if pos is None else tensor + pos
117
+
118
+ def forward(self, tgt,
119
+ tgt_mask: Optional[Tensor] = None,
120
+ tgt_key_padding_mask: Optional[Tensor] = None,
121
+ query_pos: Optional[Tensor] = None):
122
+
123
+ # self attention
124
+ tgt2 = self.norm1(tgt)
125
+ q = k = self.with_pos_embed(tgt2, query_pos)
126
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
127
+ key_padding_mask=tgt_key_padding_mask)[0]
128
+ tgt = tgt + self.dropout1(tgt2)
129
+
130
+ # ffn
131
+ tgt2 = self.norm2(tgt)
132
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
133
+ tgt = tgt + self.dropout2(tgt2)
134
+ return tgt
135
+
136
+ class Fuse_sft_block(nn.Module):
137
+ def __init__(self, in_ch, out_ch):
138
+ super().__init__()
139
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
140
+
141
+ self.scale = nn.Sequential(
142
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
143
+ nn.LeakyReLU(0.2, True),
144
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
145
+
146
+ self.shift = nn.Sequential(
147
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
148
+ nn.LeakyReLU(0.2, True),
149
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
150
+
151
+ def forward(self, enc_feat, dec_feat, w=1):
152
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
153
+ scale = self.scale(enc_feat)
154
+ shift = self.shift(enc_feat)
155
+ residual = w * (dec_feat * scale + shift)
156
+ out = dec_feat + residual
157
+ return out
158
+
159
+
160
+ @ARCH_REGISTRY.register()
161
+ class CodeFormer(VQAutoEncoder):
162
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
163
+ codebook_size=1024, latent_size=256,
164
+ connect_list=['32', '64', '128', '256'],
165
+ fix_modules=['quantize','generator'], vqgan_path=None):
166
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
167
+
168
+ if vqgan_path is not None:
169
+ self.load_state_dict(
170
+ torch.load(vqgan_path, map_location='cpu')['params_ema'])
171
+
172
+ if fix_modules is not None:
173
+ for module in fix_modules:
174
+ for param in getattr(self, module).parameters():
175
+ param.requires_grad = False
176
+
177
+ self.connect_list = connect_list
178
+ self.n_layers = n_layers
179
+ self.dim_embd = dim_embd
180
+ self.dim_mlp = dim_embd*2
181
+
182
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
183
+ self.feat_emb = nn.Linear(256, self.dim_embd)
184
+
185
+ # transformer
186
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
187
+ for _ in range(self.n_layers)])
188
+
189
+ # logits_predict head
190
+ self.idx_pred_layer = nn.Sequential(
191
+ nn.LayerNorm(dim_embd),
192
+ nn.Linear(dim_embd, codebook_size, bias=False))
193
+
194
+ self.channels = {
195
+ '16': 512,
196
+ '32': 256,
197
+ '64': 256,
198
+ '128': 128,
199
+ '256': 128,
200
+ '512': 64,
201
+ }
202
+
203
+ # after second residual block for > 16, before attn layer for ==16
204
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
205
+ # after first residual block for > 16, before attn layer for ==16
206
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
207
+
208
+ # fuse_convs_dict
209
+ self.fuse_convs_dict = nn.ModuleDict()
210
+ for f_size in self.connect_list:
211
+ in_ch = self.channels[f_size]
212
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
213
+
214
+ def _init_weights(self, module):
215
+ if isinstance(module, (nn.Linear, nn.Embedding)):
216
+ module.weight.data.normal_(mean=0.0, std=0.02)
217
+ if isinstance(module, nn.Linear) and module.bias is not None:
218
+ module.bias.data.zero_()
219
+ elif isinstance(module, nn.LayerNorm):
220
+ module.bias.data.zero_()
221
+ module.weight.data.fill_(1.0)
222
+
223
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
224
+ # ################### Encoder #####################
225
+ enc_feat_dict = {}
226
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
227
+ for i, block in enumerate(self.encoder.blocks):
228
+ x = block(x)
229
+ if i in out_list:
230
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
231
+
232
+ lq_feat = x
233
+ # ################# Transformer ###################
234
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
235
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
236
+ # BCHW -> BC(HW) -> (HW)BC
237
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
238
+ query_emb = feat_emb
239
+ # Transformer encoder
240
+ for layer in self.ft_layers:
241
+ query_emb = layer(query_emb, query_pos=pos_emb)
242
+
243
+ # output logits
244
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
245
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
246
+
247
+ if code_only: # for training stage II
248
+ # logits doesn't need softmax before cross_entropy loss
249
+ return logits, lq_feat
250
+
251
+ # ################# Quantization ###################
252
+ # if self.training:
253
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
254
+ # # b(hw)c -> bc(hw) -> bchw
255
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
256
+ # ------------
257
+ soft_one_hot = F.softmax(logits, dim=2)
258
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
259
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
260
+ # preserve gradients
261
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
262
+
263
+ if detach_16:
264
+ quant_feat = quant_feat.detach() # for training stage III
265
+ if adain:
266
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
267
+
268
+ # ################## Generator ####################
269
+ x = quant_feat
270
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
271
+
272
+ for i, block in enumerate(self.generator.blocks):
273
+ x = block(x)
274
+ if i in fuse_list: # fuse after i-th block
275
+ f_size = str(x.shape[-1])
276
+ if w>0:
277
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
278
+ out = x
279
+ # logits doesn't need softmax before cross_entropy loss
280
+ return out, logits, lq_feat