Spaces:
Runtime error
Runtime error
Update codeformer_arch.py
Browse files- codeformer_arch.py +4 -5
codeformer_arch.py
CHANGED
|
@@ -138,9 +138,8 @@ class CodeFormer(VQAutoEncoder):
|
|
| 138 |
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
| 139 |
codebook_size=1024, latent_size=256,
|
| 140 |
connect_list=['32', '64', '128', '256'],
|
| 141 |
-
fix_modules=['quantize', '
|
| 142 |
-
|
| 143 |
-
down_factor = [1, 2, 2, 4, 4, 8] # Ensure this matches the number of steps
|
| 144 |
super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 16, codebook_size)
|
| 145 |
|
| 146 |
if vqgan_path is not None:
|
|
@@ -159,7 +158,7 @@ class CodeFormer(VQAutoEncoder):
|
|
| 159 |
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
| 160 |
self.feat_emb = nn.Linear(256, self.dim_embd)
|
| 161 |
|
| 162 |
-
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=
|
| 163 |
for _ in range(self.n_layers)])
|
| 164 |
|
| 165 |
self.idx_pred_layer = nn.Sequential(
|
|
@@ -226,7 +225,7 @@ class CodeFormer(VQAutoEncoder):
|
|
| 226 |
|
| 227 |
x = quant_feat
|
| 228 |
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
| 229 |
-
for i, block in enumerate(self.decoder):
|
| 230 |
x = block(x)
|
| 231 |
if i in fuse_list:
|
| 232 |
f_size = str(x.shape[-1])
|
|
|
|
| 138 |
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
| 139 |
codebook_size=1024, latent_size=256,
|
| 140 |
connect_list=['32', '64', '128', '256'],
|
| 141 |
+
fix_modules=['quantize', 'decoder'], vqgan_path=None): # Changed 'generator' to 'decoder'
|
| 142 |
+
down_factor = [1, 2, 2, 4, 4, 8]
|
|
|
|
| 143 |
super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 16, codebook_size)
|
| 144 |
|
| 145 |
if vqgan_path is not None:
|
|
|
|
| 158 |
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
| 159 |
self.feat_emb = nn.Linear(256, self.dim_embd)
|
| 160 |
|
| 161 |
+
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=nhead, dim_mlp=self.dim_mlp, dropout=0.0)
|
| 162 |
for _ in range(self.n_layers)])
|
| 163 |
|
| 164 |
self.idx_pred_layer = nn.Sequential(
|
|
|
|
| 225 |
|
| 226 |
x = quant_feat
|
| 227 |
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
| 228 |
+
for i, block in enumerate(self.decoder): # Changed 'generator' to 'decoder'
|
| 229 |
x = block(x)
|
| 230 |
if i in fuse_list:
|
| 231 |
f_size = str(x.shape[-1])
|