Spaces:
Runtime error
Runtime error
Update codeformer_arch.py
Browse files- codeformer_arch.py +3 -3
codeformer_arch.py
CHANGED
|
@@ -75,9 +75,9 @@ def _get_activation_fn(activation):
|
|
| 75 |
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
| 76 |
|
| 77 |
class TransformerSALayer(nn.Module):
|
| 78 |
-
def __init__(self, embed_dim, n_head=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
| 79 |
super().__init__()
|
| 80 |
-
self.self_attn = nn.MultiheadAttention(embed_dim, n_head, dropout=dropout)
|
| 81 |
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
| 82 |
self.dropout = nn.Dropout(dropout)
|
| 83 |
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
@@ -140,7 +140,7 @@ class CodeFormer(VQAutoEncoder):
|
|
| 140 |
connect_list=['32', '64', '128', '256'],
|
| 141 |
fix_modules=['quantize', 'decoder'], vqgan_path=None):
|
| 142 |
down_factor = [1, 2, 2, 4, 4, 8]
|
| 143 |
-
super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1,
|
| 144 |
|
| 145 |
if vqgan_path is not None:
|
| 146 |
self.load_state_dict(torch.load(vqgan_path, map_location='cpu')['params_ema'])
|
|
|
|
| 75 |
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
| 76 |
|
| 77 |
class TransformerSALayer(nn.Module):
|
| 78 |
+
def __init__(self, embed_dim, n_head=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
| 79 |
super().__init__()
|
| 80 |
+
self.self_attn = nn.MultiheadAttention(embed_dim, n_head, dropout=dropout)
|
| 81 |
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
| 82 |
self.dropout = nn.Dropout(dropout)
|
| 83 |
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
|
|
| 140 |
connect_list=['32', '64', '128', '256'],
|
| 141 |
fix_modules=['quantize', 'decoder'], vqgan_path=None):
|
| 142 |
down_factor = [1, 2, 2, 4, 4, 8]
|
| 143 |
+
super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 256, codebook_size) # Changed z_channels from 16 to 256
|
| 144 |
|
| 145 |
if vqgan_path is not None:
|
| 146 |
self.load_state_dict(torch.load(vqgan_path, map_location='cpu')['params_ema'])
|