lucky0146 commited on
Commit
cf32411
·
verified ·
1 Parent(s): 862663e

Update codeformer_arch.py

Browse files
Files changed (1) hide show
  1. codeformer_arch.py +44 -88
codeformer_arch.py CHANGED
@@ -5,18 +5,12 @@ from torch import nn, Tensor
5
  import torch.nn.functional as F
6
  from typing import Optional, List
7
 
8
- from vqgan_arch import *
9
  from basicsr.utils import get_root_logger
10
  from basicsr.utils.registry import ARCH_REGISTRY
11
 
12
  def calc_mean_std(feat, eps=1e-5):
13
- """Calculate mean and std for adaptive_instance_normalization.
14
-
15
- Args:
16
- feat (Tensor): 4D tensor.
17
- eps (float): A small value added to the variance to avoid
18
- divide-by-zero. Default: 1e-5.
19
- """
20
  size = feat.size()
21
  assert len(size) == 4, 'The input feature should be 4D tensor.'
22
  b, c = size[:2]
@@ -25,30 +19,15 @@ def calc_mean_std(feat, eps=1e-5):
25
  feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
26
  return feat_mean, feat_std
27
 
28
-
29
  def adaptive_instance_normalization(content_feat, style_feat):
30
- """Adaptive instance normalization.
31
-
32
- Adjust the reference features to have the similar color and illuminations
33
- as those in the degradate features.
34
-
35
- Args:
36
- content_feat (Tensor): The reference feature.
37
- style_feat (Tensor): The degradate features.
38
- """
39
  size = content_feat.size()
40
  style_mean, style_std = calc_mean_std(style_feat)
41
  content_mean, content_std = calc_mean_std(content_feat)
42
  normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
43
  return normalized_feat * style_std.expand(size) + style_mean.expand(size)
44
 
45
-
46
  class PositionEmbeddingSine(nn.Module):
47
- """
48
- This is a more standard version of the position embedding, very similar to the one
49
- used by the Attention is all you need paper, generalized to work on images.
50
- """
51
-
52
  def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
53
  super().__init__()
54
  self.num_pos_feats = num_pos_feats
@@ -93,14 +72,12 @@ def _get_activation_fn(activation):
93
  return F.gelu
94
  if activation == "glu":
95
  return F.glu
96
- raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
97
-
98
 
99
  class TransformerSALayer(nn.Module):
100
  def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
101
  super().__init__()
102
  self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
103
- # Implementation of Feedforward model - MLP
104
  self.linear1 = nn.Linear(embed_dim, dim_mlp)
105
  self.dropout = nn.Dropout(dropout)
106
  self.linear2 = nn.Linear(dim_mlp, embed_dim)
@@ -119,15 +96,11 @@ class TransformerSALayer(nn.Module):
119
  tgt_mask: Optional[Tensor] = None,
120
  tgt_key_padding_mask: Optional[Tensor] = None,
121
  query_pos: Optional[Tensor] = None):
122
-
123
- # self attention
124
  tgt2 = self.norm1(tgt)
125
  q = k = self.with_pos_embed(tgt2, query_pos)
126
  tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
127
  key_padding_mask=tgt_key_padding_mask)[0]
128
  tgt = tgt + self.dropout1(tgt2)
129
-
130
- # ffn
131
  tgt2 = self.norm2(tgt)
132
  tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
133
  tgt = tgt + self.dropout2(tgt2)
@@ -136,17 +109,21 @@ class TransformerSALayer(nn.Module):
136
  class Fuse_sft_block(nn.Module):
137
  def __init__(self, in_ch, out_ch):
138
  super().__init__()
139
- self.encode_enc = ResBlock(2*in_ch, out_ch)
140
-
 
 
 
141
  self.scale = nn.Sequential(
142
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
143
- nn.LeakyReLU(0.2, True),
144
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
145
-
146
  self.shift = nn.Sequential(
147
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
148
- nn.LeakyReLU(0.2, True),
149
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
 
150
 
151
  def forward(self, enc_feat, dec_feat, w=1):
152
  enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
@@ -156,18 +133,18 @@ class Fuse_sft_block(nn.Module):
156
  out = dec_feat + residual
157
  return out
158
 
159
-
160
  @ARCH_REGISTRY.register()
161
  class CodeFormer(VQAutoEncoder):
162
  def __init__(self, dim_embd=512, n_head=8, n_layers=9,
163
- codebook_size=1024, latent_size=256,
164
- connect_list=['32', '64', '128', '256'],
165
- fix_modules=['quantize','generator'], vqgan_path=None):
166
- super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
 
 
167
 
168
  if vqgan_path is not None:
169
- self.load_state_dict(
170
- torch.load(vqgan_path, map_location='cpu')['params_ema'])
171
 
172
  if fix_modules is not None:
173
  for module in fix_modules:
@@ -177,19 +154,18 @@ class CodeFormer(VQAutoEncoder):
177
  self.connect_list = connect_list
178
  self.n_layers = n_layers
179
  self.dim_embd = dim_embd
180
- self.dim_mlp = dim_embd*2
181
 
182
  self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
183
  self.feat_emb = nn.Linear(256, self.dim_embd)
184
 
185
- # transformer
186
  self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
187
- for _ in range(self.n_layers)])
188
 
