Spaces:
Runtime error
Runtime error
Update BidirectionalTranslation/models/base_model.py
Browse files
BidirectionalTranslation/models/base_model.py
CHANGED
|
@@ -36,7 +36,7 @@ class BaseModel(ABC):
|
|
| 36 |
self.iter = 0
|
| 37 |
self.last_iter = 0
|
| 38 |
self.device = torch.device('cuda:{}'.format(
|
| 39 |
-
self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
| 40 |
# save all the checkpoints to save_dir
|
| 41 |
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
| 42 |
try:
|
|
@@ -231,7 +231,7 @@ class BaseModel(ABC):
|
|
| 231 |
# if you are using PyTorch newer than 0.4 (e.g., built from
|
| 232 |
# GitHub source), you can remove str() on self.device
|
| 233 |
state_dict = torch.load(
|
| 234 |
-
load_path, map_location=
|
| 235 |
if hasattr(state_dict, '_metadata'):
|
| 236 |
del state_dict._metadata
|
| 237 |
|
|
@@ -274,4 +274,4 @@ class BaseModel(ABC):
|
|
| 274 |
for net in nets:
|
| 275 |
if net is not None:
|
| 276 |
for param in net.parameters():
|
| 277 |
-
param.requires_grad = requires_grad
|
|
|
|
| 36 |
self.iter = 0
|
| 37 |
self.last_iter = 0
|
| 38 |
self.device = torch.device('cuda:{}'.format(
|
| 39 |
+
self.gpu_ids[0])) if self.gpu_ids and torch.cuda.is_available() else torch.device('cpu') # get device name: CPU or GPU
|
| 40 |
# save all the checkpoints to save_dir
|
| 41 |
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
| 42 |
try:
|
|
|
|
| 231 |
# if you are using PyTorch newer than 0.4 (e.g., built from
|
| 232 |
# GitHub source), you can remove str() on self.device
|
| 233 |
state_dict = torch.load(
|
| 234 |
+
load_path, map_location=self.device)
|
| 235 |
if hasattr(state_dict, '_metadata'):
|
| 236 |
del state_dict._metadata
|
| 237 |
|
|
|
|
| 274 |
for net in nets:
|
| 275 |
if net is not None:
|
| 276 |
for param in net.parameters():
|
| 277 |
+
param.requires_grad = requires_grad
|