lucky0146 commited on
Commit
cc76843
·
verified ·
1 Parent(s): 8fd5dd4

Update codeformer_arch.py

Browse files
Files changed (1) hide show
  1. codeformer_arch.py +3 -3
codeformer_arch.py CHANGED
@@ -75,9 +75,9 @@ def _get_activation_fn(activation):
75
  raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
76
 
77
  class TransformerSALayer(nn.Module):
78
- def __init__(self, embed_dim, n_head=8, dim_mlp=2048, dropout=0.0, activation="gelu"): # Changed nhead to n_head
79
  super().__init__()
80
- self.self_attn = nn.MultiheadAttention(embed_dim, n_head, dropout=dropout) # Changed nhead to n_head
81
  self.linear1 = nn.Linear(embed_dim, dim_mlp)
82
  self.dropout = nn.Dropout(dropout)
83
  self.linear2 = nn.Linear(dim_mlp, embed_dim)
@@ -140,7 +140,7 @@ class CodeFormer(VQAutoEncoder):
140
  connect_list=['32', '64', '128', '256'],
141
  fix_modules=['quantize', 'decoder'], vqgan_path=None):
142
  down_factor = [1, 2, 2, 4, 4, 8]
143
- super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 16, codebook_size)
144
 
145
  if vqgan_path is not None:
146
  self.load_state_dict(torch.load(vqgan_path, map_location='cpu')['params_ema'])
 
75
  raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
76
 
77
  class TransformerSALayer(nn.Module):
78
+ def __init__(self, embed_dim, n_head=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
79
  super().__init__()
80
+ self.self_attn = nn.MultiheadAttention(embed_dim, n_head, dropout=dropout)
81
  self.linear1 = nn.Linear(embed_dim, dim_mlp)
82
  self.dropout = nn.Dropout(dropout)
83
  self.linear2 = nn.Linear(dim_mlp, embed_dim)
 
140
  connect_list=['32', '64', '128', '256'],
141
  fix_modules=['quantize', 'decoder'], vqgan_path=None):
142
  down_factor = [1, 2, 2, 4, 4, 8]
143
+ super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 256, codebook_size) # Changed z_channels from 16 to 256
144
 
145
  if vqgan_path is not None:
146
  self.load_state_dict(torch.load(vqgan_path, map_location='cpu')['params_ema'])