Update src/model.py
Browse files- src/model.py +2 -2
src/model.py
CHANGED
|
@@ -7,11 +7,11 @@ class Clamp(nn.Module):
|
|
| 7 |
class B(nn.Module):
|
| 8 |
def __init__(self, n_in, n_out):
|
| 9 |
super().__init__()
|
| 10 |
-
self.
|
| 11 |
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
| 12 |
self.fuse = nn.ReLU()
|
| 13 |
def forward(self, x):
|
| 14 |
-
return self.fuse(self.
|
| 15 |
def E(latent_channels=4):
|
| 16 |
return nn.Sequential(
|
| 17 |
C(3, 64), B(64, 64),
|
|
|
|
| 7 |
class B(nn.Module):
|
| 8 |
def __init__(self, n_in, n_out):
|
| 9 |
super().__init__()
|
| 10 |
+
self.conv = nn.Sequential(C(n_in, n_out), nn.ReLU(), C(n_out, n_out), nn.ReLU(), C(n_out, n_out))
|
| 11 |
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
| 12 |
self.fuse = nn.ReLU()
|
| 13 |
def forward(self, x):
|
| 14 |
+
return self.fuse(self.conv(x) + self.skip(x))
|
| 15 |
def E(latent_channels=4):
|
| 16 |
return nn.Sequential(
|
| 17 |
C(3, 64), B(64, 64),
|