Code fixes
Browse files- brlp_lite.py +8 -5
brlp_lite.py
CHANGED
|
@@ -56,6 +56,9 @@ from monai.transforms.transform import Transform
|
|
| 56 |
from monai import transforms
|
| 57 |
from monai.utils import set_determinism
|
| 58 |
from monai.data.meta_tensor import MetaTensor
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
from tqdm import tqdm
|
| 61 |
import matplotlib.pyplot as plt
|
|
@@ -93,9 +96,9 @@ def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module:
|
|
| 93 |
if checkpoints_path is not None:
|
| 94 |
assert os.path.exists(checkpoints_path), 'Invalid path'
|
| 95 |
# Using context manager to allow MetaTensor
|
| 96 |
-
with torch.serialization.safe_globals([MetaTensor]):
|
| 97 |
-
|
| 98 |
-
|
| 99 |
return network
|
| 100 |
|
| 101 |
|
|
@@ -490,8 +493,8 @@ def train(
|
|
| 490 |
|
| 491 |
# Save the model after each epoch.
|
| 492 |
os.makedirs(output_dir, exist_ok=True)
|
| 493 |
-
torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch
|
| 494 |
-
torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch
|
| 495 |
|
| 496 |
writer.close()
|
| 497 |
print("Training completed and models saved.")
|
|
|
|
| 56 |
from monai import transforms
|
| 57 |
from monai.utils import set_determinism
|
| 58 |
from monai.data.meta_tensor import MetaTensor
|
| 59 |
+
import torch.serialization
|
| 60 |
+
|
| 61 |
+
torch.serialization.add_safe_globals([MetaTensor])
|
| 62 |
|
| 63 |
from tqdm import tqdm
|
| 64 |
import matplotlib.pyplot as plt
|
|
|
|
| 96 |
if checkpoints_path is not None:
|
| 97 |
assert os.path.exists(checkpoints_path), 'Invalid path'
|
| 98 |
# Using context manager to allow MetaTensor
|
| 99 |
+
#with torch.serialization.safe_globals([MetaTensor]):
|
| 100 |
+
network.load_state_dict(torch.load(checkpoints_path))
|
| 101 |
+
#network.load_state_dict(torch.load(checkpoints_path, map_location='cpu'))
|
| 102 |
return network
|
| 103 |
|
| 104 |
|
|
|
|
| 493 |
|
| 494 |
# Save the model after each epoch.
|
| 495 |
os.makedirs(output_dir, exist_ok=True)
|
| 496 |
+
torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch}.pth'))
|
| 497 |
+
torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch}.pth'))
|
| 498 |
|
| 499 |
writer.close()
|
| 500 |
print("Training completed and models saved.")
|