Spaces:
Paused
Paused
Fix device GPU
Browse files- models/base_model.py +3 -2
models/base_model.py
CHANGED
|
@@ -32,7 +32,8 @@ class BaseModel(ABC):
|
|
| 32 |
self.opt = opt
|
| 33 |
self.gpu_ids = opt.gpu_ids
|
| 34 |
self.isTrain = opt.isTrain
|
| 35 |
-
self.device =
|
|
|
|
| 36 |
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
| 37 |
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
| 38 |
torch.backends.cudnn.benchmark = True
|
|
@@ -214,7 +215,7 @@ class BaseModel(ABC):
|
|
| 214 |
print('loading the model from %s' % load_path)
|
| 215 |
# if you are using PyTorch newer than 0.4 (e.g., built from
|
| 216 |
# GitHub source), you can remove str() on self.device
|
| 217 |
-
state_dict = torch.load(load_path, map_location=
|
| 218 |
if hasattr(state_dict, '_metadata'):
|
| 219 |
del state_dict._metadata
|
| 220 |
|
|
|
|
| 32 |
self.opt = opt
|
| 33 |
self.gpu_ids = opt.gpu_ids
|
| 34 |
self.isTrain = opt.isTrain
|
| 35 |
+
self.device = 'cuda:{}'.format(self.gpu_ids[0]) if self.gpu_ids else 'cpu'
|
| 36 |
+
# self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
| 37 |
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
| 38 |
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
| 39 |
torch.backends.cudnn.benchmark = True
|
|
|
|
| 215 |
print('loading the model from %s' % load_path)
|
| 216 |
# if you are using PyTorch newer than 0.4 (e.g., built from
|
| 217 |
# GitHub source), you can remove str() on self.device
|
| 218 |
+
state_dict = torch.load(load_path, map_location=self.device, weights_only=True)
|
| 219 |
if hasattr(state_dict, '_metadata'):
|
| 220 |
del state_dict._metadata
|
| 221 |
|