Spaces:
Runtime error
Runtime error
Update codeformer_arch.py
Browse files- codeformer_arch.py +3 -3
codeformer_arch.py
CHANGED
|
@@ -75,9 +75,9 @@ def _get_activation_fn(activation):
|
|
| 75 |
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
| 76 |
|
| 77 |
class TransformerSALayer(nn.Module):
|
| 78 |
-
def __init__(self, embed_dim,
|
| 79 |
super().__init__()
|
| 80 |
-
self.self_attn = nn.MultiheadAttention(embed_dim,
|
| 81 |
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
| 82 |
self.dropout = nn.Dropout(dropout)
|
| 83 |
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
@@ -159,7 +159,7 @@ class CodeFormer(VQAutoEncoder):
|
|
| 159 |
self.feat_emb = nn.Linear(256, self.dim_embd)
|
| 160 |
|
| 161 |
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, n_head=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
| 162 |
-
for _ in range(self.n_layers)])
|
| 163 |
|
| 164 |
self.idx_pred_layer = nn.Sequential(
|
| 165 |
nn.LayerNorm(dim_embd),
|
|
|
|
| 75 |
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
| 76 |
|
| 77 |
class TransformerSALayer(nn.Module):
|
| 78 |
+
def __init__(self, embed_dim, n_head=8, dim_mlp=2048, dropout=0.0, activation="gelu"): # Changed nhead to n_head
|
| 79 |
super().__init__()
|
| 80 |
+
self.self_attn = nn.MultiheadAttention(embed_dim, n_head, dropout=dropout) # Changed nhead to n_head
|
| 81 |
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
| 82 |
self.dropout = nn.Dropout(dropout)
|
| 83 |
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
|
|
| 159 |
self.feat_emb = nn.Linear(256, self.dim_embd)
|
| 160 |
|
| 161 |
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, n_head=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
| 162 |
+
for _ in range(self.n_layers)])
|
| 163 |
|
| 164 |
self.idx_pred_layer = nn.Sequential(
|
| 165 |
nn.LayerNorm(dim_embd),
|