manbeast3b commited on
Commit
af5f8fc
·
verified ·
1 Parent(s): eaae6fc

Update src/model.py

Browse files
Files changed (1) hide show
  1. src/model.py +44 -10
src/model.py CHANGED
@@ -9,20 +9,54 @@ class B(nn.Module):
9
  s.s = nn.Conv2d(n_i, n_o, 1, bias=False) if n_i != n_o else nn.Identity()
10
  s.f = nn.ReLU()
11
  def forward(s, x): return s.f(s.c(x) + s.s(x))
12
- def E(lc=4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  return nn.Sequential(
14
- cv(3, 64), B(64, 64), cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
15
- cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
16
- cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
17
- cv(64, lc),
 
18
  )
19
- def D(lc=16):
 
20
  return nn.Sequential(
21
- C(), cv(lc, 48), nn.ReLU(), B(48, 48), B(48, 48), nn.Upsample(scale_factor=2),
22
- cv(48, 48, bias=False), B(48, 48), B(48, 48), nn.Upsample(scale_factor=2),
23
- cv(48, 48, bias=False), B(48, 48), nn.Upsample(scale_factor=2),
24
- cv(48, 48, bias=False), B(48, 48), cv(48, 3),
 
 
 
 
 
 
 
25
  )
 
 
 
26
  class M(nn.Module):
27
  lm, ls = 3, 0.5
28
  def __init__(s, ep="encoder.pth", dp="decoder.pth", lc=None):
 
9
  s.s = nn.Conv2d(n_i, n_o, 1, bias=False) if n_i != n_o else nn.Identity()
10
  s.f = nn.ReLU()
11
  def forward(s, x): return s.f(s.c(x) + s.s(x))
12
+
13
+ import torch
14
+ import torch as th
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ def conv(n_in, n_out, **kwargs):
19
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
20
+
21
+ class Clamp(nn.Module):
22
+ def forward(self, x):
23
+ return torch.tanh(x / 3) * 3
24
+
25
+ class Block(nn.Module):
26
+ def __init__(self, n_in, n_out):
27
+ super().__init__()
28
+ self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
29
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
30
+ self.fuse = nn.ReLU()
31
+ def forward(self, x):
32
+ return self.fuse(self.conv(x) + self.skip(x))
33
+
34
+ def E(latent_channels=4):
35
  return nn.Sequential(
36
+ conv(3, 64), Block(64, 64),
37
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
38
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
39
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
40
+ conv(64, latent_channels),
41
  )
42
+
43
+ def D(latent_channels=16): # Adjusted to match expected input channels
44
  return nn.Sequential(
45
+ Clamp(),
46
+ conv(latent_channels, 48), # Reduced from 64 to 48 channels
47
+ nn.ReLU(),
48
+ Block(48, 48), Block(48, 48), # Reduced number of blocks
49
+ nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
50
+ Block(48, 48), Block(48, 48), # Reduced number of blocks
51
+ nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
52
+ Block(48, 48), # Further reduction in blocks
53
+ nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
54
+ Block(48, 48),
55
+ conv(48, 3), # Final convolution to output channels
56
  )
57
+
58
+
59
+
60
  class M(nn.Module):
61
  lm, ls = 3, 0.5
62
  def __init__(s, ep="encoder.pth", dp="decoder.pth", lc=None):