189
- # logits_predict head
190
  self.idx_pred_layer = nn.Sequential(
191
  nn.LayerNorm(dim_embd),
192
- nn.Linear(dim_embd, codebook_size, bias=False))
 
193
 
194
  self.channels = {
195
  '16': 512,
@@ -200,12 +176,9 @@ class CodeFormer(VQAutoEncoder):
200
  '512': 64,
201
  }
202
 
203
- # after second residual block for > 16, before attn layer for ==16
204
- self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
205
- # after first residual block for > 16, before attn layer for ==16
206
- self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
207
 
208
- # fuse_convs_dict
209
  self.fuse_convs_dict = nn.ModuleDict()
210
  for f_size in self.connect_list:
211
  in_ch = self.channels[f_size]
@@ -221,60 +194,43 @@ class CodeFormer(VQAutoEncoder):
221
  module.weight.data.fill_(1.0)
222
 
223
  def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
224
- # ################### Encoder #####################
225
  enc_feat_dict = {}
226
  out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
227
- for i, block in enumerate(self.encoder.blocks):
228
- x = block(x)
229
  if i in out_list:
230
  enc_feat_dict[str(x.shape[-1])] = x.clone()
231
 
232
  lq_feat = x
233
- # ################# Transformer ###################
234
- # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
235
- pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
236
- # BCHW -> BC(HW) -> (HW)BC
237
- feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
238
  query_emb = feat_emb
239
- # Transformer encoder
240
  for layer in self.ft_layers:
241
  query_emb = layer(query_emb, query_pos=pos_emb)
242
 
243
- # output logits
244
- logits = self.idx_pred_layer(query_emb) # (hw)bn
245
- logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
246
 
247
- if code_only: # for training stage II
248
- # logits doesn't need softmax before cross_entropy loss
249
  return logits, lq_feat
250
 
251
- # ################# Quantization ###################
252
- # if self.training:
253
- # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
254
- # # b(hw)c -> bc(hw) -> bchw
255
- # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
256
- # ------------
257
  soft_one_hot = F.softmax(logits, dim=2)
258
  _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
259
- quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
260
- # preserve gradients
261
- # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
262
 
263
  if detach_16:
264
- quant_feat = quant_feat.detach() # for training stage III
265
  if adain:
266
  quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
267
 
268
- # ################## Generator ####################
269
  x = quant_feat
270
  fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
271
-
272
- for i, block in enumerate(self.generator.blocks):
273
- x = block(x)
274
- if i in fuse_list: # fuse after i-th block
275
  f_size = str(x.shape[-1])
276
- if w>0:
277
  x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
278
  out = x
279
- # logits doesn't need softmax before cross_entropy loss
280
  return out, logits, lq_feat
 
5
  import torch.nn.functional as F
6
  from typing import Optional, List
7
 
8
+ from vqgan_arch import * # Custom import from root
9
  from basicsr.utils import get_root_logger
10
  from basicsr.utils.registry import ARCH_REGISTRY
11
 
12
  def calc_mean_std(feat, eps=1e-5):
13
+ """Calculate mean and std for adaptive_instance_normalization."""
 
 
 
 
 
 
14
  size = feat.size()
15
  assert len(size) == 4, 'The input feature should be 4D tensor.'
16
  b, c = size[:2]
 
19
  feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
20
  return feat_mean, feat_std
21
 
 
22
  def adaptive_instance_normalization(content_feat, style_feat):
23
+ """Adaptive instance normalization."""
 
 
 
 
 
 
 
 
24
  size = content_feat.size()
25
  style_mean, style_std = calc_mean_std(style_feat)
26
  content_mean, content_std = calc_mean_std(content_feat)
27
  normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
28
  return normalized_feat * style_std.expand(size) + style_mean.expand(size)
29
 
 
30
  class PositionEmbeddingSine(nn.Module):
 
 
 
 
 
31
  def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
32
  super().__init__()
33
  self.num_pos_feats = num_pos_feats
 
72
  return F.gelu
73
  if activation == "glu":
74
  return F.glu
75
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
 
76
 
77
  class TransformerSALayer(nn.Module):
78
  def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
79
  super().__init__()
80
  self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
 
81
  self.linear1 = nn.Linear(embed_dim, dim_mlp)
82
  self.dropout = nn.Dropout(dropout)
83
  self.linear2 = nn.Linear(dim_mlp, embed_dim)
 
96
  tgt_mask: Optional[Tensor] = None,
97
  tgt_key_padding_mask: Optional[Tensor] = None,
