lucky0146 commited on
Commit
c14ad09
·
verified ·
1 Parent(s): 6b310a6

Update codeformer_arch.py

Browse files
Files changed (1) hide show
  1. codeformer_arch.py +4 -4
codeformer_arch.py CHANGED
@@ -138,7 +138,7 @@ class CodeFormer(VQAutoEncoder):
138
  def __init__(self, dim_embd=512, n_head=8, n_layers=9,
139
  codebook_size=1024, latent_size=256,
140
  connect_list=['32', '64', '128', '256'],
141
- fix_modules=['quantize', 'decoder'], vqgan_path=None): # Changed 'generator' to 'decoder'
142
  down_factor = [1, 2, 2, 4, 4, 8]
143
  super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 16, codebook_size)
144
 
@@ -158,8 +158,8 @@ class CodeFormer(VQAutoEncoder):
158
  self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
159
  self.feat_emb = nn.Linear(256, self.dim_embd)
160
 
161
- self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=nhead, dim_mlp=self.dim_mlp, dropout=0.0)
162
- for _ in range(self.n_layers)])
163
 
164
  self.idx_pred_layer = nn.Sequential(
165
  nn.LayerNorm(dim_embd),
@@ -225,7 +225,7 @@ class CodeFormer(VQAutoEncoder):
225
 
226
  x = quant_feat
227
  fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
228
- for i, block in enumerate(self.decoder): # Changed 'generator' to 'decoder'
229
  x = block(x)
230
  if i in fuse_list:
231
  f_size = str(x.shape[-1])
 
138
  def __init__(self, dim_embd=512, n_head=8, n_layers=9,
139
  codebook_size=1024, latent_size=256,
140
  connect_list=['32', '64', '128', '256'],
141
+ fix_modules=['quantize', 'decoder'], vqgan_path=None):
142
  down_factor = [1, 2, 2, 4, 4, 8]
143
  super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 16, codebook_size)
144
 
 
158
  self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
159
  self.feat_emb = nn.Linear(256, self.dim_embd)
160
 
161
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, n_head=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
162
+ for _ in range(self.n_layers)]) # Changed nhead to n_head
163
 
164
  self.idx_pred_layer = nn.Sequential(
165
  nn.LayerNorm(dim_embd),
 
225
 
226
  x = quant_feat
227
  fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
228
+ for i, block in enumerate(self.decoder):
229
  x = block(x)
230
  if i in fuse_list:
231
  f_size = str(x.shape[-1])