Yarflam commited on
Commit
a2ed7af
·
1 Parent(s): f084b99

Fix device GPU

Browse files
Files changed (1) hide show
  1. models/base_model.py +3 -2
models/base_model.py CHANGED
@@ -32,7 +32,8 @@ class BaseModel(ABC):
32
  self.opt = opt
33
  self.gpu_ids = opt.gpu_ids
34
  self.isTrain = opt.isTrain
35
- self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
 
36
  self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37
  if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38
  torch.backends.cudnn.benchmark = True
@@ -214,7 +215,7 @@ class BaseModel(ABC):
214
  print('loading the model from %s' % load_path)
215
  # if you are using PyTorch newer than 0.4 (e.g., built from
216
  # GitHub source), you can remove str() on self.device
217
- state_dict = torch.load(load_path, map_location=str(self.device), weights_only=True)
218
  if hasattr(state_dict, '_metadata'):
219
  del state_dict._metadata
220
 
 
32
  self.opt = opt
33
  self.gpu_ids = opt.gpu_ids
34
  self.isTrain = opt.isTrain
35
+ self.device = 'cuda:{}'.format(self.gpu_ids[0]) if self.gpu_ids else 'cpu'
36
+ # self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
37
  self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
38
  if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
39
  torch.backends.cudnn.benchmark = True
 
215
  print('loading the model from %s' % load_path)
216
  # if you are using PyTorch newer than 0.4 (e.g., built from
217
  # GitHub source), you can remove str() on self.device
218
+ state_dict = torch.load(load_path, map_location=self.device, weights_only=True)
219
  if hasattr(state_dict, '_metadata'):
220
  del state_dict._metadata
221