Code changes
Browse files- README.md +6 -0
- inference_brain2vec.py +1 -1
README.md
CHANGED
|
@@ -44,6 +44,12 @@ python create_csv.py
|
|
| 44 |
mkdir ae_cache
|
| 45 |
mkdir ae_output
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
# train the model
|
| 48 |
nohup python brain2vec.py train \
|
| 49 |
--dataset_csv /home/ubuntu/brain2vec/inputs.csv \
|
|
|
|
| 44 |
mkdir ae_cache
|
| 45 |
mkdir ae_output
|
| 46 |
|
| 47 |
+
# install git lfs to pull large model weights
|
| 48 |
+
sudo apt-get update
|
| 49 |
+
sudo apt install git-lfs
|
| 50 |
+
git lfs install
|
| 51 |
+
git lfs pull
|
| 52 |
+
|
| 53 |
# train the model
|
| 54 |
nohup python brain2vec.py train \
|
| 55 |
--dataset_csv /home/ubuntu/brain2vec/inputs.csv \
|
inference_brain2vec.py
CHANGED
|
@@ -119,7 +119,7 @@ class Brain2vec(AutoencoderKL):
|
|
| 119 |
if checkpoint_path is not None:
|
| 120 |
if not os.path.exists(checkpoint_path):
|
| 121 |
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
|
| 122 |
-
state_dict = torch.load(checkpoint_path, map_location=device
|
| 123 |
model.load_state_dict(state_dict)
|
| 124 |
|
| 125 |
model.to(device)
|
|
|
|
| 119 |
if checkpoint_path is not None:
|
| 120 |
if not os.path.exists(checkpoint_path):
|
| 121 |
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
|
| 122 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
| 123 |
model.load_state_dict(state_dict)
|
| 124 |
|
| 125 |
model.to(device)
|