Spaces:
Runtime error
Runtime error
Update BidirectionalTranslation/models/networks.py
Browse files
BidirectionalTranslation/models/networks.py
CHANGED
|
@@ -57,8 +57,7 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init=True):
|
|
| 57 |
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
| 58 |
Return an initialized network.
|
| 59 |
"""
|
| 60 |
-
if len(gpu_ids) > 0:
|
| 61 |
-
assert(torch.cuda.is_available())
|
| 62 |
net.to(gpu_ids[0])
|
| 63 |
if init:
|
| 64 |
init_weights(net, init_type, init_gain=init_gain)
|
|
@@ -97,7 +96,7 @@ class LayerNormWarpper(nn.Module):
|
|
| 97 |
self.num_features = int(num_features)
|
| 98 |
|
| 99 |
def forward(self, x):
|
| 100 |
-
x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).
|
| 101 |
return x
|
| 102 |
|
| 103 |
def get_norm_layer(norm_type='instance'):
|
|
@@ -904,7 +903,7 @@ class G_Unet_add_all(nn.Module):
|
|
| 904 |
for layer_idx in range(num_layers):
|
| 905 |
res = layer_idx // 2 + 2
|
| 906 |
shape = [1, 1, 2 ** res, 2 ** res]
|
| 907 |
-
self.noise_inputs.append(torch.randn(*shape).to("cuda"))
|
| 908 |
|
| 909 |
# construct unet structure
|
| 910 |
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True,
|
|
@@ -1338,7 +1337,7 @@ class ScreenVAE(nn.Module):
|
|
| 1338 |
net = net.module
|
| 1339 |
print('loading the model from %s' % load_path)
|
| 1340 |
state_dict = torch.load(
|
| 1341 |
-
load_path, map_location=
|
| 1342 |
if hasattr(state_dict, '_metadata'):
|
| 1343 |
del state_dict._metadata
|
| 1344 |
|
|
@@ -1372,4 +1371,4 @@ class ScreenVAE(nn.Module):
|
|
| 1372 |
x = self.npad(x)
|
| 1373 |
recons = self.dec(x)[:,:,:h,:w]
|
| 1374 |
recons = (recons+1)*(line+1)/2-1
|
| 1375 |
-
return torch.clamp(recons,-1,1)
|
|
|
|
| 57 |
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
| 58 |
Return an initialized network.
|
| 59 |
"""
|
| 60 |
+
if len(gpu_ids) > 0 and torch.cuda.is_available():
|
|
|
|
| 61 |
net.to(gpu_ids[0])
|
| 62 |
if init:
|
| 63 |
init_weights(net, init_type, init_gain=init_gain)
|
|
|
|
| 96 |
self.num_features = int(num_features)
|
| 97 |
|
| 98 |
def forward(self, x):
|
| 99 |
+
x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).to(x.device)(x)
|
| 100 |
return x
|
| 101 |
|
| 102 |
def get_norm_layer(norm_type='instance'):
|
|
|
|
| 903 |
for layer_idx in range(num_layers):
|
| 904 |
res = layer_idx // 2 + 2
|
| 905 |
shape = [1, 1, 2 ** res, 2 ** res]
|
| 906 |
+
self.noise_inputs.append(torch.randn(*shape).to("cuda" if torch.cuda.is_available() else "cpu"))
|
| 907 |
|
| 908 |
# construct unet structure
|
| 909 |
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True,
|
|
|
|
| 1337 |
net = net.module
|
| 1338 |
print('loading the model from %s' % load_path)
|
| 1339 |
state_dict = torch.load(
|
| 1340 |
+
load_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
| 1341 |
if hasattr(state_dict, '_metadata'):
|
| 1342 |
del state_dict._metadata
|
| 1343 |
|
|
|
|
| 1371 |
x = self.npad(x)
|
| 1372 |
recons = self.dec(x)[:,:,:h,:w]
|
| 1373 |
recons = (recons+1)*(line+1)/2-1
|
| 1374 |
+
return torch.clamp(recons,-1,1)
|