Marek Bukowicki commited on
Commit
1c6540f
·
1 Parent(s): 42a027b

make shimnet a package

Browse files
.gitignore CHANGED
@@ -11,3 +11,7 @@ data/
11
 
12
  # gradio
13
  .gradio/
 
 
 
 
 
11
 
12
  # gradio
13
  .gradio/
14
+
15
+ # package building
16
+ build/
17
+ *shimnet.egg-info/
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.10 as build
2
 
3
  RUN useradd -m -u 1000 user
4
  USER user
@@ -10,16 +10,18 @@ ENV HOME=/home/user \
10
  # Set the working directory to the user's home directory
11
  WORKDIR $HOME/app
12
 
 
 
 
 
 
13
 
14
- COPY --chown=user requirements-cpu.txt requirements-gui.txt ./
15
- RUN pip install --no-cache-dir -r requirements-cpu.txt -r requirements-gui.txt --extra-index-url https://download.pytorch.org/whl/cpu
16
-
17
- FROM build as final
18
 
19
  COPY --chown=user . .
20
 
21
  # download weights
22
  RUN python download_files.py --overwrite
23
 
24
- CMD [ "python", "./predict-gui.py"]
25
 
 
1
+ FROM python:3.10 AS build
2
 
3
  RUN useradd -m -u 1000 user
4
  USER user
 
10
  # Set the working directory to the user's home directory
11
  WORKDIR $HOME/app
12
 
13
+ # copy installation files
14
+ COPY --chown=user shimnet shimnet/
15
+ COPY --chown=user pyproject.toml ./
16
+ # install shimnet (cpu version + GUI)
17
+ RUN pip install --no-cache-dir .[cpu,gui] --extra-index-url https://download.pytorch.org/whl/cpu
18
 
19
+ FROM build AS final
 
 
 
20
 
21
  COPY --chown=user . .
22
 
23
  # download weights
24
  RUN python download_files.py --overwrite
25
 
26
+ CMD [ "python", "./predict-gui.py", "--server_name", "0.0.0.0" ]
27
 
