Working on setting up automatic weight downloading
Browse files- README.md +39 -0
- poetry.lock +8 -1
- pyment/models/sfcn/sfcn.py +3 -0
- pyment/models/utils/ensure_weights.py +56 -7
- pyment/utils/download_file.py +30 -0
- scripts/predict_from_fastsurfer_folder.py +10 -11
- scripts/utils/upload_weights_to_github.py +68 -0
- tutorials/download_ixi.py +3 -27
README.md
CHANGED
|
@@ -1,5 +1,34 @@
|
|
| 1 |
# Installation
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
## Install pyenv and Python
|
| 5 |
|
|
@@ -100,4 +129,14 @@ Finally, we can run the preprocessing script, pointing towards the python from t
|
|
| 100 |
```
|
| 101 |
sh scripts/preprocess.sh --license <path-to-license> --python ~/venvs/fastsurfer/bin/python ~/data/ixi/images ~/data/ixi/preprocessed
|
| 102 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
</details>
|
|
|
|
| 1 |
# Installation
|
| 2 |
|
| 3 |
+
## Configure system
|
| 4 |
+
<details>
|
| 5 |
+
<summary>Ubuntu</summary>
|
| 6 |
+
|
| 7 |
+
First we need to download and install CUDA 11.2:
|
| 8 |
+
```
|
| 9 |
+
wget https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run
|
| 10 |
+
sudo sh cuda_11.2.2_460.32.03_linux.run --silent --toolkit --installpath=/usr/local/cuda-11.2
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
Next, cudnn must be installed. Download a suitable deb-file from
|
| 14 |
+
https://developer.nvidia.com/rdp/cudnn-archive. Then install the file:
|
| 15 |
+
```
|
| 16 |
+
sudo dpkg -i ~/Downloads/cudnn-local-repo-ubuntu2204-8.9.7.29_1.0-1_amd64.deb
|
| 17 |
+
sudo cp /var/cudnn-local-repo-ubuntu2204-8.9.7.29/cudnn-local-*-keyring.gpg /usr/share/keyrings/
|
| 18 |
+
sudo apt update
|
| 19 |
+
sudo apt install libcudnn8 libcudnn8-dev
|
| 20 |
+
sudo cp /usr/include/cudnn*.h /usr/local/cuda-11.2/include/
|
| 21 |
+
sudo cp -P /usr/lib/x86_64-linux-gnu/libcudnn*.so* /usr/local/cuda-11.2/lib64/
|
| 22 |
+
sudo ldconfig
|
| 23 |
+
```
|
| 24 |
+
Finally, we must configure the system paths (or add these lines to ~/.bashrc:
|
| 25 |
+
```
|
| 26 |
+
echo 'export CUDA_HOME=/usr/local/cuda-11.2' >> ~/.bashrc
|
| 27 |
+
echo 'export PATH=$CUDA_HOME/bin:$PATH' >> ~/.bashrc
|
| 28 |
+
echo 'export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$CUDA_HOME/extras/CUPTI/lib64' >> ~/.bashrc
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
</details>
|
| 32 |
|
| 33 |
## Install pyenv and Python
|
| 34 |
|
|
|
|
| 129 |
```
|
| 130 |
sh scripts/preprocess.sh --license <path-to-license> --python ~/venvs/fastsurfer/bin/python ~/data/ixi/images ~/data/ixi/preprocessed
|
| 131 |
```
|
| 132 |
+
|
| 133 |
+
### Generate predictions
|
| 134 |
+
After preprocessing, we can generate predictions for the IXI dataset using the scripts in the repository. First, ensure the virtual environment is loaded:
|
| 135 |
+
```
|
| 136 |
+
eval $(poetry env activate)
|
| 137 |
+
```
|
| 138 |
+
Next, run the prediction-script:
|
| 139 |
+
```
|
| 140 |
+
python scripts/predict_from_fastsurfer_folder.py
|
| 141 |
+
```
|
| 142 |
</details>
|
poetry.lock
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# This file is automatically @generated by Poetry 2.
|
| 2 |
|
| 3 |
[[package]]
|
| 4 |
name = "absl-py"
|
|
@@ -2807,6 +2807,13 @@ optional = false
|
|
| 2807 |
python-versions = ">=3.8"
|
| 2808 |
groups = ["main"]
|
| 2809 |
files = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2810 |
{file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
|
| 2811 |
{file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
|
| 2812 |
{file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
|
|
|
|
| 1 |
+
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
|
| 2 |
|
| 3 |
[[package]]
|
| 4 |
name = "absl-py"
|
|
|
|
| 2807 |
python-versions = ">=3.8"
|
| 2808 |
groups = ["main"]
|
| 2809 |
files = [
|
| 2810 |
+
{file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"},
|
| 2811 |
+
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"},
|
| 2812 |
+
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"},
|
| 2813 |
+
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"},
|
| 2814 |
+
{file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"},
|
| 2815 |
+
{file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"},
|
| 2816 |
+
{file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"},
|
| 2817 |
{file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
|
| 2818 |
{file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
|
| 2819 |
{file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
|
pyment/models/sfcn/sfcn.py
CHANGED
|
@@ -79,7 +79,10 @@ class SFCN(Model):
|
|
| 79 |
super().__init__(self.inputs, x)
|
| 80 |
|
| 81 |
if weights:
|
|
|
|
| 82 |
weights = ensure_weights(weights)
|
|
|
|
|
|
|
| 83 |
status = self.load_weights(weights)
|
| 84 |
|
| 85 |
print(weights)
|
|
|
|
| 79 |
super().__init__(self.inputs, x)
|
| 80 |
|
| 81 |
if weights:
|
| 82 |
+
print(weights)
|
| 83 |
weights = ensure_weights(weights)
|
| 84 |
+
print(weights)
|
| 85 |
+
weights = 'checkpoints/pyment/sfcn-multi'
|
| 86 |
status = self.load_weights(weights)
|
| 87 |
|
| 88 |
print(weights)
|
pyment/models/utils/ensure_weights.py
CHANGED
|
@@ -1,15 +1,60 @@
|
|
| 1 |
import os
|
| 2 |
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
configuration as an argument, and returns a path-prefix to files
|
| 7 |
containing the weights. If necessary, the weights are downloaded.
|
| 8 |
|
| 9 |
Parameters
|
| 10 |
----------
|
| 11 |
identifier : str
|
| 12 |
-
Points to either a filename or a valid keyword identifiying a
|
| 13 |
weight file.
|
| 14 |
|
| 15 |
Returns
|
|
@@ -21,20 +66,24 @@ def ensure_weights(identifier: str) -> str:
|
|
| 21 |
------
|
| 22 |
KeyError
|
| 23 |
If the identifier is not a valid identifier and there does not
|
| 24 |
-
exist either a single file <identifier> or files
|
| 25 |
<identifier>.index and <identifier>.data-00000-of-00001 on the
|
| 26 |
local file system.
|
| 27 |
"""
|
| 28 |
-
if
|
| 29 |
(
|
| 30 |
-
os.path.isfile(f'{identifier}.index') and
|
| 31 |
os.path.isfile(f'{identifier}.data-00000-of-00001')
|
| 32 |
) or (
|
| 33 |
os.path.isfile(identifier)
|
| 34 |
)
|
| 35 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
raise NotImplementedError(
|
| 37 |
f'Identifier-based lookups are not supported'
|
| 38 |
)
|
| 39 |
|
| 40 |
-
return identifier
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
+
from pyment.utils.download_file import download_file
|
| 4 |
|
| 5 |
+
|
| 6 |
+
IDENTIFIERS = {
|
| 7 |
+
'multi-2025': {
|
| 8 |
+
'data': 'f4054d701fa59971fb7000d38cf9e63a202bd66a',
|
| 9 |
+
'index': '9c208ca0bcc3969ceb281ba63a8cee4944a63157'
|
| 10 |
+
}
|
| 11 |
+
}
|
| 12 |
+
BASE_URL = 'https://api.github.com/repos/estenhl/pyment-public/git/blobs'
|
| 13 |
+
|
| 14 |
+
def _lookup_identifier(identifier: str, local_cache: str) -> str:
|
| 15 |
+
if not (
|
| 16 |
+
os.path.isfile(
|
| 17 |
+
os.path.join(local_cache, f'{identifier}.index')
|
| 18 |
+
) and os.path.isfile(
|
| 19 |
+
os.path.join(local_cache, f'{identifier}.data-00000-of-00001')
|
| 20 |
+
)
|
| 21 |
+
):
|
| 22 |
+
if not os.path.isdir(local_cache):
|
| 23 |
+
os.makedirs(local_cache, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
download_file(
|
| 26 |
+
url=BASE_URL + '/' + IDENTIFIERS[identifier]['data'],
|
| 27 |
+
destination=os.path.join(
|
| 28 |
+
local_cache,
|
| 29 |
+
f'{identifier}.data-00000-of-00001'
|
| 30 |
+
),
|
| 31 |
+
description=f'Downloading {identifier} data'
|
| 32 |
+
)
|
| 33 |
+
download_file(
|
| 34 |
+
url=BASE_URL + '/' + IDENTIFIERS[identifier]['index'],
|
| 35 |
+
destination=os.path.join(
|
| 36 |
+
local_cache,
|
| 37 |
+
f'{identifier}.index'
|
| 38 |
+
),
|
| 39 |
+
description=f'Downloading {identifier} index'
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
return os.path.join(local_cache, identifier)
|
| 43 |
+
|
| 44 |
+
def ensure_weights(
|
| 45 |
+
identifier: str,
|
| 46 |
+
local_cache: str = os.path.join(
|
| 47 |
+
os.path.expanduser('~'), '.pyment', 'weights'
|
| 48 |
+
)
|
| 49 |
+
) -> str:
|
| 50 |
+
"""Takes either a path or an identifier for a valid weight
|
| 51 |
configuration as an argument, and returns a path-prefix to files
|
| 52 |
containing the weights. If necessary, the weights are downloaded.
|
| 53 |
|
| 54 |
Parameters
|
| 55 |
----------
|
| 56 |
identifier : str
|
| 57 |
+
Points to either a filename or a valid keyword identifiying a
|
| 58 |
weight file.
|
| 59 |
|
| 60 |
Returns
|
|
|
|
| 66 |
------
|
| 67 |
KeyError
|
| 68 |
If the identifier is not a valid identifier and there does not
|
| 69 |
+
exist either a single file <identifier> or files
|
| 70 |
<identifier>.index and <identifier>.data-00000-of-00001 on the
|
| 71 |
local file system.
|
| 72 |
"""
|
| 73 |
+
if (
|
| 74 |
(
|
| 75 |
+
os.path.isfile(f'{identifier}.index') and
|
| 76 |
os.path.isfile(f'{identifier}.data-00000-of-00001')
|
| 77 |
) or (
|
| 78 |
os.path.isfile(identifier)
|
| 79 |
)
|
| 80 |
):
|
| 81 |
+
return identifier
|
| 82 |
+
elif identifier in IDENTIFIERS:
|
| 83 |
+
return _lookup_identifier(identifier, local_cache)
|
| 84 |
+
else:
|
| 85 |
raise NotImplementedError(
|
| 86 |
f'Identifier-based lookups are not supported'
|
| 87 |
)
|
| 88 |
|
| 89 |
+
return identifier
|
pyment/utils/download_file.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import requests
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def download_file(
|
| 7 |
+
url: str,
|
| 8 |
+
destination: str,
|
| 9 |
+
description: str = None
|
| 10 |
+
) -> str:
|
| 11 |
+
with requests.get(url, stream=True) as response:
|
| 12 |
+
response.raise_for_status()
|
| 13 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 14 |
+
|
| 15 |
+
# 1 MB chunks
|
| 16 |
+
chunk_size = 1<<20
|
| 17 |
+
|
| 18 |
+
progress_bar = tqdm(
|
| 19 |
+
response.iter_content(chunk_size=chunk_size),
|
| 20 |
+
total=int(math.ceil(total_size / chunk_size)),
|
| 21 |
+
unit='mb',
|
| 22 |
+
unit_scale=True,
|
| 23 |
+
unit_divisor=1024,
|
| 24 |
+
desc=description
|
| 25 |
+
)
|
| 26 |
+
progress_bar.format_dict['rate'] = f'mb/s'
|
| 27 |
+
|
| 28 |
+
with open(destination, 'wb') as f:
|
| 29 |
+
for chunk in progress_bar:
|
| 30 |
+
f.write(chunk)
|
scripts/predict_from_fastsurfer_folder.py
CHANGED
|
@@ -28,7 +28,7 @@ def _parse_folder_name(name: str) -> Tuple[str, str, str]:
|
|
| 28 |
return match.groups()
|
| 29 |
|
| 30 |
def predict_from_fastsurfer_folder(
|
| 31 |
-
source: str,
|
| 32 |
weights: str,
|
| 33 |
model_name: str = 'sfcn-multi',
|
| 34 |
targets: List[str] = [
|
|
@@ -38,7 +38,7 @@ def predict_from_fastsurfer_folder(
|
|
| 38 |
) -> pd.DataFrame:
|
| 39 |
if destination is not None and os.path.isfile(destination):
|
| 40 |
raise ValueError(f'Destination {destination} already exists')
|
| 41 |
-
|
| 42 |
logger.info('Loading multi-task model with weights %s', weights)
|
| 43 |
|
| 44 |
model_class = sfcn_factory(model_name)
|
|
@@ -62,7 +62,7 @@ def predict_from_fastsurfer_folder(
|
|
| 62 |
if not os.path.isfile(brainmask):
|
| 63 |
logger.warning('No mask.mgz file for folder %s', folder)
|
| 64 |
continue
|
| 65 |
-
|
| 66 |
brainmask = nib.load(brainmask)
|
| 67 |
brainmask = brainmask.get_fdata()
|
| 68 |
|
|
@@ -72,10 +72,8 @@ def predict_from_fastsurfer_folder(
|
|
| 72 |
image = conform(image)
|
| 73 |
|
| 74 |
predictions = model.predict(np.expand_dims(image, axis=0))[0]
|
| 75 |
-
print(predictions.shape)
|
| 76 |
-
print(predictions)
|
| 77 |
logger.debug('Predictions for %s: %s', folder, str(predictions))
|
| 78 |
-
|
| 79 |
results.append({
|
| 80 |
**{
|
| 81 |
'source': os.path.join(source, folder),
|
|
@@ -100,7 +98,7 @@ if __name__ == '__main__':
|
|
| 100 |
)
|
| 101 |
|
| 102 |
parser.add_argument(
|
| 103 |
-
'root',
|
| 104 |
help=(
|
| 105 |
'Path to FastSurfer folder. Should contain subfolders that have '
|
| 106 |
'an \'mri\' subfolder that contains files orig.mgz and mask.mgz'
|
|
@@ -108,15 +106,16 @@ if __name__ == '__main__':
|
|
| 108 |
)
|
| 109 |
parser.add_argument(
|
| 110 |
'-w', '--weights',
|
| 111 |
-
required=
|
|
|
|
| 112 |
help=(
|
| 113 |
'Weights to use. Should either point to a local file path, or a '
|
| 114 |
-
'known
|
| 115 |
'exist files named <path>.index and <path>.data-00000-of-00001'
|
| 116 |
)
|
| 117 |
)
|
| 118 |
parser.add_argument(
|
| 119 |
-
'-m', '--model',
|
| 120 |
required=False,
|
| 121 |
default='sfcn-multi',
|
| 122 |
help=(
|
|
@@ -128,7 +127,7 @@ if __name__ == '__main__':
|
|
| 128 |
required=False,
|
| 129 |
nargs='+',
|
| 130 |
default=[
|
| 131 |
-
'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
|
| 132 |
'neuroticism'
|
| 133 |
],
|
| 134 |
help='Name to use for each of the prediction heads in the output CSV'
|
|
|
|
| 28 |
return match.groups()
|
| 29 |
|
| 30 |
def predict_from_fastsurfer_folder(
|
| 31 |
+
source: str,
|
| 32 |
weights: str,
|
| 33 |
model_name: str = 'sfcn-multi',
|
| 34 |
targets: List[str] = [
|
|
|
|
| 38 |
) -> pd.DataFrame:
|
| 39 |
if destination is not None and os.path.isfile(destination):
|
| 40 |
raise ValueError(f'Destination {destination} already exists')
|
| 41 |
+
|
| 42 |
logger.info('Loading multi-task model with weights %s', weights)
|
| 43 |
|
| 44 |
model_class = sfcn_factory(model_name)
|
|
|
|
| 62 |
if not os.path.isfile(brainmask):
|
| 63 |
logger.warning('No mask.mgz file for folder %s', folder)
|
| 64 |
continue
|
| 65 |
+
|
| 66 |
brainmask = nib.load(brainmask)
|
| 67 |
brainmask = brainmask.get_fdata()
|
| 68 |
|
|
|
|
| 72 |
image = conform(image)
|
| 73 |
|
| 74 |
predictions = model.predict(np.expand_dims(image, axis=0))[0]
|
|
|
|
|
|
|
| 75 |
logger.debug('Predictions for %s: %s', folder, str(predictions))
|
| 76 |
+
|
| 77 |
results.append({
|
| 78 |
**{
|
| 79 |
'source': os.path.join(source, folder),
|
|
|
|
| 98 |
)
|
| 99 |
|
| 100 |
parser.add_argument(
|
| 101 |
+
'root',
|
| 102 |
help=(
|
| 103 |
'Path to FastSurfer folder. Should contain subfolders that have '
|
| 104 |
'an \'mri\' subfolder that contains files orig.mgz and mask.mgz'
|
|
|
|
| 106 |
)
|
| 107 |
parser.add_argument(
|
| 108 |
'-w', '--weights',
|
| 109 |
+
required=False,
|
| 110 |
+
default='multi-2025',
|
| 111 |
help=(
|
| 112 |
'Weights to use. Should either point to a local file path, or a '
|
| 113 |
+
'known identifier. If a local file path <path> is used, there should '
|
| 114 |
'exist files named <path>.index and <path>.data-00000-of-00001'
|
| 115 |
)
|
| 116 |
)
|
| 117 |
parser.add_argument(
|
| 118 |
+
'-m', '--model',
|
| 119 |
required=False,
|
| 120 |
default='sfcn-multi',
|
| 121 |
help=(
|
|
|
|
| 127 |
required=False,
|
| 128 |
nargs='+',
|
| 129 |
default=[
|
| 130 |
+
'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
|
| 131 |
'neuroticism'
|
| 132 |
],
|
| 133 |
help='Name to use for each of the prediction heads in the output CSV'
|
scripts/utils/upload_weights_to_github.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import base64
|
| 3 |
+
import os
|
| 4 |
+
import requests
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def upload_weights_to_github(filename: str, token: str, user: str, repo: str):
|
| 8 |
+
with open(filename, 'rb') as f:
|
| 9 |
+
bytes = f.read()
|
| 10 |
+
|
| 11 |
+
bytes = base64.b64encode(bytes).decode()
|
| 12 |
+
|
| 13 |
+
if os.path.isfile(token):
|
| 14 |
+
with open(token, 'r') as f:
|
| 15 |
+
token = f.read().strip()
|
| 16 |
+
|
| 17 |
+
headers = {
|
| 18 |
+
'Accept': 'application/vnd.github+json',
|
| 19 |
+
'Authorization': f'Bearer {token}',
|
| 20 |
+
'X-GitHub-Api-Version': '2022-11-28'
|
| 21 |
+
}
|
| 22 |
+
content = {
|
| 23 |
+
'content': bytes,
|
| 24 |
+
'encoding': 'base64'
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
url = f'https://api.github.com/repos/{user}/{repo}/git/blobs'
|
| 28 |
+
|
| 29 |
+
response = requests.post(url, json=content, headers=headers)
|
| 30 |
+
|
| 31 |
+
return response
|
| 32 |
+
|
| 33 |
+
if __name__ == '__main__':
|
| 34 |
+
parser = argparse.ArgumentParser('Uploads a weights-file to github')
|
| 35 |
+
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
'-f', '--filename',
|
| 38 |
+
required=True,
|
| 39 |
+
help='Path to file containing weights'
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
'-t', '--token',
|
| 43 |
+
required=True,
|
| 44 |
+
help='Token for the GitHub API'
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
'-u', '--user',
|
| 48 |
+
required=False,
|
| 49 |
+
default='estenhl',
|
| 50 |
+
help='Owner of the github repo'
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
'-r', '--repo',
|
| 54 |
+
required=False,
|
| 55 |
+
default='pyment-public',
|
| 56 |
+
help='Name of the github repo'
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
|
| 61 |
+
response = upload_weights_to_github(
|
| 62 |
+
filename=args.filename,
|
| 63 |
+
token=args.token,
|
| 64 |
+
user=args.user,
|
| 65 |
+
repo=args.repo
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
print(response.text)
|
tutorials/download_ixi.py
CHANGED
|
@@ -5,34 +5,10 @@ import requests
|
|
| 5 |
import tarfile
|
| 6 |
from tqdm import tqdm
|
| 7 |
|
|
|
|
| 8 |
|
| 9 |
-
DEFAULT_DESTINATION = os.path.join(os.path.expanduser('~'), 'data', 'ixi')
|
| 10 |
-
|
| 11 |
-
def download_file(
|
| 12 |
-
url: str,
|
| 13 |
-
destination: str,
|
| 14 |
-
description: str = None
|
| 15 |
-
) -> str:
|
| 16 |
-
with requests.get(url, stream=True) as response:
|
| 17 |
-
response.raise_for_status()
|
| 18 |
-
total_size = int(response.headers.get('content-length', 0))
|
| 19 |
-
|
| 20 |
-
# 1 MB chunks
|
| 21 |
-
chunk_size = 1<<20
|
| 22 |
|
| 23 |
-
|
| 24 |
-
response.iter_content(chunk_size=chunk_size),
|
| 25 |
-
total=int(math.ceil(total_size / chunk_size)),
|
| 26 |
-
unit='mb',
|
| 27 |
-
unit_scale=True,
|
| 28 |
-
unit_divisor=1024,
|
| 29 |
-
desc=description
|
| 30 |
-
)
|
| 31 |
-
progress_bar.format_dict['rate'] = f'mb/s'
|
| 32 |
-
|
| 33 |
-
with open(destination, 'wb') as f:
|
| 34 |
-
for chunk in progress_bar:
|
| 35 |
-
f.write(chunk)
|
| 36 |
|
| 37 |
def download_tar(tar_path: str) -> str:
|
| 38 |
url = (
|
|
@@ -82,4 +58,4 @@ if __name__ == '__main__':
|
|
| 82 |
|
| 83 |
args = parser.parse_args()
|
| 84 |
|
| 85 |
-
download_ixi(args.destination)
|
|
|
|
| 5 |
import tarfile
|
| 6 |
from tqdm import tqdm
|
| 7 |
|
| 8 |
+
from pyment.utils.download_file import download_file
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
DEFAULT_DESTINATION = os.path.join(os.path.expanduser('~'), 'data', 'ixi')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def download_tar(tar_path: str) -> str:
|
| 14 |
url = (
|
|
|
|
| 58 |
|
| 59 |
args = parser.parse_args()
|
| 60 |
|
| 61 |
+
download_ixi(args.destination)
|