Code changes
Browse files- inference_brain2vec.py +1 -1
inference_brain2vec.py
CHANGED
|
@@ -119,7 +119,7 @@ class Brain2vec(AutoencoderKL):
|
|
| 119 |
if checkpoint_path is not None:
|
| 120 |
if not os.path.exists(checkpoint_path):
|
| 121 |
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
|
| 122 |
-
state_dict = torch.load(checkpoint_path, map_location=device)
|
| 123 |
model.load_state_dict(state_dict)
|
| 124 |
|
| 125 |
model.to(device)
|
|
|
|
| 119 |
if checkpoint_path is not None:
|
| 120 |
if not os.path.exists(checkpoint_path):
|
| 121 |
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
|
| 122 |
+
state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 123 |
model.load_state_dict(state_dict)
|
| 124 |
|
| 125 |
model.to(device)
|