manbeast3b commited on
Commit
026d1ee
·
verified ·
1 Parent(s): c73798e

Update src/model.py

Browse files
Files changed (1) hide show
  1. src/model.py +88 -42
src/model.py CHANGED
@@ -1,46 +1,92 @@
1
- import torch as T, torch.nn as n, torch.nn.functional as f
2
- def C(v, w, **k): return n.Conv2d(v, w, 3, padding=1, **k)
3
- class Z(n.Module):
4
- def forward(s, x): return T.tanh(x / 3) * 3
5
- class A(n.Module):
6
- def __init__(s, i, o):
 
 
 
 
 
 
 
 
7
  super().__init__()
8
- s.a = n.Sequential(C(i, o), n.ReLU(), C(o, o), n.ReLU(), C(o, o))
9
- s.b = n.Conv2d(i, o, 1, bias=False) if i != o else n.Identity()
10
- s.c = n.ReLU()
11
- def forward(s, x): return s.c(s.a(x) + s.b(x))
12
- def E(c=4):
13
- return n.Sequential(
14
- C(3, 64), A(64, 64),
15
- C(64, 64, stride=2, bias=False), A(64, 64), A(64, 64), A(64, 64),
16
- C(64, 64, stride=2, bias=False), A(64, 64), A(64, 64), A(64, 64),
17
- C(64, 64, stride=2, bias=False), A(64, 64), A(64, 64), A(64, 64),
18
- C(64, c))
19
- def D(c=4):
20
- return n.Sequential(
21
- Z(), C(c, 64), n.ReLU(),
22
- A(64, 64), n.Upsample(scale_factor=2), C(64, 64, bias=False), n.ReLU(),
23
- A(64, 64), n.Upsample(scale_factor=2), C(64, 64, bias=False), n.ReLU(),
24
- A(64, 64), n.Upsample(scale_factor=2), C(64, 64, bias=False), n.ReLU(),
25
- A(64, 64), C(64, 3))
26
- class F(n.Module):
27
- M, N = 3, 0.5
28
- def __init__(s, p1="a.pth", p2="b.pth", c=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  super().__init__()
30
- c = c or s.H(str(p1))
31
- s.a, s.b = E(c), D(c)
32
- if p1: s.L(s.a, p1, 'encoder')
33
- if p2: s.L(s.b, p2, 'decoder')
34
- s.a.requires_grad_(False), s.b.requires_grad_(False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @staticmethod
36
- def S(x): return x.div(2 * F.M).add(F.N).clamp(0, 1)
 
 
37
  @staticmethod
38
- def U(x): return x.sub(F.N).mul(2 * F.M)
39
- def L(s, m, p, q):
40
- sd = {k.strip(f"{q}."): v for k, v in T.load(p, map_location="cpu", weights_only=True).items() if k.strip(f"{q}.") in m.state_dict() and v.size() == m.state_dict()[k.strip(f"{q}.")].size()}
41
- # print(f" {len(sd)} filtered keys for {q}, total: {len(m.state_dict())}")
42
- m.load_state_dict(sd, strict=False)
43
- def forward(s, x, r=False):
44
- l = s.a(x)
45
- o = s.b(l)
46
- return (o.clamp(0, 1), l) if r else o.clamp(0, 1)
 
1
+ import torch
2
+ import torch as th
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ def conv(n_in, n_out, **kwargs):
7
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
8
+
9
+ class Clamp(nn.Module):
10
+ def forward(self, x):
11
+ return torch.tanh(x / 3) * 3
12
+
13
+ class Block(nn.Module):
14
+ def __init__(self, n_in, n_out):
15
  super().__init__()
16
+ self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
17
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
18
+ self.fuse = nn.ReLU()
19
+ def forward(self, x):
20
+ return self.fuse(self.conv(x) + self.skip(x))
21
+
22
+ def Encoder(latent_channels=4):
23
+ return nn.Sequential(
24
+ conv(3, 64), Block(64, 64),
25
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
26
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
27
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
28
+ conv(64, latent_channels),
29
+ )
30
+
31
+ def Decoder(latent_channels=16): # Adjusted to match expected input channels
32
+ return nn.Sequential(
33
+ Clamp(),
34
+ conv(latent_channels, 48), # Reduced from 64 to 48 channels
35
+ nn.ReLU(),
36
+ Block(48, 48), Block(48, 48), # Reduced number of blocks
37
+ nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
38
+ Block(48, 48), Block(48, 48), # Reduced number of blocks
39
+ nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
40
+ Block(48, 48), # Further reduction in blocks
41
+ nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
42
+ Block(48, 48),
43
+ conv(48, 3), # Final convolution to output channels
44
+ )
45
+
46
+
47
+
48
+ class Model(nn.Module):
49
+ latent_magnitude = 3
50
+ latent_shift = 0.5
51
+
52
+ def __init__(self, encoder_path="encoder.pth", decoder_path="decoder.pth", latent_channels=None):
53
  super().__init__()
54
+ if latent_channels is None:
55
+ latent_channels = self.guess_latent_channels(str(encoder_path))
56
+ self.encoder = Encoder(latent_channels)
57
+ self.decoder = Decoder(latent_channels)
58
+
59
+ if encoder_path is not None:
60
+ encoder_state_dict = torch.load(encoder_path, map_location="cpu", weights_only=True)
61
+ filtered_state_dict = {k.strip('encoder.'): v for k, v in encoder_state_dict.items() if k.strip('encoder.') in self.encoder.state_dict() and v.size() == self.encoder.state_dict()[k.strip('encoder.')].size()}
62
+ print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(self.encoder.state_dict())}")
63
+ self.encoder.load_state_dict(filtered_state_dict, strict=False)
64
+
65
+ if decoder_path is not None:
66
+ decoder_state_dict = torch.load(decoder_path, map_location="cpu", weights_only=True)
67
+ filtered_state_dict = {k.strip('decoder.'): v for k, v in decoder_state_dict.items() if k.strip('decoder.') in self.decoder.state_dict() and v.size() == self.decoder.state_dict()[k.strip('decoder.')].size()}
68
+ print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(self.decoder.state_dict())}")
69
+ self.decoder.load_state_dict(filtered_state_dict, strict=False)
70
+
71
+ self.encoder.requires_grad_(False)
72
+ self.decoder.requires_grad_(False)
73
+
74
+ def guess_latent_channels(self, encoder_path):
75
+ if "taef1" in encoder_path:return 16
76
+ if "taesd3" in encoder_path:return 16
77
+ return 4
78
+
79
  @staticmethod
80
+ def scale_latents(x):
81
+ return x.div(2 * Model.latent_magnitude).add(Model.latent_shift).clamp(0, 1)
82
+
83
  @staticmethod
84
+ def unscale_latents(x):
85
+ return x.sub(Model.latent_shift).mul(2 * Model.latent_magnitude)
86
+
87
+ def forward(self, x, return_latent=False):
88
+ latent = self.encoder(x)
89
+ out = self.decoder(latent)
90
+ if return_latent:
91
+ return out.clamp(0, 1), latent
92
+ return out.clamp(0, 1)