Spaces:
Sleeping
Sleeping
Fix #9 app.py
Browse files
app.py
CHANGED
|
@@ -107,7 +107,7 @@ class MappingNetwork(nn.Module):
|
|
| 107 |
return s
|
| 108 |
|
| 109 |
class StyleEncoder(nn.Module):
|
| 110 |
-
def __init__(self, img_size=256, style_dim=64, num_domains=
|
| 111 |
super().__init__()
|
| 112 |
dim_in = 64
|
| 113 |
blocks = []
|
|
@@ -117,20 +117,23 @@ class StyleEncoder(nn.Module):
|
|
| 117 |
dim_out = min(dim_in*2, max_conv_dim)
|
| 118 |
blocks += [ResBlk(dim_in, dim_out, downsample=True)]
|
| 119 |
dim_in = dim_out
|
|
|
|
| 120 |
self.shared = nn.Sequential(*blocks)
|
|
|
|
| 121 |
self.unshared = nn.ModuleList()
|
| 122 |
for _ in range(num_domains):
|
| 123 |
-
self.unshared += [nn.Linear(dim_in
|
| 124 |
|
| 125 |
def forward(self, x, y):
|
| 126 |
h = self.shared(x)
|
|
|
|
| 127 |
h = h.view(h.size(0), -1)
|
| 128 |
out = []
|
| 129 |
for layer in self.unshared:
|
| 130 |
out += [layer(h)]
|
| 131 |
-
out = torch.stack(out, dim=1)
|
| 132 |
-
idx = torch.
|
| 133 |
-
s =
|
| 134 |
return s
|
| 135 |
|
| 136 |
# DEFINICIÓN DEL GENERADOR
|
|
|
|
| 107 |
return s
|
| 108 |
|
| 109 |
class StyleEncoder(nn.Module):
|
| 110 |
+
def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512):
|
| 111 |
super().__init__()
|
| 112 |
dim_in = 64
|
| 113 |
blocks = []
|
|
|
|
| 117 |
dim_out = min(dim_in*2, max_conv_dim)
|
| 118 |
blocks += [ResBlk(dim_in, dim_out, downsample=True)]
|
| 119 |
dim_in = dim_out
|
| 120 |
+
blocks += [nn.LeakyReLU(0.2)]
|
| 121 |
self.shared = nn.Sequential(*blocks)
|
| 122 |
+
|
| 123 |
self.unshared = nn.ModuleList()
|
| 124 |
for _ in range(num_domains):
|
| 125 |
+
self.unshared += [nn.Linear(dim_in, style_dim)]
|
| 126 |
|
| 127 |
def forward(self, x, y):
|
| 128 |
h = self.shared(x)
|
| 129 |
+
h = F.adaptive_avg_pool2d(h, (1,1))
|
| 130 |
h = h.view(h.size(0), -1)
|
| 131 |
out = []
|
| 132 |
for layer in self.unshared:
|
| 133 |
out += [layer(h)]
|
| 134 |
+
out = torch.stack(out, dim=1)
|
| 135 |
+
idx = torch.arange(y.size(0)).to(y.device)
|
| 136 |
+
s = out[idx, y]
|
| 137 |
return s
|
| 138 |
|
| 139 |
# DEFINICIÓN DEL GENERADOR
|