Update src/model.py
Browse files- src/model.py +10 -61
src/model.py
CHANGED
|
@@ -29,67 +29,15 @@ def Encoder(latent_channels=4):
|
|
| 29 |
conv(64, latent_channels),
|
| 30 |
)
|
| 31 |
|
| 32 |
-
|
| 33 |
-
def __init__(self, in_channels, embed_dim=64, dilation_rates=(1, 2, 4)):
|
| 34 |
-
super(DCAH, self).__init__()
|
| 35 |
-
self.in_channels = in_channels
|
| 36 |
-
self.embed_dim = embed_dim
|
| 37 |
-
self.dilated_convs = nn.ModuleList([
|
| 38 |
-
nn.Conv2d(in_channels, embed_dim, kernel_size=3, padding=rate, dilation=rate)
|
| 39 |
-
for rate in dilation_rates
|
| 40 |
-
])
|
| 41 |
-
self.dilated_conv_merge = nn.Conv2d(embed_dim * len(dilation_rates), embed_dim, kernel_size=1)
|
| 42 |
-
self.query = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
|
| 43 |
-
self.key = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
|
| 44 |
-
self.value = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
|
| 45 |
-
self.refine = nn.Sequential(
|
| 46 |
-
nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
|
| 47 |
-
nn.ReLU(),
|
| 48 |
-
nn.Conv2d(embed_dim, in_channels, kernel_size=1)
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
def forward(self, x):
|
| 52 |
-
dilated_features = [conv(x) for conv in self.dilated_convs]
|
| 53 |
-
concat_features = torch.cat(dilated_features, dim=1)
|
| 54 |
-
global_context = self.dilated_conv_merge(concat_features)
|
| 55 |
-
q = self.query(global_context)
|
| 56 |
-
k = self.key(global_context)
|
| 57 |
-
v = self.value(global_context)
|
| 58 |
-
attention = F.softmax(torch.matmul(q.flatten(2), k.flatten(2).transpose(-2, -1)), dim=-1)
|
| 59 |
-
attention_out = torch.matmul(attention, v.flatten(2)).view_as(global_context)
|
| 60 |
-
refined = self.refine(global_context + attention_out)
|
| 61 |
-
return refined
|
| 62 |
-
|
| 63 |
-
def DecoderSeq(latent_channels=16):
|
| 64 |
return nn.Sequential(
|
| 65 |
-
Clamp(),
|
| 66 |
-
|
| 67 |
-
nn.ReLU(),
|
| 68 |
-
Block(
|
| 69 |
-
|
| 70 |
-
Block(48, 48), Block(48, 48),
|
| 71 |
-
nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
|
| 72 |
-
Block(48, 48),
|
| 73 |
-
nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
|
| 74 |
-
Block(48, 48),
|
| 75 |
-
conv(48, 3),
|
| 76 |
)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
class Decoder(nn.Module):
|
| 80 |
-
def __init__(self, latent_channels=16):
|
| 81 |
-
decoder = DecoderSeq(latent_channels=latent_channels)
|
| 82 |
-
refinement_head = DCAH(in_channels=3, embed_dim=64)
|
| 83 |
-
super(Decoder, self).__init__()
|
| 84 |
-
self.decoder = decoder
|
| 85 |
-
self.refinement_head = refinement_head
|
| 86 |
-
|
| 87 |
-
def forward(self, x):
|
| 88 |
-
decoded = self.decoder(x)
|
| 89 |
-
refined = self.refinement_head(decoded)
|
| 90 |
-
return refined
|
| 91 |
-
|
| 92 |
-
|
| 93 |
class Model(nn.Module):
|
| 94 |
latent_magnitude = 3
|
| 95 |
latent_shift = 0.5
|
|
@@ -103,16 +51,17 @@ class Model(nn.Module):
|
|
| 103 |
if encoder_path is not None:
|
| 104 |
encoder_state_dict = torch.load(encoder_path, map_location="cpu", weights_only=True)
|
| 105 |
filtered_state_dict = {k.strip('encoder.'): v for k, v in encoder_state_dict.items() if k.strip('encoder.') in self.encoder.state_dict() and v.size() == self.encoder.state_dict()[k.strip('encoder.')].size()}
|
|
|
|
| 106 |
self.encoder.load_state_dict(filtered_state_dict, strict=False)
|
| 107 |
|
| 108 |
if decoder_path is not None:
|
| 109 |
decoder_state_dict = torch.load(decoder_path, map_location="cpu", weights_only=True)
|
| 110 |
-
filtered_state_dict = {k: v for k, v in decoder_state_dict.items() if k in self.decoder.state_dict() and v.size() == self.decoder.state_dict()[k].size()}
|
|
|
|
| 111 |
self.decoder.load_state_dict(filtered_state_dict, strict=False)
|
| 112 |
|
| 113 |
self.encoder.requires_grad_(False)
|
| 114 |
-
self.decoder.
|
| 115 |
-
self.decoder.refinement_head.requires_grad_(False)
|
| 116 |
|
| 117 |
def guess_latent_channels(self, encoder_path):
|
| 118 |
if "taef1" in encoder_path:return 16
|
|
|
|
| 29 |
conv(64, latent_channels),
|
| 30 |
)
|
| 31 |
|
| 32 |
+
def Decoder(latent_channels=4):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
return nn.Sequential(
|
| 34 |
+
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
| 35 |
+
Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),nn.ReLU(),
|
| 36 |
+
Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),nn.ReLU(),
|
| 37 |
+
Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),nn.ReLU(),
|
| 38 |
+
Block(64, 64), conv(64, 3),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
class Model(nn.Module):
|
| 42 |
latent_magnitude = 3
|
| 43 |
latent_shift = 0.5
|
|
|
|
| 51 |
if encoder_path is not None:
|
| 52 |
encoder_state_dict = torch.load(encoder_path, map_location="cpu", weights_only=True)
|
| 53 |
filtered_state_dict = {k.strip('encoder.'): v for k, v in encoder_state_dict.items() if k.strip('encoder.') in self.encoder.state_dict() and v.size() == self.encoder.state_dict()[k.strip('encoder.')].size()}
|
| 54 |
+
print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(self.encoder.state_dict())}")
|
| 55 |
self.encoder.load_state_dict(filtered_state_dict, strict=False)
|
| 56 |
|
| 57 |
if decoder_path is not None:
|
| 58 |
decoder_state_dict = torch.load(decoder_path, map_location="cpu", weights_only=True)
|
| 59 |
+
filtered_state_dict = {k.strip('decoder.'): v for k, v in decoder_state_dict.items() if k.strip('decoder.') in self.decoder.state_dict() and v.size() == self.decoder.state_dict()[k.strip('decoder.')].size()}
|
| 60 |
+
print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(self.decoder.state_dict())}")
|
| 61 |
self.decoder.load_state_dict(filtered_state_dict, strict=False)
|
| 62 |
|
| 63 |
self.encoder.requires_grad_(False)
|
| 64 |
+
self.decoder.requires_grad_(False)
|
|
|
|
| 65 |
|
| 66 |
def guess_latent_channels(self, encoder_path):
|
| 67 |
if "taef1" in encoder_path:return 16
|