Spaces:
Sleeping
Sleeping
Marek Bukowicki
commited on
Commit
·
1c6540f
1
Parent(s):
42a027b
make shimnet a package
Browse files- .gitignore +4 -0
- Dockerfile +8 -6
- Readme.md +25 -18
- predict-gui.py +6 -17
- predict.py +2 -20
- pyproject.toml +47 -0
- requirements-cpu.txt +0 -9
- requirements-gpu.txt +0 -9
- requirements-gui.txt +0 -2
- shimnet/__init__.py +1 -0
- {src → shimnet}/generators.py +0 -0
- {src → shimnet}/models.py +0 -0
- shimnet/nmr_utils.py +18 -0
- shimnet/predict_utils.py +23 -0
- train.py +3 -3
.gitignore
CHANGED
|
@@ -11,3 +11,7 @@ data/
|
|
| 11 |
|
| 12 |
# gradio
|
| 13 |
.gradio/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# gradio
|
| 13 |
.gradio/
|
| 14 |
+
|
| 15 |
+
# package building
|
| 16 |
+
build/
|
| 17 |
+
*shimnet.egg-info/
|
Dockerfile
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
FROM python:3.10
|
| 2 |
|
| 3 |
RUN useradd -m -u 1000 user
|
| 4 |
USER user
|
|
@@ -10,16 +10,18 @@ ENV HOME=/home/user \
|
|
| 10 |
# Set the working directory to the user's home directory
|
| 11 |
WORKDIR $HOME/app
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
RUN pip install --no-cache-dir -r requirements-cpu.txt -r requirements-gui.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
| 16 |
-
|
| 17 |
-
FROM build as final
|
| 18 |
|
| 19 |
COPY --chown=user . .
|
| 20 |
|
| 21 |
# download weights
|
| 22 |
RUN python download_files.py --overwrite
|
| 23 |
|
| 24 |
-
CMD [ "python", "./predict-gui.py"]
|
| 25 |
|
|
|
|
| 1 |
+
FROM python:3.10 AS build
|
| 2 |
|
| 3 |
RUN useradd -m -u 1000 user
|
| 4 |
USER user
|
|
|
|
| 10 |
# Set the working directory to the user's home directory
|
| 11 |
WORKDIR $HOME/app
|
| 12 |
|
| 13 |
+
# copy installation files
|
| 14 |
+
COPY --chown=user shimnet shimnet/
|
| 15 |
+
COPY --chown=user pyproject.toml ./
|
| 16 |
+
# install shimnet (cpu version + GUI)
|
| 17 |
+
RUN pip install --no-cache-dir .[cpu,gui] --extra-index-url https://download.pytorch.org/whl/cpu
|
| 18 |
|
| 19 |
+
FROM build AS final
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
COPY --chown=user . .
|
| 22 |
|
| 23 |
# download weights
|
| 24 |
RUN python download_files.py --overwrite
|
| 25 |
|
| 26 |
+
CMD [ "python", "./predict-gui.py", "--server_name", "0.0.0.0" ]
|
| 27 |
|
Readme.md
CHANGED
|
@@ -9,17 +9,18 @@ Web service: [:
|
| 54 |
```csv
|
| 55 |
-1.97134 0.0167137
|
| 56 |
-1.97085 -0.00778748
|
|
@@ -188,14 +189,6 @@ If you want to train the network using the calibration data from our paper, foll
|
|
| 188 |
|
| 189 |
## GUI
|
| 190 |
|
| 191 |
-
### Installation
|
| 192 |
-
|
| 193 |
-
To use the ShimNet GUI, ensure you have Python 3.10 installed (not tested with Python 3.11+). After installing the ShimNet requirements (CPU/GPU), install the additional dependencies for the GUI:
|
| 194 |
-
|
| 195 |
-
```bash
|
| 196 |
-
pip install -r requirements-gui.txt
|
| 197 |
-
```
|
| 198 |
-
|
| 199 |
### Launching the GUI
|
| 200 |
|
| 201 |
The ShimNet GUI is built using Gradio. To start the application, run:
|
|
@@ -221,3 +214,17 @@ python predict-gui.py --share
|
|
| 221 |
```
|
| 222 |
|
| 223 |
A public web address will be displayed in the terminal, which you can use to access the GUI remotely or share with others.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
## Installation
|
| 11 |
|
| 12 |
+
Python 3.10+ is required
|
| 13 |
|
| 14 |
+
You may install CPU-only version for inference only. If you need both training and inference, GPU version is strongly recommended.
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
In both CPU-only and GPU versions you may also install GUI (graphical user interface for inference)
|
| 17 |
+
|
| 18 |
+
- CPU-only version
|
| 19 |
+
- with GUI (recommended): `pip install .[cpu,gui] --extra-index-url https://download.pytorch.org/whl/cpu`
|
| 20 |
+
- without GUI: `pip install .[cpu] --extra-index-url https://download.pytorch.org/whl/cpu`
|
| 21 |
+
- GPU version (strongly recommended for training)
|
| 22 |
+
- with GUI: `pip install .[gpu,gui]`
|
| 23 |
+
- without GUI: `pip install .[gpu]`
|
| 24 |
|
| 25 |
## Usage
|
| 26 |
To correct spectra presented in the paper:
|
|
|
|
| 51 |
|
| 52 |
### input format
|
| 53 |
|
| 54 |
+
The spectrum file for reconstruction should be in the format of two columns separated by a space and without the sign at the end of the line at the end of the file. The first column is frequency in ppm, the second is the intensity. The frequency values should be in ascending order (example below):
|
| 55 |
```csv
|
| 56 |
-1.97134 0.0167137
|
| 57 |
-1.97085 -0.00778748
|
|
|
|
| 189 |
|
| 190 |
## GUI
|
| 191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
### Launching the GUI
|
| 193 |
|
| 194 |
The ShimNet GUI is built using Gradio. To start the application, run:
|
|
|
|
| 214 |
```
|
| 215 |
|
| 216 |
A public web address will be displayed in the terminal, which you can use to access the GUI remotely or share with others.
|
| 217 |
+
|
| 218 |
+
### GUI inference with Docker
|
| 219 |
+
|
| 220 |
+
Create docker image:
|
| 221 |
+
```bash
|
| 222 |
+
docker build -t shimnetgui .
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
Run the container:
|
| 226 |
+
```bash
|
| 227 |
+
docker run -it -p 7860:7860 shimnetgui
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
The GUI should be working at `http://127.0.0.1:7860`
|
predict-gui.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
torch.set_grad_enabled(False)
|
| 3 |
import numpy as np
|
|
@@ -6,8 +7,7 @@ from omegaconf import OmegaConf
|
|
| 6 |
import gradio as gr
|
| 7 |
import plotly.graph_objects as go
|
| 8 |
|
| 9 |
-
from
|
| 10 |
-
from predict import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
|
| 11 |
|
| 12 |
# silent deprecation warnings
|
| 13 |
import warnings
|
|
@@ -168,11 +168,11 @@ with gr.Blocks() as app:
|
|
| 168 |
# Process button click logic
|
| 169 |
def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file):
|
| 170 |
if model_selection == "600 MHz":
|
| 171 |
-
config_file = "configs/shimnet_600.yaml"
|
| 172 |
-
weights_file = "weights/shimnet_600MHz.pt"
|
| 173 |
elif model_selection == "700 MHz":
|
| 174 |
-
config_file = "configs/shimnet_700.yaml"
|
| 175 |
-
weights_file = "weights/shimnet_700MHz.pt"
|
| 176 |
else:
|
| 177 |
config_file = config_file.name
|
| 178 |
weights_file = weights_file.name
|
|
@@ -186,14 +186,3 @@ with gr.Blocks() as app:
|
|
| 186 |
)
|
| 187 |
|
| 188 |
app.launch(share=args.share, server_name=args.server_name)
|
| 189 |
-
|
| 190 |
-
# '#636efa',
|
| 191 |
-
# '#EF553B',
|
| 192 |
-
# '#00cc96',
|
| 193 |
-
# '#ab63fa',
|
| 194 |
-
# '#FFA15A',
|
| 195 |
-
# '#19d3f3',
|
| 196 |
-
# '#FF6692',
|
| 197 |
-
# '#B6E880',
|
| 198 |
-
# '#FF97FF',
|
| 199 |
-
# '#FECB52'
|
|
|
|
| 1 |
+
import os
|
| 2 |
import torch
|
| 3 |
torch.set_grad_enabled(False)
|
| 4 |
import numpy as np
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
import plotly.graph_objects as go
|
| 9 |
|
| 10 |
+
from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
|
|
|
|
| 11 |
|
| 12 |
# silent deprecation warnings
|
| 13 |
import warnings
|
|
|
|
| 168 |
# Process button click logic
|
| 169 |
def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file):
|
| 170 |
if model_selection == "600 MHz":
|
| 171 |
+
config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_600.yaml")
|
| 172 |
+
weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_600MHz.pt")
|
| 173 |
elif model_selection == "700 MHz":
|
| 174 |
+
config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_700.yaml")
|
| 175 |
+
weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_700MHz.pt")
|
| 176 |
else:
|
| 177 |
config_file = config_file.name
|
| 178 |
weights_file = weights_file.name
|
|
|
|
| 186 |
)
|
| 187 |
|
| 188 |
app.launch(share=args.share, server_name=args.server_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predict.py
CHANGED
|
@@ -6,15 +6,14 @@ from pathlib import Path
|
|
| 6 |
import sys, os
|
| 7 |
from omegaconf import OmegaConf
|
| 8 |
|
| 9 |
-
from
|
|
|
|
| 10 |
|
| 11 |
# silent deprecation warnings
|
| 12 |
# https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
|
| 13 |
import warnings
|
| 14 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 15 |
|
| 16 |
-
class Defaults:
|
| 17 |
-
SCALE = 16.0
|
| 18 |
|
| 19 |
def parse_args():
|
| 20 |
parser = argparse.ArgumentParser()
|
|
@@ -26,23 +25,6 @@ def parse_args():
|
|
| 26 |
args = parser.parse_args()
|
| 27 |
return args
|
| 28 |
|
| 29 |
-
# functions
|
| 30 |
-
def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
|
| 31 |
-
"""resample input spectrum to match the model's frequency range"""
|
| 32 |
-
freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point)
|
| 33 |
-
spectrum = np.interp(freqs, input_freqs, input_spectrum)
|
| 34 |
-
return freqs, spectrum
|
| 35 |
-
|
| 36 |
-
def resample_output_spectrum(input_freqs, freqs, prediction):
|
| 37 |
-
"""resample prediction to match the input spectrum's frequency range"""
|
| 38 |
-
prediction = np.interp(input_freqs, freqs, prediction)
|
| 39 |
-
return prediction
|
| 40 |
-
|
| 41 |
-
def initialize_predictor(config, weights_file):
|
| 42 |
-
model = ShimNetWithSCRF(**config.model.kwargs)
|
| 43 |
-
predictor = Predictor(model, weights_file)
|
| 44 |
-
return predictor
|
| 45 |
-
|
| 46 |
# run
|
| 47 |
if __name__ == "__main__":
|
| 48 |
args = parse_args()
|
|
|
|
| 6 |
import sys, os
|
| 7 |
from omegaconf import OmegaConf
|
| 8 |
|
| 9 |
+
from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
|
| 10 |
+
|
| 11 |
|
| 12 |
# silent deprecation warnings
|
| 13 |
# https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
|
| 14 |
import warnings
|
| 15 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 16 |
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def parse_args():
|
| 19 |
parser = argparse.ArgumentParser()
|
|
|
|
| 25 |
args = parser.parse_args()
|
| 26 |
return args
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# run
|
| 29 |
if __name__ == "__main__":
|
| 30 |
args = parse_args()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[tool.setuptools]
|
| 6 |
+
packages = ["shimnet"] # only include this package
|
| 7 |
+
|
| 8 |
+
[project]
|
| 9 |
+
name = "shimnet"
|
| 10 |
+
version = "0.1.0"
|
| 11 |
+
description = "Package for ShimNet: A Neural Network for Postacquisition Improvement of NMR Spectra Distorted by Magnetic-Field Inhomogeneity https://pubs.acs.org/doi/full/10.1021/acs.jpcb.5c02632"
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "Marek Bukowicki", email = "m.bukowicki@uw.edu.pl"},
|
| 14 |
+
]
|
| 15 |
+
readme = "Readme.md"
|
| 16 |
+
requires-python = ">=3.10"
|
| 17 |
+
dependencies = [
|
| 18 |
+
"nmrglue==0.11",
|
| 19 |
+
"torchdata==0.9.0",
|
| 20 |
+
"numpy==2.0.2",
|
| 21 |
+
"matplotlib==3.9.3",
|
| 22 |
+
"pandas==2.2.3",
|
| 23 |
+
"tqdm==4.67.1",
|
| 24 |
+
"hydra-core==1.3.2"
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
[project.optional-dependencies]
|
| 28 |
+
# GPU installs the default torch build from PyPI (CUDA-enabled by default)
|
| 29 |
+
gpu = [
|
| 30 |
+
"torch==2.5.1",
|
| 31 |
+
"torchaudio==2.5.1"
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
# CPU needs extra index URL: pip install shimnet[cpu] --extra-index-url https://download.pytorch.org/whl/cpu
|
| 35 |
+
cpu = [
|
| 36 |
+
"torch==2.5.1+cpu",
|
| 37 |
+
"torchaudio==2.5.1+cpu"
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
# predictions with GUI
|
| 41 |
+
gui = [
|
| 42 |
+
"gradio==5.23.2",
|
| 43 |
+
"plotly==6.0.1",
|
| 44 |
+
"ipykernel",
|
| 45 |
+
"psutil",
|
| 46 |
+
"pexpect"
|
| 47 |
+
]
|
requirements-cpu.txt
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
torch==2.4.1+cpu
|
| 2 |
-
torchaudio==2.4.1+cpu
|
| 3 |
-
nmrglue==0.11
|
| 4 |
-
torchdata==0.9.0
|
| 5 |
-
numpy==2.0.2
|
| 6 |
-
matplotlib==3.9.3
|
| 7 |
-
pandas==2.2.3
|
| 8 |
-
tqdm==4.67.1
|
| 9 |
-
hydra-core==1.3.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements-gpu.txt
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
torch==2.4.1
|
| 2 |
-
torchaudio==2.4.1
|
| 3 |
-
nmrglue==0.11
|
| 4 |
-
torchdata==0.9.0
|
| 5 |
-
numpy==2.0.2
|
| 6 |
-
matplotlib==3.9.3
|
| 7 |
-
pandas==2.2.3
|
| 8 |
-
tqdm==4.67.1
|
| 9 |
-
hydra-core==1.3.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements-gui.txt
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
gradio==5.23.2
|
| 2 |
-
plotly==6.0.1
|
|
|
|
|
|
|
|
|
shimnet/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import generators, models, nmr_utils, predict_utils
|
{src → shimnet}/generators.py
RENAMED
|
File without changes
|
{src → shimnet}/models.py
RENAMED
|
File without changes
|
shimnet/nmr_utils.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nmrglue as ng
|
| 2 |
+
|
| 3 |
+
def fid_to_spectrum(varian_fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=None, sin_pod=False):
|
| 4 |
+
dic, data = ng.varian.read(varian_fid_path)
|
| 5 |
+
data[0] *= 0.5
|
| 6 |
+
if sin_pod:
|
| 7 |
+
data = ng.proc_base.sp(data, end=0.98)
|
| 8 |
+
|
| 9 |
+
if target_length is not None:
|
| 10 |
+
if (pad_length := target_length - len(data)) > 0:
|
| 11 |
+
data = ng.proc_base.zf(data, pad_length)
|
| 12 |
+
else:
|
| 13 |
+
data = data[:target_length]
|
| 14 |
+
|
| 15 |
+
spec=ng.proc_base.fft(data)
|
| 16 |
+
spec = ng.process.proc_autophase.autops(spec, autophase_fn, p0=ph0_correction, p1=ph1_correction, disp=False)
|
| 17 |
+
|
| 18 |
+
return spec
|
shimnet/predict_utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from .models import ShimNetWithSCRF, Predictor
|
| 4 |
+
|
| 5 |
+
class Defaults:
|
| 6 |
+
SCALE = 16.0
|
| 7 |
+
|
| 8 |
+
# functions
|
| 9 |
+
def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
|
| 10 |
+
"""resample input spectrum to match the model's frequency range"""
|
| 11 |
+
freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point)
|
| 12 |
+
spectrum = np.interp(freqs, input_freqs, input_spectrum)
|
| 13 |
+
return freqs, spectrum
|
| 14 |
+
|
| 15 |
+
def resample_output_spectrum(input_freqs, freqs, prediction):
|
| 16 |
+
"""resample prediction to match the input spectrum's frequency range"""
|
| 17 |
+
prediction = np.interp(input_freqs, freqs, prediction)
|
| 18 |
+
return prediction
|
| 19 |
+
|
| 20 |
+
def initialize_predictor(config, weights_file):
|
| 21 |
+
model = ShimNetWithSCRF(**config.model.kwargs)
|
| 22 |
+
predictor = Predictor(model, weights_file)
|
| 23 |
+
return predictor
|
train.py
CHANGED
|
@@ -15,8 +15,8 @@ matplotlib.use('Agg')
|
|
| 15 |
import warnings
|
| 16 |
warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
|
| 17 |
|
| 18 |
-
from
|
| 19 |
-
from
|
| 20 |
|
| 21 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 22 |
if len(sys.argv) < 2:
|
|
@@ -32,7 +32,7 @@ else:
|
|
| 32 |
minimum = float("inf")
|
| 33 |
|
| 34 |
# initialization
|
| 35 |
-
model = instantiate({"_target_": f"
|
| 36 |
model_weights_file = run_dir / f'model.pt'
|
| 37 |
optimizer = torch.optim.Adam(model.parameters())
|
| 38 |
optimizer_weights_file = run_dir / f'optimizer.pt'
|
|
|
|
| 15 |
import warnings
|
| 16 |
warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
|
| 17 |
|
| 18 |
+
# from shiment import models
|
| 19 |
+
from shiment.generators import get_datapipe
|
| 20 |
|
| 21 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 22 |
if len(sys.argv) < 2:
|
|
|
|
| 32 |
minimum = float("inf")
|
| 33 |
|
| 34 |
# initialization
|
| 35 |
+
model = instantiate({"_target_": f"shimnet.models.{config.model.name}", **config.model.kwargs}).to(device)
|
| 36 |
model_weights_file = run_dir / f'model.pt'
|
| 37 |
optimizer = torch.optim.Adam(model.parameters())
|
| 38 |
optimizer_weights_file = run_dir / f'optimizer.pt'
|