facehuggingjay commited on
Commit
85c91be
·
verified ·
1 Parent(s): 89bab58

Update BidirectionalTranslation/models/networks.py

Browse files
BidirectionalTranslation/models/networks.py CHANGED
@@ -57,8 +57,7 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init=True):
57
  gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
58
  Return an initialized network.
59
  """
60
- if len(gpu_ids) > 0:
61
- assert(torch.cuda.is_available())
62
  net.to(gpu_ids[0])
63
  if init:
64
  init_weights(net, init_type, init_gain=init_gain)
@@ -97,7 +96,7 @@ class LayerNormWarpper(nn.Module):
97
  self.num_features = int(num_features)
98
 
99
  def forward(self, x):
100
- x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).cuda()(x)
101
  return x
102
 
103
  def get_norm_layer(norm_type='instance'):
@@ -904,7 +903,7 @@ class G_Unet_add_all(nn.Module):
904
  for layer_idx in range(num_layers):
905
  res = layer_idx // 2 + 2
906
  shape = [1, 1, 2 ** res, 2 ** res]
907
- self.noise_inputs.append(torch.randn(*shape).to("cuda"))
908
 
909
  # construct unet structure
910
  unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True,
@@ -1338,7 +1337,7 @@ class ScreenVAE(nn.Module):
1338
  net = net.module
1339
  print('loading the model from %s' % load_path)
1340
  state_dict = torch.load(
1341
- load_path, map_location=lambda storage, loc: storage.cuda())
1342
  if hasattr(state_dict, '_metadata'):
1343
  del state_dict._metadata
1344
 
@@ -1372,4 +1371,4 @@ class ScreenVAE(nn.Module):
1372
  x = self.npad(x)
1373
  recons = self.dec(x)[:,:,:h,:w]
1374
  recons = (recons+1)*(line+1)/2-1
1375
- return torch.clamp(recons,-1,1)
 
57
  gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
58
  Return an initialized network.
59
  """
60
+ if len(gpu_ids) > 0 and torch.cuda.is_available():
 
61
  net.to(gpu_ids[0])
62
  if init:
63
  init_weights(net, init_type, init_gain=init_gain)
 
96
  self.num_features = int(num_features)
97
 
98
  def forward(self, x):
99
+ x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).to(x.device)(x)
100
  return x
101
 
102
  def get_norm_layer(norm_type='instance'):
 
903
  for layer_idx in range(num_layers):
904
  res = layer_idx // 2 + 2
905
  shape = [1, 1, 2 ** res, 2 ** res]
906
+ self.noise_inputs.append(torch.randn(*shape).to("cuda" if torch.cuda.is_available() else "cpu"))
907
 
908
  # construct unet structure
909
  unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True,
 
1337
  net = net.module
1338
  print('loading the model from %s' % load_path)
1339
  state_dict = torch.load(
1340
+ load_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
1341
  if hasattr(state_dict, '_metadata'):
1342
  del state_dict._metadata
1343
 
 
1371
  x = self.npad(x)
1372
  recons = self.dec(x)[:,:,:h,:w]
1373
  recons = (recons+1)*(line+1)/2-1
1374
+ return torch.clamp(recons,-1,1)