manbeast3b commited on
Commit
16b6d4a
·
verified ·
1 Parent(s): 28019a8

Update src/model.py

Browse files
Files changed (1) hide show
  1. src/model.py +10 -61
src/model.py CHANGED
@@ -29,67 +29,15 @@ def Encoder(latent_channels=4):
29
  conv(64, latent_channels),
30
  )
31
 
32
- class DCAH(nn.Module):
33
- def __init__(self, in_channels, embed_dim=64, dilation_rates=(1, 2, 4)):
34
- super(DCAH, self).__init__()
35
- self.in_channels = in_channels
36
- self.embed_dim = embed_dim
37
- self.dilated_convs = nn.ModuleList([
38
- nn.Conv2d(in_channels, embed_dim, kernel_size=3, padding=rate, dilation=rate)
39
- for rate in dilation_rates
40
- ])
41
- self.dilated_conv_merge = nn.Conv2d(embed_dim * len(dilation_rates), embed_dim, kernel_size=1)
42
- self.query = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
43
- self.key = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
44
- self.value = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
45
- self.refine = nn.Sequential(
46
- nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
47
- nn.ReLU(),
48
- nn.Conv2d(embed_dim, in_channels, kernel_size=1)
49
- )
50
-
51
- def forward(self, x):
52
- dilated_features = [conv(x) for conv in self.dilated_convs]
53
- concat_features = torch.cat(dilated_features, dim=1)
54
- global_context = self.dilated_conv_merge(concat_features)
55
- q = self.query(global_context)
56
- k = self.key(global_context)
57
- v = self.value(global_context)
58
- attention = F.softmax(torch.matmul(q.flatten(2), k.flatten(2).transpose(-2, -1)), dim=-1)
59
- attention_out = torch.matmul(attention, v.flatten(2)).view_as(global_context)
60
- refined = self.refine(global_context + attention_out)
61
- return refined
62
-
63
- def DecoderSeq(latent_channels=16):
64
  return nn.Sequential(
65
- Clamp(),
66
- conv(latent_channels, 48),
67
- nn.ReLU(),
68
- Block(48, 48), Block(48, 48),
69
- nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
70
- Block(48, 48), Block(48, 48),
71
- nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
72
- Block(48, 48),
73
- nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
74
- Block(48, 48),
75
- conv(48, 3),
76
  )
77
 
78
-
79
- class Decoder(nn.Module):
80
- def __init__(self, latent_channels=16):
81
- decoder = DecoderSeq(latent_channels=latent_channels)
82
- refinement_head = DCAH(in_channels=3, embed_dim=64)
83
- super(Decoder, self).__init__()
84
- self.decoder = decoder
85
- self.refinement_head = refinement_head
86
-
87
- def forward(self, x):
88
- decoded = self.decoder(x)
89
- refined = self.refinement_head(decoded)
90
- return refined
91
-
92
-
93
  class Model(nn.Module):
94
  latent_magnitude = 3
95
  latent_shift = 0.5
@@ -103,16 +51,17 @@ class Model(nn.Module):
103
  if encoder_path is not None:
104
  encoder_state_dict = torch.load(encoder_path, map_location="cpu", weights_only=True)
105
  filtered_state_dict = {k.strip('encoder.'): v for k, v in encoder_state_dict.items() if k.strip('encoder.') in self.encoder.state_dict() and v.size() == self.encoder.state_dict()[k.strip('encoder.')].size()}
 
106
  self.encoder.load_state_dict(filtered_state_dict, strict=False)
107
 
108
  if decoder_path is not None:
109
  decoder_state_dict = torch.load(decoder_path, map_location="cpu", weights_only=True)
110
- filtered_state_dict = {k: v for k, v in decoder_state_dict.items() if k in self.decoder.state_dict() and v.size() == self.decoder.state_dict()[k].size()}
 
111
  self.decoder.load_state_dict(filtered_state_dict, strict=False)
112
 
113
  self.encoder.requires_grad_(False)
114
- self.decoder.decoder.requires_grad_(False)
115
- self.decoder.refinement_head.requires_grad_(False)
116
 
117
  def guess_latent_channels(self, encoder_path):
118
  if "taef1" in encoder_path:return 16
 
29
  conv(64, latent_channels),
30
  )
31
 
32
+ def Decoder(latent_channels=4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  return nn.Sequential(
34
+ Clamp(), conv(latent_channels, 64), nn.ReLU(),
35
+ Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),nn.ReLU(),
36
+ Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),nn.ReLU(),
37
+ Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),nn.ReLU(),
38
+ Block(64, 64), conv(64, 3),
 
 
 
 
 
 
39
  )
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  class Model(nn.Module):
42
  latent_magnitude = 3
43
  latent_shift = 0.5
 
51
  if encoder_path is not None:
52
  encoder_state_dict = torch.load(encoder_path, map_location="cpu", weights_only=True)
53
  filtered_state_dict = {k.strip('encoder.'): v for k, v in encoder_state_dict.items() if k.strip('encoder.') in self.encoder.state_dict() and v.size() == self.encoder.state_dict()[k.strip('encoder.')].size()}
54
+ print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(self.encoder.state_dict())}")
55
  self.encoder.load_state_dict(filtered_state_dict, strict=False)
56
 
57
  if decoder_path is not None:
58
  decoder_state_dict = torch.load(decoder_path, map_location="cpu", weights_only=True)
59
+ filtered_state_dict = {k.strip('decoder.'): v for k, v in decoder_state_dict.items() if k.strip('decoder.') in self.decoder.state_dict() and v.size() == self.decoder.state_dict()[k.strip('decoder.')].size()}
60
+ print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(self.decoder.state_dict())}")
61
  self.decoder.load_state_dict(filtered_state_dict, strict=False)
62
 
63
  self.encoder.requires_grad_(False)
64
+ self.decoder.requires_grad_(False)
 
65
 
66
  def guess_latent_channels(self, encoder_path):
67
  if "taef1" in encoder_path:return 16