Readme.md CHANGED
@@ -9,17 +9,18 @@ Web service: [![Open in Hugging Face Spaces](https://huggingface.co/datasets/hug
9
 
10
  ## Installation
11
 
12
- Python 3.9+ (3.10+ for GUI)
13
 
14
- GPU version (for training and inference)
15
- ```
16
- pip install -r requirements-gpu.txt
17
- ```
18
 
19
- CPU version (for inference, not recommended for training)
20
- ```
21
- pip install -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
22
- ```
 
 
 
 
23
 
24
  ## Usage
25
  To correct spectra presented in the paper:
@@ -50,7 +51,7 @@ python predict.py sample_data/CresolRed_after_styrene_600MHz.csv -o output --con
50
 
51
  ### input format
52
 
53
- The spectrum file for reconstruction should be in the format of two columns separated by a space and without the sign at the end of the line at the end of the file(example below):
54
  ```csv
55
  -1.97134 0.0167137
56
  -1.97085 -0.00778748
@@ -188,14 +189,6 @@ If you want to train the network using the calibration data from our paper, foll
188
 
189
  ## GUI
190
 
191
- ### Installation
192
-
193
- To use the ShimNet GUI, ensure you have Python 3.10 installed (not tested with Python 3.11+). After installing the ShimNet requirements (CPU/GPU), install the additional dependencies for the GUI:
194
-
195
- ```bash
196
- pip install -r requirements-gui.txt
197
- ```
198
-
199
  ### Launching the GUI
200
 
201
  The ShimNet GUI is built using Gradio. To start the application, run:
@@ -221,3 +214,17 @@ python predict-gui.py --share
221
  ```
222
 
223
  A public web address will be displayed in the terminal, which you can use to access the GUI remotely or share with others.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  ## Installation
11
 
12
+ Python 3.10+ is required
13
 
14
+ You may install CPU-only version for inference only. If you need both training and inference, GPU version is strongly recommended.
 
 
 
15
 
16
+ In both CPU-only and GPU versions you may also install GUI (graphical user interface for inference)
17
+
18
+ - CPU-only version
19
+ - with GUI (recommended): `pip install .[cpu,gui] --extra-index-url https://download.pytorch.org/whl/cpu`
20
+ - without GUI: `pip install .[cpu] --extra-index-url https://download.pytorch.org/whl/cpu`
21
+ - GPU version (strongly recommended for training)
22
+ - with GUI: `pip install .[gpu,gui]`
23
+ - without GUI: `pip install .[gpu]`
24
 
25
  ## Usage
26
  To correct spectra presented in the paper:
 
51
 
52
  ### input format
53
 
54
+ The spectrum file for reconstruction should be in the format of two columns separated by a space and without the sign at the end of the line at the end of the file. The first column is frequency in ppm, the second is the intensity. The frequency values ​​should be in ascending order (example below):
55
  ```csv
56
  -1.97134 0.0167137
57
  -1.97085 -0.00778748
 
189
 
190
  ## GUI
191
 
 
 
 
 
 
 
 
 
192
  ### Launching the GUI
193
 
194
  The ShimNet GUI is built using Gradio. To start the application, run:
 
214
  ```
215
 
216
  A public web address will be displayed in the terminal, which you can use to access the GUI remotely or share with others.
217
+
218
+ ### GUI inference with Docker
219
+
220
+ Create docker image:
221
+ ```bash
222
+ docker build -t shimnetgui .
223
+ ```
224
+
225
+ Run the container:
226
+ ```bash
227
+ docker run -it -p 7860:7860 shimnetgui
228
+ ```
229
+
230
+ The GUI should be working at `http://127.0.0.1:7860`
predict-gui.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  torch.set_grad_enabled(False)
3
  import numpy as np
@@ -6,8 +7,7 @@ from omegaconf import OmegaConf
6
  import gradio as gr
7
  import plotly.graph_objects as go
8
 
9
- from src.models import ShimNetWithSCRF, Predictor
10
- from predict import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
11
 
12
  # silent deprecation warnings
13
  import warnings
@@ -168,11 +168,11 @@ with gr.Blocks() as app:
168
  # Process button click logic
169
  def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file):
170
  if model_selection == "600 MHz":
171
- config_file = "configs/shimnet_600.yaml"
172
- weights_file = "weights/shimnet_600MHz.pt"
173
  elif model_selection == "700 MHz":
174
- config_file = "configs/shimnet_700.yaml"
175
- weights_file = "weights/shimnet_700MHz.pt"
176
  else:
177
  config_file = config_file.name
178
  weights_file = weights_file.name
@@ -186,14 +186,3 @@ with gr.Blocks() as app:
186
  )
187
 
188
  app.launch(share=args.share, server_name=args.server_name)
189
-
190
- # '#636efa',
191
- # '#EF553B',
192
- # '#00cc96',
193
- # '#ab63fa',
194
- # '#FFA15A',
195
- # '#19d3f3',
196
- # '#FF6692',
197
- # '#B6E880',
198
- # '#FF97FF',
199
- # '#FECB52'
 
1
+ import os
2
  import torch
3
  torch.set_grad_enabled(False)
4
  import numpy as np
 
7
  import gradio as gr
8
  import plotly.graph_objects as go
9
 
10
+ from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
 
11
 
12
  # silent deprecation warnings
13
  import warnings
 
168
  # Process button click logic
169
  def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file):
170
  if model_selection == "600 MHz":
171
+ config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_600.yaml")
172
+ weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_600MHz.pt")
173
  elif model_selection == "700 MHz":
174
+ config_file = os.path.join(os.path.dirname(__file__), "configs/shimnet_700.yaml")
175
+ weights_file = os.path.join(os.path.dirname(__file__), "weights/shimnet_700MHz.pt")
176
  else:
177
  config_file = config_file.name
178
  weights_file = weights_file.name
 