98
  query_pos: Optional[Tensor] = None):
 
 
99
  tgt2 = self.norm1(tgt)
100
  q = k = self.with_pos_embed(tgt2, query_pos)
101
  tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
102
  key_padding_mask=tgt_key_padding_mask)[0]
103
  tgt = tgt + self.dropout1(tgt2)
 
 
104
  tgt2 = self.norm2(tgt)
105
  tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
106
  tgt = tgt + self.dropout2(tgt2)
 
109
  class Fuse_sft_block(nn.Module):
110
  def __init__(self, in_ch, out_ch):
111
  super().__init__()
112
+ self.encode_enc = nn.Sequential(
113
+ nn.Conv2d(2 * in_ch, out_ch, 3, 1, 1, bias=False),
114
+ nn.BatchNorm2d(out_ch),
115
+ nn.LeakyReLU(0.2, inplace=True)
116
+ )
117
  self.scale = nn.Sequential(
118
+ nn.Conv2d(in_ch, out_ch, 3, 1, 1),
119
+ nn.LeakyReLU(0.2, True),
120
+ nn.Conv2d(out_ch, out_ch, 3, 1, 1)
121
+ )
122
  self.shift = nn.Sequential(
123
+ nn.Conv2d(in_ch, out_ch, 3, 1, 1),
124
+ nn.LeakyReLU(0.2, True),
125
+ nn.Conv2d(out_ch, out_ch, 3, 1, 1)
126
+ )
127
 
128
  def forward(self, enc_feat, dec_feat, w=1):
129
  enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
 
133
  out = dec_feat + residual
134
  return out
135
 
 
136
  @ARCH_REGISTRY.register()
137
  class CodeFormer(VQAutoEncoder):
138
  def __init__(self, dim_embd=512, n_head=8, n_layers=9,
139
+ codebook_size=1024, latent_size=256,
140
+ connect_list=['32', '64', '128', '256'],
141
+ fix_modules=['quantize', 'generator'], vqgan_path=None):
142
+ # Adjust down_factor to ensure it works with channel scaling
143
+ down_factor = [1, 2, 2, 4, 4, 8] # Ensure this matches the number of steps
144
+ super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 16, codebook_size)
145
 
146
  if vqgan_path is not None:
147
+ self.load_state_dict(torch.load(vqgan_path, map_location='cpu')['params_ema'])
 
148
 
149
  if fix_modules is not None:
150
  for module in fix_modules:
 
154
  self.connect_list = connect_list
155
  self.n_layers = n_layers
156
  self.dim_embd = dim_embd
157
+ self.dim_mlp = dim_embd * 2
158
 
159
  self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
160
  self.feat_emb = nn.Linear(256, self.dim_embd)
161
 
 
162
  self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
163
+ for _ in range(self.n_layers)])
164
 
 
165
  self.idx_pred_layer = nn.Sequential(
166
  nn.LayerNorm(dim_embd),
167
+ nn.Linear(dim_embd, codebook_size, bias=False)
168
+ )
169
 
170
  self.channels = {
171
  '16': 512,
 
176
  '512': 64,
177
  }
178
 
179
+ self.fuse_encoder_block = {'512': 2, '256': 5, '128': 8, '64': 11, '32': 14, '16': 18}
180
+ self.fuse_generator_block = {'16': 6, '32': 9, '64': 12, '128': 15, '256': 18, '512': 21}
 
 
181
 
 
182
  self.fuse_convs_dict = nn.ModuleDict()
183
  for f_size in self.connect_list:
184
  in_ch = self.channels[f_size]
 
194
  module.weight.data.fill_(1.0)
195
 
196
  def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
 
197
  enc_feat_dict = {}
198
  out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
199
+ for i, block in enumerate(self.encoder):
200
+ x = block(x)
201
  if i in out_list:
202
  enc_feat_dict[str(x.shape[-1])] = x.clone()
203
 
204
  lq_feat = x
205
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
206
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
 
 
 
207
  query_emb = feat_emb
208
+
209
  for layer in self.ft_layers:
210
  query_emb = layer(query_emb, query_pos=pos_emb)
211
 
212
+ logits = self.idx_pred_layer(query_emb)
213
+ logits = logits.permute(1, 0, 2)
 
214
 
215
+ if code_only:
 
216
  return logits, lq_feat
217
 
 
 
 
 
 
 
218
  soft_one_hot = F.softmax(logits, dim=2)
219
  _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
220
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0], 16, 16, 256])
 
 
221
 
222
  if detach_16:
223
+ quant_feat = quant_feat.detach()
224
  if adain:
225
  quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
226
 
 
227
  x = quant_feat
228
  fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
229
+ for i, block in enumerate(self.decoder):
230
+ x = block(x)
231
+ if i in fuse_list:
 
232
  f_size = str(x.shape[-1])
233
+ if w > 0:
234
  x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
235
  out = x
 
236
  return out, logits, lq_feat