Spaces:
Runtime error
Runtime error
Update codeformer_arch.py
Browse files- codeformer_arch.py +44 -88
codeformer_arch.py
CHANGED
|
@@ -5,18 +5,12 @@ from torch import nn, Tensor
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
from typing import Optional, List
|
| 7 |
|
| 8 |
-
from vqgan_arch import *
|
| 9 |
from basicsr.utils import get_root_logger
|
| 10 |
from basicsr.utils.registry import ARCH_REGISTRY
|
| 11 |
|
| 12 |
def calc_mean_std(feat, eps=1e-5):
|
| 13 |
-
"""Calculate mean and std for adaptive_instance_normalization.
|
| 14 |
-
|
| 15 |
-
Args:
|
| 16 |
-
feat (Tensor): 4D tensor.
|
| 17 |
-
eps (float): A small value added to the variance to avoid
|
| 18 |
-
divide-by-zero. Default: 1e-5.
|
| 19 |
-
"""
|
| 20 |
size = feat.size()
|
| 21 |
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
| 22 |
b, c = size[:2]
|
|
@@ -25,30 +19,15 @@ def calc_mean_std(feat, eps=1e-5):
|
|
| 25 |
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
| 26 |
return feat_mean, feat_std
|
| 27 |
|
| 28 |
-
|
| 29 |
def adaptive_instance_normalization(content_feat, style_feat):
|
| 30 |
-
"""Adaptive instance normalization.
|
| 31 |
-
|
| 32 |
-
Adjust the reference features to have the similar color and illuminations
|
| 33 |
-
as those in the degradate features.
|
| 34 |
-
|
| 35 |
-
Args:
|
| 36 |
-
content_feat (Tensor): The reference feature.
|
| 37 |
-
style_feat (Tensor): The degradate features.
|
| 38 |
-
"""
|
| 39 |
size = content_feat.size()
|
| 40 |
style_mean, style_std = calc_mean_std(style_feat)
|
| 41 |
content_mean, content_std = calc_mean_std(content_feat)
|
| 42 |
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 43 |
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
| 44 |
|
| 45 |
-
|
| 46 |
class PositionEmbeddingSine(nn.Module):
|
| 47 |
-
"""
|
| 48 |
-
This is a more standard version of the position embedding, very similar to the one
|
| 49 |
-
used by the Attention is all you need paper, generalized to work on images.
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
| 53 |
super().__init__()
|
| 54 |
self.num_pos_feats = num_pos_feats
|
|
@@ -93,14 +72,12 @@ def _get_activation_fn(activation):
|
|
| 93 |
return F.gelu
|
| 94 |
if activation == "glu":
|
| 95 |
return F.glu
|
| 96 |
-
raise RuntimeError(
|
| 97 |
-
|
| 98 |
|
| 99 |
class TransformerSALayer(nn.Module):
|
| 100 |
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
| 101 |
super().__init__()
|
| 102 |
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
| 103 |
-
# Implementation of Feedforward model - MLP
|
| 104 |
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
| 105 |
self.dropout = nn.Dropout(dropout)
|
| 106 |
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
@@ -119,15 +96,11 @@ class TransformerSALayer(nn.Module):
|
|
| 119 |
tgt_mask: Optional[Tensor] = None,
|
| 120 |
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 121 |
query_pos: Optional[Tensor] = None):
|
| 122 |
-
|
| 123 |
-
# self attention
|
| 124 |
tgt2 = self.norm1(tgt)
|
| 125 |
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 126 |
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
| 127 |
key_padding_mask=tgt_key_padding_mask)[0]
|
| 128 |
tgt = tgt + self.dropout1(tgt2)
|
| 129 |
-
|
| 130 |
-
# ffn
|
| 131 |
tgt2 = self.norm2(tgt)
|
| 132 |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 133 |
tgt = tgt + self.dropout2(tgt2)
|
|
@@ -136,17 +109,21 @@ class TransformerSALayer(nn.Module):
|
|
| 136 |
class Fuse_sft_block(nn.Module):
|
| 137 |
def __init__(self, in_ch, out_ch):
|
| 138 |
super().__init__()
|
| 139 |
-
self.encode_enc =
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
| 141 |
self.scale = nn.Sequential(
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
self.shift = nn.Sequential(
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
| 150 |
|
| 151 |
def forward(self, enc_feat, dec_feat, w=1):
|
| 152 |
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
|
@@ -156,18 +133,18 @@ class Fuse_sft_block(nn.Module):
|
|
| 156 |
out = dec_feat + residual
|
| 157 |
return out
|
| 158 |
|
| 159 |
-
|
| 160 |
@ARCH_REGISTRY.register()
|
| 161 |
class CodeFormer(VQAutoEncoder):
|
| 162 |
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
| 167 |
|
| 168 |
if vqgan_path is not None:
|
| 169 |
-
self.load_state_dict(
|
| 170 |
-
torch.load(vqgan_path, map_location='cpu')['params_ema'])
|
| 171 |
|
| 172 |
if fix_modules is not None:
|
| 173 |
for module in fix_modules:
|
|
@@ -177,19 +154,18 @@ class CodeFormer(VQAutoEncoder):
|
|
| 177 |
self.connect_list = connect_list
|
| 178 |
self.n_layers = n_layers
|
| 179 |
self.dim_embd = dim_embd
|
| 180 |
-
self.dim_mlp = dim_embd*2
|
| 181 |
|
| 182 |
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
| 183 |
self.feat_emb = nn.Linear(256, self.dim_embd)
|
| 184 |
|
| 185 |
-
# transformer
|
| 186 |
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
| 187 |
-
|
| 188 |
|
| 189 |
-
# logits_predict head
|
| 190 |
self.idx_pred_layer = nn.Sequential(
|
| 191 |
nn.LayerNorm(dim_embd),
|
| 192 |
-
nn.Linear(dim_embd, codebook_size, bias=False)
|
|
|
|
| 193 |
|
| 194 |
self.channels = {
|
| 195 |
'16': 512,
|
|
@@ -200,12 +176,9 @@ class CodeFormer(VQAutoEncoder):
|
|
| 200 |
'512': 64,
|
| 201 |
}
|
| 202 |
|
| 203 |
-
|
| 204 |
-
self.
|
| 205 |
-
# after first residual block for > 16, before attn layer for ==16
|
| 206 |
-
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
| 207 |
|
| 208 |
-
# fuse_convs_dict
|
| 209 |
self.fuse_convs_dict = nn.ModuleDict()
|
| 210 |
for f_size in self.connect_list:
|
| 211 |
in_ch = self.channels[f_size]
|
|
@@ -221,60 +194,43 @@ class CodeFormer(VQAutoEncoder):
|
|
| 221 |
module.weight.data.fill_(1.0)
|
| 222 |
|
| 223 |
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
| 224 |
-
# ################### Encoder #####################
|
| 225 |
enc_feat_dict = {}
|
| 226 |
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
| 227 |
-
for i, block in enumerate(self.encoder
|
| 228 |
-
x = block(x)
|
| 229 |
if i in out_list:
|
| 230 |
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
| 231 |
|
| 232 |
lq_feat = x
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
| 236 |
-
# BCHW -> BC(HW) -> (HW)BC
|
| 237 |
-
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
| 238 |
query_emb = feat_emb
|
| 239 |
-
|
| 240 |
for layer in self.ft_layers:
|
| 241 |
query_emb = layer(query_emb, query_pos=pos_emb)
|
| 242 |
|
| 243 |
-
|
| 244 |
-
logits =
|
| 245 |
-
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
| 246 |
|
| 247 |
-
if code_only:
|
| 248 |
-
# logits doesn't need softmax before cross_entropy loss
|
| 249 |
return logits, lq_feat
|
| 250 |
|
| 251 |
-
# ################# Quantization ###################
|
| 252 |
-
# if self.training:
|
| 253 |
-
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
| 254 |
-
# # b(hw)c -> bc(hw) -> bchw
|
| 255 |
-
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
| 256 |
-
# ------------
|
| 257 |
soft_one_hot = F.softmax(logits, dim=2)
|
| 258 |
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
| 259 |
-
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
| 260 |
-
# preserve gradients
|
| 261 |
-
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
| 262 |
|
| 263 |
if detach_16:
|
| 264 |
-
quant_feat = quant_feat.detach()
|
| 265 |
if adain:
|
| 266 |
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
| 267 |
|
| 268 |
-
# ################## Generator ####################
|
| 269 |
x = quant_feat
|
| 270 |
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
if i in fuse_list: # fuse after i-th block
|
| 275 |
f_size = str(x.shape[-1])
|
| 276 |
-
if w>0:
|
| 277 |
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
| 278 |
out = x
|
| 279 |
-
# logits doesn't need softmax before cross_entropy loss
|
| 280 |
return out, logits, lq_feat
|
|
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
from typing import Optional, List
|
| 7 |
|
| 8 |
+
from vqgan_arch import * # Custom import from root
|
| 9 |
from basicsr.utils import get_root_logger
|
| 10 |
from basicsr.utils.registry import ARCH_REGISTRY
|
| 11 |
|
| 12 |
def calc_mean_std(feat, eps=1e-5):
|
| 13 |
+
"""Calculate mean and std for adaptive_instance_normalization."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
size = feat.size()
|
| 15 |
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
| 16 |
b, c = size[:2]
|
|
|
|
| 19 |
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
| 20 |
return feat_mean, feat_std
|
| 21 |
|
|
|
|
| 22 |
def adaptive_instance_normalization(content_feat, style_feat):
|
| 23 |
+
"""Adaptive instance normalization."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
size = content_feat.size()
|
| 25 |
style_mean, style_std = calc_mean_std(style_feat)
|
| 26 |
content_mean, content_std = calc_mean_std(content_feat)
|
| 27 |
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 28 |
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
| 29 |
|
|
|
|
| 30 |
class PositionEmbeddingSine(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
| 32 |
super().__init__()
|
| 33 |
self.num_pos_feats = num_pos_feats
|
|
|
|
| 72 |
return F.gelu
|
| 73 |
if activation == "glu":
|
| 74 |
return F.glu
|
| 75 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
|
|
|
| 76 |
|
| 77 |
class TransformerSALayer(nn.Module):
|
| 78 |
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
| 79 |
super().__init__()
|
| 80 |
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
|
|
|
| 81 |
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
| 82 |
self.dropout = nn.Dropout(dropout)
|
| 83 |
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
|
|
| 96 |
tgt_mask: Optional[Tensor] = None,
|
| 97 |
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 98 |
query_pos: Optional[Tensor] = None):
|
|
|
|
|
|
|
| 99 |
tgt2 = self.norm1(tgt)
|
| 100 |
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 101 |
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
| 102 |
key_padding_mask=tgt_key_padding_mask)[0]
|
| 103 |
tgt = tgt + self.dropout1(tgt2)
|
|
|
|
|
|
|
| 104 |
tgt2 = self.norm2(tgt)
|
| 105 |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 106 |
tgt = tgt + self.dropout2(tgt2)
|
|
|
|
| 109 |
class Fuse_sft_block(nn.Module):
|
| 110 |
def __init__(self, in_ch, out_ch):
|
| 111 |
super().__init__()
|
| 112 |
+
self.encode_enc = nn.Sequential(
|
| 113 |
+
nn.Conv2d(2 * in_ch, out_ch, 3, 1, 1, bias=False),
|
| 114 |
+
nn.BatchNorm2d(out_ch),
|
| 115 |
+
nn.LeakyReLU(0.2, inplace=True)
|
| 116 |
+
)
|
| 117 |
self.scale = nn.Sequential(
|
| 118 |
+
nn.Conv2d(in_ch, out_ch, 3, 1, 1),
|
| 119 |
+
nn.LeakyReLU(0.2, True),
|
| 120 |
+
nn.Conv2d(out_ch, out_ch, 3, 1, 1)
|
| 121 |
+
)
|
| 122 |
self.shift = nn.Sequential(
|
| 123 |
+
nn.Conv2d(in_ch, out_ch, 3, 1, 1),
|
| 124 |
+
nn.LeakyReLU(0.2, True),
|
| 125 |
+
nn.Conv2d(out_ch, out_ch, 3, 1, 1)
|
| 126 |
+
)
|
| 127 |
|
| 128 |
def forward(self, enc_feat, dec_feat, w=1):
|
| 129 |
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
|
|
|
| 133 |
out = dec_feat + residual
|
| 134 |
return out
|
| 135 |
|
|
|
|
| 136 |
@ARCH_REGISTRY.register()
|
| 137 |
class CodeFormer(VQAutoEncoder):
|
| 138 |
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
| 139 |
+
codebook_size=1024, latent_size=256,
|
| 140 |
+
connect_list=['32', '64', '128', '256'],
|
| 141 |
+
fix_modules=['quantize', 'generator'], vqgan_path=None):
|
| 142 |
+
# Adjust down_factor to ensure it works with channel scaling
|
| 143 |
+
down_factor = [1, 2, 2, 4, 4, 8] # Ensure this matches the number of steps
|
| 144 |
+
super().__init__(512, 64, down_factor, 'nearest', len(down_factor) - 1, 16, codebook_size)
|
| 145 |
|
| 146 |
if vqgan_path is not None:
|
| 147 |
+
self.load_state_dict(torch.load(vqgan_path, map_location='cpu')['params_ema'])
|
|
|
|
| 148 |
|
| 149 |
if fix_modules is not None:
|
| 150 |
for module in fix_modules:
|
|
|
|
| 154 |
self.connect_list = connect_list
|
| 155 |
self.n_layers = n_layers
|
| 156 |
self.dim_embd = dim_embd
|
| 157 |
+
self.dim_mlp = dim_embd * 2
|
| 158 |
|
| 159 |
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
| 160 |
self.feat_emb = nn.Linear(256, self.dim_embd)
|
| 161 |
|
|
|
|
| 162 |
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
| 163 |
+
for _ in range(self.n_layers)])
|
| 164 |
|
|
|
|
| 165 |
self.idx_pred_layer = nn.Sequential(
|
| 166 |
nn.LayerNorm(dim_embd),
|
| 167 |
+
nn.Linear(dim_embd, codebook_size, bias=False)
|
| 168 |
+
)
|
| 169 |
|
| 170 |
self.channels = {
|
| 171 |
'16': 512,
|
|
|
|
| 176 |
'512': 64,
|
| 177 |
}
|
| 178 |
|
| 179 |
+
self.fuse_encoder_block = {'512': 2, '256': 5, '128': 8, '64': 11, '32': 14, '16': 18}
|
| 180 |
+
self.fuse_generator_block = {'16': 6, '32': 9, '64': 12, '128': 15, '256': 18, '512': 21}
|
|
|
|
|
|
|
| 181 |
|
|
|
|
| 182 |
self.fuse_convs_dict = nn.ModuleDict()
|
| 183 |
for f_size in self.connect_list:
|
| 184 |
in_ch = self.channels[f_size]
|
|
|
|
| 194 |
module.weight.data.fill_(1.0)
|
| 195 |
|
| 196 |
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
|
|
|
| 197 |
enc_feat_dict = {}
|
| 198 |
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
| 199 |
+
for i, block in enumerate(self.encoder):
|
| 200 |
+
x = block(x)
|
| 201 |
if i in out_list:
|
| 202 |
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
| 203 |
|
| 204 |
lq_feat = x
|
| 205 |
+
pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
|
| 206 |
+
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
|
|
|
|
|
|
|
|
|
|
| 207 |
query_emb = feat_emb
|
| 208 |
+
|
| 209 |
for layer in self.ft_layers:
|
| 210 |
query_emb = layer(query_emb, query_pos=pos_emb)
|
| 211 |
|
| 212 |
+
logits = self.idx_pred_layer(query_emb)
|
| 213 |
+
logits = logits.permute(1, 0, 2)
|
|
|
|
| 214 |
|
| 215 |
+
if code_only:
|
|
|
|
| 216 |
return logits, lq_feat
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
soft_one_hot = F.softmax(logits, dim=2)
|
| 219 |
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
| 220 |
+
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0], 16, 16, 256])
|
|
|
|
|
|
|
| 221 |
|
| 222 |
if detach_16:
|
| 223 |
+
quant_feat = quant_feat.detach()
|
| 224 |
if adain:
|
| 225 |
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
| 226 |
|
|
|
|
| 227 |
x = quant_feat
|
| 228 |
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
| 229 |
+
for i, block in enumerate(self.decoder):
|
| 230 |
+
x = block(x)
|
| 231 |
+
if i in fuse_list:
|
|
|
|
| 232 |
f_size = str(x.shape[-1])
|
| 233 |
+
if w > 0:
|
| 234 |
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
| 235 |
out = x
|
|
|
|
| 236 |
return out, logits, lq_feat
|