186
  )
187
 
188
  app.launch(share=args.share, server_name=args.server_name)
 
 
 
 
 
 
 
 
 
 
 
predict.py CHANGED
@@ -6,15 +6,14 @@ from pathlib import Path
6
  import sys, os
7
  from omegaconf import OmegaConf
8
 
9
- from src.models import ShimNetWithSCRF, Predictor
 
10
 
11
  # silent deprecation warnings
12
  # https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
13
  import warnings
14
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
15
 
16
- class Defaults:
17
- SCALE = 16.0
18
 
19
  def parse_args():
20
  parser = argparse.ArgumentParser()
@@ -26,23 +25,6 @@ def parse_args():
26
  args = parser.parse_args()
27
  return args
28
 
29
- # functions
30
- def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
31
- """resample input spectrum to match the model's frequency range"""
32
- freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point)
33
- spectrum = np.interp(freqs, input_freqs, input_spectrum)
34
- return freqs, spectrum
35
-
36
- def resample_output_spectrum(input_freqs, freqs, prediction):
37
- """resample prediction to match the input spectrum's frequency range"""
38
- prediction = np.interp(input_freqs, freqs, prediction)
39
- return prediction
40
-
41
- def initialize_predictor(config, weights_file):
42
- model = ShimNetWithSCRF(**config.model.kwargs)
43
- predictor = Predictor(model, weights_file)
44
- return predictor
45
-
46
  # run
47
  if __name__ == "__main__":
48
  args = parse_args()
 
6
  import sys, os
7
  from omegaconf import OmegaConf
8
 
9
+ from shimnet.predict_utils import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
10
+
11
 
12
  # silent deprecation warnings
13
  # https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
14
  import warnings
15
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
16
 
 
 
17
 
18
  def parse_args():
19
  parser = argparse.ArgumentParser()
 
25
  args = parser.parse_args()
26
  return args
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # run
29
  if __name__ == "__main__":
30
  args = parse_args()
pyproject.toml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.setuptools]
6
+ packages = ["shimnet"] # only include this package
7
+
8
+ [project]
9
+ name = "shimnet"
10
+ version = "0.1.0"
11
+ description = "Package for ShimNet: A Neural Network for Postacquisition Improvement of NMR Spectra Distorted by Magnetic-Field Inhomogeneity https://pubs.acs.org/doi/full/10.1021/acs.jpcb.5c02632"
12
+ authors = [
13
+ {name = "Marek Bukowicki", email = "m.bukowicki@uw.edu.pl"},
14
+ ]
15
+ readme = "Readme.md"
16
+ requires-python = ">=3.10"
17
+ dependencies = [
18
+ "nmrglue==0.11",
19
+ "torchdata==0.9.0",
20
+ "numpy==2.0.2",
21
+ "matplotlib==3.9.3",
22
+ "pandas==2.2.3",
23
+ "tqdm==4.67.1",
24
+ "hydra-core==1.3.2"
25
+ ]
26
+
27
+ [project.optional-dependencies]
28
+ # GPU installs the default torch build from PyPI (CUDA-enabled by default)
29
+ gpu = [
30
+ "torch==2.5.1",
31
+ "torchaudio==2.5.1"
32
+ ]
33
+
34
+ # CPU needs extra index URL: pip install shimnet[cpu] --extra-index-url https://download.pytorch.org/whl/cpu
35
+ cpu = [
36
+ "torch==2.5.1+cpu",
37
+ "torchaudio==2.5.1+cpu"
38
+ ]
39
+
40
+ # predictions with GUI
41
+ gui = [
42
+ "gradio==5.23.2",
43
+ "plotly==6.0.1",
44
+ "ipykernel",
45
+ "psutil",
46
+ "pexpect"
47
+ ]
requirements-cpu.txt DELETED
@@ -1,9 +0,0 @@
1
- torch==2.4.1+cpu
2
- torchaudio==2.4.1+cpu
3
- nmrglue==0.11
4
- torchdata==0.9.0
5
- numpy==2.0.2
6
- matplotlib==3.9.3
7
- pandas==2.2.3
8
- tqdm==4.67.1
9
- hydra-core==1.3.2
 
 
 
 
 
 
 
 
 
 
requirements-gpu.txt DELETED
@@ -1,9 +0,0 @@
1
- torch==2.4.1
2
- torchaudio==2.4.1
3
- nmrglue==0.11
4
- torchdata==0.9.0
5
- numpy==2.0.2
6
- matplotlib==3.9.3
7
- pandas==2.2.3
8
- tqdm==4.67.1
9
- hydra-core==1.3.2
 
 
 
 
 
 
 
 
 
 
requirements-gui.txt DELETED
@@ -1,2 +0,0 @@
1
- gradio==5.23.2
2
- plotly==6.0.1
 
 
 
