facehuggingjay commited on
Commit
89bab58
·
verified ·
1 Parent(s): e856c2f

Update BidirectionalTranslation/models/base_model.py

Browse files
BidirectionalTranslation/models/base_model.py CHANGED
@@ -36,7 +36,7 @@ class BaseModel(ABC):
36
  self.iter = 0
37
  self.last_iter = 0
38
  self.device = torch.device('cuda:{}'.format(
39
- self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
40
  # save all the checkpoints to save_dir
41
  self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
42
  try:
@@ -231,7 +231,7 @@ class BaseModel(ABC):
231
  # if you are using PyTorch newer than 0.4 (e.g., built from
232
  # GitHub source), you can remove str() on self.device
233
  state_dict = torch.load(
234
- load_path, map_location=lambda storage, loc: storage.cuda())
235
  if hasattr(state_dict, '_metadata'):
236
  del state_dict._metadata
237
 
@@ -274,4 +274,4 @@ class BaseModel(ABC):
274
  for net in nets:
275
  if net is not None:
276
  for param in net.parameters():
277
- param.requires_grad = requires_grad
 
36
  self.iter = 0
37
  self.last_iter = 0
38
  self.device = torch.device('cuda:{}'.format(
39
+ self.gpu_ids[0])) if self.gpu_ids and torch.cuda.is_available() else torch.device('cpu') # get device name: CPU or GPU
40
  # save all the checkpoints to save_dir
41
  self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
42
  try:
 
231
  # if you are using PyTorch newer than 0.4 (e.g., built from
232
  # GitHub source), you can remove str() on self.device
233
  state_dict = torch.load(
234
+ load_path, map_location=self.device)
235
  if hasattr(state_dict, '_metadata'):
236
  del state_dict._metadata
237
 
 
274
  for net in nets:
275
  if net is not None:
276
  for param in net.parameters():
277
+ param.requires_grad = requires_grad