lucky0146 commited on
Commit
6b310a6
·
verified ·
1 Parent(s): cf32411

Update codeformer_arch.py

Browse files
Files changed (1) hide show
  1. codeformer_arch.py +4 -5
codeformer_arch.py CHANGED
@@ -138,9 +138,8 @@ class CodeFormer(VQAutoEncoder):
138
  def __init__(self, dim_embd=512, n_head=8, n_layers=9,
139
  codebook_size=1024, latent_size=256,
140
  connect_list=['32', '64', '128', '256'],
141
- fix_modules=['quantize', 'generator'], vqgan_path=None):
142
- # Adjust down_factor to ensure it works with channel scaling
143
- down_factor = [1, 2, 2, 4, 4, 8] # Ensure this matches the number of steps
144
  super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 16, codebook_size)
145
 
146
  if vqgan_path is not None:
@@ -159,7 +158,7 @@ class CodeFormer(VQAutoEncoder):
159
  self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
160
  self.feat_emb = nn.Linear(256, self.dim_embd)
161
 
162
- self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
163
  for _ in range(self.n_layers)])
164
 
165
  self.idx_pred_layer = nn.Sequential(
@@ -226,7 +225,7 @@ class CodeFormer(VQAutoEncoder):
226
 
227
  x = quant_feat
228
  fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
229
- for i, block in enumerate(self.decoder):
230
  x = block(x)
231
  if i in fuse_list:
232
  f_size = str(x.shape[-1])
 
138
  def __init__(self, dim_embd=512, n_head=8, n_layers=9,
139
  codebook_size=1024, latent_size=256,
140
  connect_list=['32', '64', '128', '256'],
141
+ fix_modules=['quantize', 'decoder'], vqgan_path=None): # Changed 'generator' to 'decoder'
142
+ down_factor = [1, 2, 2, 4, 4, 8]
 
143
  super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 16, codebook_size)
144
 
145
  if vqgan_path is not None:
 
158
  self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
159
  self.feat_emb = nn.Linear(256, self.dim_embd)
160
 
161
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=nhead, dim_mlp=self.dim_mlp, dropout=0.0)
162
  for _ in range(self.n_layers)])
163
 
164
  self.idx_pred_layer = nn.Sequential(
 
225
 
226
  x = quant_feat
227
  fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
228
+ for i, block in enumerate(self.decoder): # Changed 'generator' to 'decoder'
229
  x = block(x)
230
  if i in fuse_list:
231
  f_size = str(x.shape[-1])