shimnet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import generators, models, nmr_utils, predict_utils
{src → shimnet}/generators.py RENAMED
File without changes
{src → shimnet}/models.py RENAMED
File without changes
shimnet/nmr_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nmrglue as ng
2
+
3
+ def fid_to_spectrum(varian_fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=None, sin_pod=False):
4
+ dic, data = ng.varian.read(varian_fid_path)
5
+ data[0] *= 0.5
6
+ if sin_pod:
7
+ data = ng.proc_base.sp(data, end=0.98)
8
+
9
+ if target_length is not None:
10
+ if (pad_length := target_length - len(data)) > 0:
11
+ data = ng.proc_base.zf(data, pad_length)
12
+ else:
13
+ data = data[:target_length]
14
+
15
+ spec=ng.proc_base.fft(data)
16
+ spec = ng.process.proc_autophase.autops(spec, autophase_fn, p0=ph0_correction, p1=ph1_correction, disp=False)
17
+
18
+ return spec
shimnet/predict_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from .models import ShimNetWithSCRF, Predictor
4
+
5
+ class Defaults:
6
+ SCALE = 16.0
7
+
8
+ # functions
9
+ def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
10
+ """resample input spectrum to match the model's frequency range"""
11
+ freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point)
12
+ spectrum = np.interp(freqs, input_freqs, input_spectrum)
13
+ return freqs, spectrum
14
+
15
+ def resample_output_spectrum(input_freqs, freqs, prediction):
16
+ """resample prediction to match the input spectrum's frequency range"""
17
+ prediction = np.interp(input_freqs, freqs, prediction)
18
+ return prediction
19
+
20
+ def initialize_predictor(config, weights_file):
21
+ model = ShimNetWithSCRF(**config.model.kwargs)
22
+ predictor = Predictor(model, weights_file)
23
+ return predictor
train.py CHANGED
@@ -15,8 +15,8 @@ matplotlib.use('Agg')
15
  import warnings
16
  warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
17
 
18
- from src import models
19
- from src.generators import get_datapipe
20
 
21
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
  if len(sys.argv) < 2:
@@ -32,7 +32,7 @@ else:
32
  minimum = float("inf")
33
 
34
  # initialization
35
- model = instantiate({"_target_": f"__main__.models.{config.model.name}", **config.model.kwargs}).to(device)
36
  model_weights_file = run_dir / f'model.pt'
37
  optimizer = torch.optim.Adam(model.parameters())
38
  optimizer_weights_file = run_dir / f'optimizer.pt'
 
15
  import warnings
16
  warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
17
 
18
+ # from shiment import models
19
+ from shiment.generators import get_datapipe
20
 
21
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
  if len(sys.argv) < 2:
 
32
  minimum = float("inf")
33
 
34
  # initialization
35
+ model = instantiate({"_target_": f"shimnet.models.{config.model.name}", **config.model.kwargs}).to(device)
36
  model_weights_file = run_dir / f'model.pt'
37
  optimizer = torch.optim.Adam(model.parameters())
38
  optimizer_weights_file = run_dir / f'optimizer.pt'