manbeast3b commited on
Commit
447b1a5
·
verified ·
1 Parent(s): af5f8fc

Update src/model.py

Browse files
Files changed (1) hide show
  1. src/model.py +16 -44
src/model.py CHANGED
@@ -1,62 +1,35 @@
1
  import torch as t, torch.nn as nn, torch.nn.functional as F
2
- def cv(n_i, n_o, **kw): return nn.Conv2d(n_i, n_o, 3, padding=1, **kw)
3
- class C(nn.Module):
4
- def forward(self, x): return t.tanh(x / 3) * 3
5
- class B(nn.Module):
6
- def __init__(s, n_i, n_o):
7
- super().__init__()
8
- s.c = nn.Sequential(cv(n_i, n_o), nn.ReLU(), cv(n_o, n_o), nn.ReLU(), cv(n_o, n_o))
9
- s.s = nn.Conv2d(n_i, n_o, 1, bias=False) if n_i != n_o else nn.Identity()
10
- s.f = nn.ReLU()
11
- def forward(s, x): return s.f(s.c(x) + s.s(x))
12
-
13
- import torch
14
- import torch as th
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
-
18
- def conv(n_in, n_out, **kwargs):
19
  return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
20
-
21
  class Clamp(nn.Module):
22
  def forward(self, x):
23
  return torch.tanh(x / 3) * 3
24
-
25
- class Block(nn.Module):
26
  def __init__(self, n_in, n_out):
27
  super().__init__()
28
- self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
29
  self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
30
  self.fuse = nn.ReLU()
31
  def forward(self, x):
32
- return self.fuse(self.conv(x) + self.skip(x))
33
-
34
  def E(latent_channels=4):
35
  return nn.Sequential(
36
- conv(3, 64), Block(64, 64),
37
- conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
38
- conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
39
- conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
40
- conv(64, latent_channels),
41
  )
42
 
43
- def D(latent_channels=16): # Adjusted to match expected input channels
44
  return nn.Sequential(
45
  Clamp(),
46
- conv(latent_channels, 48), # Reduced from 64 to 48 channels
47
- nn.ReLU(),
48
- Block(48, 48), Block(48, 48), # Reduced number of blocks
49
- nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
50
- Block(48, 48), Block(48, 48), # Reduced number of blocks
51
- nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
52
- Block(48, 48), # Further reduction in blocks
53
- nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
54
- Block(48, 48),
55
- conv(48, 3), # Final convolution to output channels
56
- )
57
-
58
-
59
-
60
  class M(nn.Module):
61
  lm, ls = 3, 0.5
62
  def __init__(s, ep="encoder.pth", dp="decoder.pth", lc=None):
@@ -65,7 +38,6 @@ class M(nn.Module):
65
  s.e, s.d = E(lc), D(lc)
66
  def f(sd, mod, pfx):
67
  f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
68
- print(f"num keys: {len(f_sd)} of {len(mod.state_dict())}")
69
  mod.load_state_dict(f_sd, strict=False)
70
  if ep: f(t.load(ep, map_location="cpu", weights_only=True), s.e, "encoder.")
71
  if dp: f(t.load(dp, map_location="cpu", weights_only=True), s.d, "decoder.")
 
1
  import torch as t, torch.nn as nn, torch.nn.functional as F
2
+ def C(n_in, n_out, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
 
4
  class Clamp(nn.Module):
5
  def forward(self, x):
6
  return torch.tanh(x / 3) * 3
7
+ class B(nn.Module):
 
8
  def __init__(self, n_in, n_out):
9
  super().__init__()
10
+ self.C = nn.Sequential(C(n_in, n_out), nn.ReLU(), C(n_out, n_out), nn.ReLU(), C(n_out, n_out))
11
  self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
12
  self.fuse = nn.ReLU()
13
  def forward(self, x):
14
+ return self.fuse(self.C(x) + self.skip(x))
 
15
  def E(latent_channels=4):
16
  return nn.Sequential(
17
+ C(3, 64), B(64, 64),
18
+ C(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
19
+ C(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
20
+ C(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
21
+ C(64, latent_channels),
22
  )
23
 
24
+ def D(latent_channels=16):
25
  return nn.Sequential(
26
  Clamp(),
27
+ C(latent_channels, 48),nn.ReLU(),B(48, 48), B(48, 48),
28
+ nn.Upsample(scale_factor=2), C(48, 48, bias=False),B(48, 48), B(48, 48),
29
+ nn.Upsample(scale_factor=2), C(48, 48, bias=False),B(48, 48),
30
+ nn.Upsample(scale_factor=2), C(48, 48, bias=False),B(48, 48),
31
+ C(48, 3),
32
+ )
 
 
 
 
 
 
 
 
33
  class M(nn.Module):
34
  lm, ls = 3, 0.5
35
  def __init__(s, ep="encoder.pth", dp="decoder.pth", lc=None):
 
38
  s.e, s.d = E(lc), D(lc)
39
  def f(sd, mod, pfx):
40
  f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
 
41
  mod.load_state_dict(f_sd, strict=False)
42
  if ep: f(t.load(ep, map_location="cpu", weights_only=True), s.e, "encoder.")
43
  if dp: f(t.load(dp, map_location="cpu", weights_only=True), s.d, "decoder.")