Marek Bukowicki commited on
Commit
64b4096
·
1 Parent(s): 12f8c5c

add shimnet code

Browse files
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /usr/src/app
4
+
5
+ COPY requirements-cpu.txt requirements-gui.txt ./
6
+ RUN pip install --no-cache-dir -r requirements-cpu.txt -r requirements-gui.txt --extra-index-url https://download.pytorch.org/whl/cpu
7
+
8
+ COPY . .
9
+
10
+ RUN python download_files.py --overwrite
11
+
12
+ CMD [ "python", "./predict-gui.py", "--server_name", "0.0.0.0" ]
13
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 center4ml
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -9,4 +9,224 @@ license: mit
9
  short_description: ShimNet Spectra Correction
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  short_description: ShimNet Spectra Correction
10
  ---
11
 
12
+ # ShimNet
13
+ ShimNet is a data-driven AI solution to improve high-resolution nuclear magnetic resonance (NMR) spectra
14
+ distorted by the inhomogeneous magnetic field (less than optimal shimming). To use it, the experimental training data has to be collected (see **Data collection** below).
15
+ Example data can also be downloaded (see below).
16
+
17
+ Paper: [ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://chemrxiv.org/engage/chemrxiv/article-details/67ef866)
18
+
19
+ ## Installation
20
+
21
+ Python 3.9+ (3.10+ for GUI)
22
+
23
+ GPU version (for training and inference)
24
+ ```
25
+ pip install -r requirements-gpu.txt
26
+ ```
27
+
28
+ CPU version (for inference, not recommended for training)
29
+ ```
30
+ pip install -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
31
+ ```
32
+
33
+ ## Usage
34
+ To correct spectra presented in the paper:
35
+ 1. download weights (model parameters):
36
+ ```
37
+ python download_files.py
38
+ ```
39
+ or directly from [Google Drive 700MHz](https://drive.google.com/uc?export=download&id=17fTNWl7YW6mPbbZWga0EfdoF_6S8fCke) and [Google Drive 600MHz](https://drive.google.com/uc?export=download&id=1_VxOpFGJcFsOa5DHOW2GJbP8RvHCmC1N) and place it in `weights` directory
40
+
41
+
42
+ 2. : run correction (e.g. `Azarone_20ul_700MHz.csv`):
43
+ ```
44
+ python predict.py sample_data/Azarone_20ul_700MHz.csv -o output --config configs/shimnet_700.yaml --weights weights/shimnet_700MHz.pt
45
+ ```
46
+ The output will be `output/Azarone_20ul_700MHz_processed.csv` file
47
+
48
+ Multiple files may be processed using "*" syntax:
49
+ ```
50
+ python predict.py sample_data/*700MHz.csv -o output --config configs/shimnet_700.yaml --weights weights/shimnet_700MH
51
+ z.pt
52
+ ```
53
+
54
+ For 600 MHz data use `--config configs/shimnet_600.yaml` and `--weights weights/shimnet_600MHz.pt`, e.g.:
55
+
56
+ ```
57
+ python predict.py sample_data/CresolRed_after_styrene_600MHz.csv -o output --config configs/shimnet_600.yaml --weights weights/shimnet_600MHz.pt
58
+ ```
59
+
60
+ ### input format
61
+
62
+ The spectrum file for reconstruction should be in the format of two columns separated by a space and without the sign at the end of the line at the end of the file(example below):
63
+ ```csv
64
+ -1.97134 0.0167137
65
+ -1.97085 -0.00778748
66
+ -1.97036 -0.0109595
67
+ -1.96988 0.00825978
68
+ -1.96939 0.0133886
69
+ ```
70
+
71
+ ## Train on your data
72
+
73
+ For the model to function properly, it should be trained on calibration data from the spectrometer used for the measurements. To train a model on data from your spectrometer, please follow the instructions below.
74
+
75
+ ### Training data collection
76
+
77
+ Below we describe the training data collection for Agilent/Varian spectrometers. For machines of other vendors similar procedure can be implemented.
78
+ To collect ShimNet training data use Python script (sweep_shims_lineshape_Z1Z2.py) from the calibration_loop folder to drive the spectrometer:
79
+ 1. Install TReNDS package ( trends.spektrino.com )
80
+ 2. Open VnmrJ and type: 'listenon'
81
+ 3. Put the lineshape sample (1% CHCl3 in deuterated acetone), set standard PROTON parameters, and set nt=1 (do not modify sw and at!)
82
+ 4. Shim the sample and collect the data. Save the optimally shimmed dataset
83
+ 5. Edit the sweep_shims_lineshape_Z1Z2.py script
84
+ 6. Put optimum z1 and z2 shim values as optiz1 and optiz2 below
85
+ 7. Define the calibration range as range_z1 and range_z2 (default is ok)
86
+ 8. Start the python script:
87
+ ```
88
+ python3 ./sweep_shims_lineshape_Z1Z2.py
89
+ ```
90
+ The spectrometer will start collecting spectra
91
+
92
+ ### SCRF extraction
93
+ Shim Coil Response Functions (SCRF) should be extracted from the spectra with `extract_scrf_from_fids.py` script.
94
+ ```
95
+ python extract_scrf_from_fids.py
96
+ ```
97
+
98
+ The script uses hardcoded paths to the NMR signals (fid-s) in Agilent/Varian format: a directory with optimal measurement (`opti_fid_path` available) and a directory with calibration loop measurements (`data_dir`):
99
+ ```python
100
+ # input
101
+ data_dir = "../../sample_run/loop"
102
+ opti_fid_path = "../../sample_run/opti.fid"
103
+
104
+ ```
105
+
106
+ The output files are also hardcoded:
107
+ ```python
108
+ # output
109
+ spectra_file = "../../sample_run/total.npy"
110
+ spectra_file_names = "../../sample_run/total.csv"
111
+ opi_spectrum_file = "../../sample_run/opti.npy"
112
+ responses_file = "../../sample_run/scrf_61.pt"
113
+ ```
114
+ where only the `responses_file` is used in ShimNet training.
115
+
116
+ If the measurements are stored in a format other than Varian, you may need to change this line:
117
+ ```python
118
+ dic, data = ng.varian.read(varian_fid_path)
119
+ ```
120
+ (see nmrglue package documentation for details)
121
+
122
+ ### Training
123
+
124
+ 1. Download multiplets database:
125
+ ```
126
+ python download_files.py --multiplets
127
+ ```
128
+ 2. Configure run:
129
+ - create a run directory, e.g. `runs/my_lab_spectrometer_2025`
130
+ - create a configuration file:
131
+ 1. copy `configs/shimnet_template.py` to the run directory and rename it to `config.yaml`
132
+ ```bash
133
+ cp configs/shimnet_template.py runs/my_lab_spectrometer_2025/config.yaml
134
+ ```
135
+ 2. edit the SCRF in path in the config file:
136
+ ```yaml
137
+ response_functions_files:
138
+ - path/to/srcf_file
139
+ ```
140
+ e.g.
141
+ ```yaml
142
+ response_functions_files:
143
+ - ../../sample_run/scrf_61.pt
144
+ ```
145
+ 3. adjust spectrometer frequency step `frq_step` to match your data (spectrometer range in Hz divided by number of points in spectrum):
146
+ ```yaml
147
+ frq_step: 0.34059797
148
+ ```
149
+ 4. adjust spectromer frequency in the metadata
150
+ ```yaml
151
+ metadata: # additional metadata, not used in the training process
152
+ spectrometer_frequency: 700.0 # MHz
153
+ ```
154
+ 3. Run training:
155
+ ```
156
+ python train.py runs/my_lab_spectrometer_2025
157
+ ```
158
+ Training results will appear in `runs/my_lab_spectrometer_2025` directory.
159
+ Model parameters are stored in `runs/my_lab_spectrometer_2025/model.pt` file
160
+ 4. Use trained model:
161
+
162
+ use `--config runs/my_lab_spectrometer_2025/config.yaml` and `--weights runs/my_lab_spectrometer_2025/model.pt` flags, e.g.
163
+ ```
164
+ python predict.py my_sample1.csv -o my_output --config runs/my_lab_spectrometer_2025/config.yaml --weights runs/my_lab_spectrometer_2025/model.pt
165
+ ```
166
+
167
+ ## Repeat training on our data
168
+
169
+ If you want to train the network using the calibration data from our paper, follow the procedure below.
170
+
171
+ 1. Download multiplets database and our SCRF files:
172
+ ```
173
+ python download_files.py --multiplets --SCRF --no-weights
174
+ ```
175
+ or directly download from Google Drive and store in `data/` directory: [Response Functions 600MHz](https://drive.google.com/file/d/1J-DsPtaITXU3TFrbxaZPH800U1uIiwje/view?usp=sharing), [Response Functions 700MHz](https://drive.google.com/file/d/113al7A__yYALx_2hkESuzFIDU3feVtNY/view?usp=sharing), [Multiplets data](https://drive.google.com/file/d/1QGvV-Au50ZxaP1vFsmR_auI299Dw-Wrt/view?usp=sharing)
176
+
177
+ 2. Configure run
178
+ - For 600MHz spectrometer:
179
+ ```bash
180
+ mkdir -p runs/repeat_paper_training_600MHz
181
+ cp configs/shimnet_600.yaml runs/repeat_paper_training_600MHz/config.yaml
182
+ ```
183
+ - For 700 MHz spectrometer:
184
+ ```bash
185
+ mkdir -p runs/repeat_paper_training_700MHz
186
+ cp configs/shimnet_700.yaml runs/repeat_paper_training_700MHz/config.yaml
187
+ ```
188
+ 3. Run training:
189
+ ```
190
+ python train.py runs/repeat_paper_training_600MHz
191
+ ```
192
+ or
193
+ ```
194
+ python train.py runs/repeat_paper_training_700MHz
195
+ ```
196
+ Training results will appear in `runs/repeat_paper_training_600MHz` or `runs/repeat_paper_training_700MHz` directory.
197
+
198
+ ## GUI
199
+
200
+ ### Installation
201
+
202
+ To use the ShimNet GUI, ensure you have Python 3.10 installed (not tested with Python 3.11+). After installing the ShimNet requirements (CPU/GPU), install the additional dependencies for the GUI:
203
+
204
+ ```bash
205
+ pip install -r requirements-gui.txt
206
+ ```
207
+
208
+ ### Launching the GUI
209
+
210
+ The ShimNet GUI is built using Gradio. To start the application, run:
211
+
212
+ ```bash
213
+ python predict-gui.py
214
+ ```
215
+
216
+ Once the application starts, open your browser and navigate to:
217
+
218
+ ```
219
+ http://127.0.0.1:7860
220
+ ```
221
+
222
+ to access the GUI locally.
223
+
224
+ ### Sharing the GUI
225
+
226
+ To make the GUI accessible over the internet, use the `--share` flag:
227
+
228
+ ```bash
229
+ python predict-gui.py --share
230
+ ```
231
+
232
+ A public web address will be displayed in the terminal, which you can use to access the GUI remotely or share with others.
calibration_loop/maclib/shimator_loop ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ " macro shimnet loop "
2
+ " arguments: optimal z1, optimal z2, step in z1 loop, step in z2 loop, half number of steps in z1, half number of steps in z2"
3
+
4
+ $svfdir='/home/nmrbox/kkazimierczuk/shimnet/data/'
5
+
6
+ $opti_z1=$1
7
+ $opti_z2=$2
8
+ $step_z1=$3
9
+ $step_z2=$4
10
+ $steps_z1=$5
11
+ $steps_z2=$6
12
+ $nn=''
13
+ $bb=''
14
+ $msg=''
15
+ format((($steps_z1*2+1)*($steps_z2*2+1)*(d1+at)*nt)/3600,5,1):$msg
16
+ $msg='Time: '+$msg+' h'
17
+ banner($msg)
18
+ $j=0
19
+ repeat
20
+ $i=0
21
+ $z2=$opti_z2-$steps_z2*$step_z2+$j*$step_z2
22
+ format($z2,5,1):$bb
23
+ repeat
24
+ $z1=$opti_z1-$steps_z1*$step_z1+$i*$step_z1
25
+ "su"
26
+ "go"
27
+ format($z1,5,1):$nn
28
+ $filepath=$svfdir + 'z1_'+$nn+'z2_'+$bb+'.fid'
29
+ svf($filepath,'force')
30
+ $i=$i+1
31
+ until $z1>$opti_z1+$steps_z1*$step_z1-1
32
+ $j=$j+1
33
+ until $z2>$opti_z2+$steps_z2*$step_z2-1
calibration_loop/sweep_shims_lineshape_Z1Z2.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from TReND.send2vnmr import *
2
+ import shutil
3
+ import time
4
+ import os
5
+ import numpy as np
6
+ from subprocess import call
7
+
8
+ #This is the script to collect ShimNet trainig data on Agilent spectrometers. Execution:
9
+ # 1. Open VnmrJ and type: 'listenon'
10
+ # 2. Put the lineshape sample, set standard PROTON parameters and set one scan (do not modify sw and at!)
11
+ # 3. Shim the sample and collect the data. Save the optimally shimmed datatset
12
+ # 4. Put optimum z1 and z2 shim values as optiz1 and optiz2 below
13
+ # 5. Define the calibration range as range_z1 and range_z2 (default is ok)
14
+ # 6. Start the python script:
15
+ # python3 ./sweep_shims_lineshape_Z1Z2.py
16
+ # the spectrometer will start collecting spectra
17
+
18
+ Classic_path = '/home/nmr700/shimnet6/lshp'
19
+ optiz1= 8868 #put optimum shim values here
20
+ optiz2=-297
21
+
22
+ range_z1=100 #put optimum shim ranges here
23
+ range_z2=100 #put optimum shim ranges here
24
+
25
+ z1_sweep=np.arange(optiz1-range_z1,optiz1+range_z1+1,2.0)
26
+ z2_sweep=np.arange(optiz2-range_z2,optiz2+range_z2+1,2.0)
27
+
28
+ for i in range(1,np.shape(z1_sweep)[0]+1, 1):
29
+ for j in range(1,np.shape(z2_sweep)[0]+1, 1):
30
+ wait_until_idle()
31
+
32
+ Run_macro("sethw('z1',"+str(z1_sweep[i-1])+ ")")
33
+ Run_macro("sethw('z2',"+str(z2_sweep[j-1])+ ")")
34
+
35
+ go_if_idle()
36
+ wait_until_idle()
37
+ time.sleep(0.5)
38
+
39
+ Save_experiment(Classic_path + '_z1_'+ str(int(z1_sweep[i-1])) + '_z2_'+ str(int(z2_sweep[j-1]))+ '.fid')
configs/shimnet_600.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: ShimNetWithSCRF
3
+ kwargs:
4
+ rensponse_length: 81
5
+ resnponse_head_dims:
6
+ - 128
7
+ training:
8
+ - batch_size: 64
9
+ learning_rate: 0.001
10
+ max_iters: 1600000
11
+ - batch_size: 512
12
+ learning_rate: 0.001
13
+ max_iters: 25600000
14
+ - batch_size: 512
15
+ learning_rate: 0.0005
16
+ max_iters: 12800000
17
+ losses_weights:
18
+ clean: 1.0
19
+ noised: 1.0
20
+ response: 1.0
21
+ data:
22
+ response_functions_files:
23
+ # Paste path to your SCRF file here
24
+ # - Can be absolute path
25
+ # - Can be relative to repository root
26
+ - data/scrf_81_600MHz.pt
27
+ atom_groups_data_file: data/multiplets_10000_parsed.txt
28
+ response_function_stretch_min: 1.0
29
+ response_function_stretch_max: 1.0
30
+ response_function_noise: 0.0
31
+ multiplicity_j1_min: 0.0
32
+ multiplicity_j1_max: 15
33
+ multiplicity_j2_min: 0.0
34
+ multiplicity_j2_max: 15
35
+ number_of_signals_min: 2
36
+ number_of_signals_max: 5
37
+ thf_min: 0.5
38
+ thf_max: 2
39
+ relative_height_min: 0.5
40
+ relative_height_max: 4
41
+ frq_step: 0.30048
42
+ logging:
43
+ step: 1000000
44
+ num_plots: 32
45
+ metadata:
46
+ spectrometer_frequency: 600.0
configs/shimnet_700.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: ShimNetWithSCRF
3
+ kwargs:
4
+ rensponse_length: 61
5
+ resnponse_head_dims:
6
+ - 128
7
+ training:
8
+ - batch_size: 64
9
+ learning_rate: 0.001
10
+ max_iters: 1600000
11
+ - batch_size: 512
12
+ learning_rate: 0.001
13
+ max_iters: 6400000
14
+ - batch_size: 512
15
+ learning_rate: 0.0005
16
+ max_iters: 12800000
17
+ - batch_size: 800
18
+ learning_rate: 0.0005
19
+ max_iters: 12800000
20
+ losses_weights:
21
+ clean: 1.0
22
+ noised: 1.0
23
+ response: 1.0
24
+ data:
25
+ response_functions_files:
26
+ - data/scrf_61_700MHz.pt
27
+ atom_groups_data_file: data/multiplets_10000_parsed.txt
28
+ response_function_stretch_min: 1.0
29
+ response_function_stretch_max: 1.0
30
+ response_function_noise: 0.0
31
+ multiplicity_j1_min: 0.0
32
+ multiplicity_j1_max: 15
33
+ multiplicity_j2_min: 0.0
34
+ multiplicity_j2_max: 15
35
+ number_of_signals_min: 2
36
+ number_of_signals_max: 5
37
+ thf_min: 0.5
38
+ thf_max: 2
39
+ relative_height_min: 0.5
40
+ relative_height_max: 4
41
+ frq_step: 0.34059797
42
+ logging:
43
+ step: 1000000
44
+ num_plots: 32
45
+ metadata:
46
+ spectrometer_frequency: 700.0
configs/shimnet_template.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: ShimNetWithSCRF
3
+ kwargs:
4
+ rensponse_length: 61
5
+ resnponse_head_dims:
6
+ - 128
7
+ training:
8
+ - batch_size: 64
9
+ learning_rate: 0.001
10
+ max_iters: 1600000
11
+ - batch_size: 512
12
+ learning_rate: 0.001
13
+ max_iters: 6400000
14
+ - batch_size: 512
15
+ learning_rate: 0.0005
16
+ max_iters: 12800000
17
+ - batch_size: 800
18
+ learning_rate: 0.0005
19
+ max_iters: 12800000
20
+ losses_weights:
21
+ clean: 1.0
22
+ noised: 1.0
23
+ response: 1.0
24
+ data:
25
+ response_functions_files:
26
+ # Specify the path to your SCRF file/files here.
27
+ # - It can be an absolute path.
28
+ # - It can be relative to the repository root.
29
+ # Multiple files can be listed in the following rows (YAML list format).
30
+ - path/to/srcf_file
31
+ atom_groups_data_file: data/multiplets_10000_parsed.txt
32
+ response_function_stretch_min: 1.0 # realtive
33
+ response_function_stretch_max: 1.0 # realtive
34
+ response_function_noise: 0.0 # arbitrary hight units
35
+ multiplicity_j1_min: 0.0 # Hz
36
+ multiplicity_j1_max: 15 # Hz
37
+ multiplicity_j2_min: 0.0 # Hz
38
+ multiplicity_j2_max: 15 # Hz
39
+ number_of_signals_min: 2
40
+ number_of_signals_max: 5
41
+ thf_min: 0.5 # arbitrary hight units
42
+ thf_max: 2 # arbitrary hight units
43
+ relative_height_min: 0.5 # realtive
44
+ relative_height_max: 4 # realtive
45
+ frq_step: 0.34059797 # Hz per point
46
+ logging:
47
+ step: 1000000
48
+ num_plots: 32
49
+ metadata: # additional metadata, not used in the training process
50
+ spectrometer_frequency: 700.0 # MHz
download_files.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ from pathlib import Path
3
+ import argparse
4
+
5
+ ALL_FILES_TO_DOWNLOAD = {
6
+ "weights": [{
7
+ "url": "https://drive.google.com/uc?export=download&id=17fTNWl7YW6mPbbZWga0EfdoF_6S8fCke",
8
+ "destination": "weights/shimnet_700MHz.pt"
9
+ },
10
+ {
11
+ "url": "https://drive.google.com/uc?export=download&id=1_VxOpFGJcFsOa5DHOW2GJbP8RvHCmC1N",
12
+ "destination": "weights/shimnet_600MHz.pt"
13
+ }],
14
+ "SCRF": [{
15
+ "url": "https://drive.google.com/uc?export=download&id=113al7A__yYALx_2hkESuzFIDU3feVtNY",
16
+ "destination": "data/scrf_61_700MHz.pt"
17
+ },
18
+ {
19
+ "url": "https://drive.google.com/uc?export=download&id=1J-DsPtaITXU3TFrbxaZPH800U1uIiwje",
20
+ "destination": "data/scrf_81_600MHz.pt"
21
+ }],
22
+ "mupltiplets": [{
23
+ "url": "https://drive.google.com/uc?export=download&id=1QGvV-Au50ZxaP1vFsmR_auI299Dw-Wrt",
24
+ "destination": "data/multiplets_10000_parsed.txt"
25
+ }],
26
+ "development": []
27
+ }
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(
31
+ description='Download files: weighs (default), SCRF (optional), multiplet data (optional)',
32
+ )
33
+ parser.add_argument('--overwrite', action='store_true', help='Overwrite existing files')
34
+ parser.add_argument(
35
+ '--weights',
36
+ action='store_true',
37
+ default=True,
38
+ help='Download weights file (default behavior). Use --no-weights to opt out.',
39
+ )
40
+ parser.add_argument(
41
+ '--no-weights',
42
+ action='store_false',
43
+ dest='weights',
44
+ help='Do not download weights file.',
45
+ )
46
+ parser.add_argument('--SCRF', action='store_true', help='Download SCRF files - Shim Coil Response Functions')
47
+ parser.add_argument('--multiplets', action='store_true', help='Download multiplets data file')
48
+ parser.add_argument('--development', action='store_true', help='Download development weights file')
49
+
50
+ parser.add_argument('--all', action='store_true', help='Download all available files')
51
+
52
+ args = parser.parse_args()
53
+ # Set all individual flags if --all is specified
54
+ if args.all:
55
+ args.weights = True
56
+ args.SCRF = True
57
+ args.multiplets = True
58
+ args.development = True
59
+
60
+ return args
61
+
62
+ def download_file(url, target, overwrite=False):
63
+ target = Path(target)
64
+ if target.exists() and not overwrite:
65
+ response = input(f"File {target} already exists. Overwrite? (y/n): ")
66
+ if response.lower() != 'y':
67
+ print(f"Download of {target} cancelled")
68
+ return
69
+ target.parent.mkdir(parents=True, exist_ok=True)
70
+ try:
71
+ urllib.request.urlretrieve(url, target)
72
+ print(f"Downloaded {target}")
73
+ except Exception as e:
74
+ print(f"Failed to download file from {url}:\n {e}")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ args = parse_args()
79
+
80
+ main_dir = Path(__file__).parent
81
+ if args.weights:
82
+ for file_data in ALL_FILES_TO_DOWNLOAD["weights"]:
83
+ download_file(file_data["url"], main_dir / file_data["destination"], args.overwrite)
84
+
85
+ if args.SCRF:
86
+ for file_data in ALL_FILES_TO_DOWNLOAD["SCRF"]:
87
+ download_file(file_data["url"], main_dir / file_data["destination"], args.overwrite)
88
+
89
+ if args.multiplets:
90
+ for file_data in ALL_FILES_TO_DOWNLOAD["mupltiplets"]:
91
+ download_file(file_data["url"], main_dir / file_data["destination"], args.overwrite)
92
+
93
+ if args.development:
94
+ for file_data in ALL_FILES_TO_DOWNLOAD["development"]:
95
+ download_file(file_data["url"], main_dir / file_data["destination"], args.overwrite)
extract_scrf_from_fids.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import nmrglue as ng
9
+ from tqdm import tqdm
10
+
11
+ # input
12
+ data_dir = "../../../dane/600Hz_20241227/loop"
13
+ opti_fid_path = "../../../dane/600Hz_20241227/opti"
14
+
15
+ # output
16
+ spectra_file = "data/600Hz_20241227-Lnative/total.npy"
17
+ spectra_file_names = "data/600Hz_20241227-Lnative/total.csv"
18
+ opi_spectrum_file = "data/600Hz_20241227-Lnative/opti.npy"
19
+ responses_file = "data/600Hz_20241227-Lnative/scrf_61.pt"
20
+ losses_file = "data/600Hz_20241227-Lnative/losses_scrf_61.pt"
21
+
22
+ # create directories for output files
23
+ for file_path in [spectra_file, spectra_file_names, opi_spectrum_file, responses_file]:
24
+ Path(file_path).parent.mkdir(parents=True, exist_ok=True)
25
+
26
+ # settings: fid to spectrum
27
+ ph0_correction =-190.49
28
+ ph1_correction = 0
29
+ autophase_fn="acme"
30
+ target_length = None
31
+
32
+ # settings: SCRF extraction
33
+ calibration_peak_center = "auto"#12277 # the center of the calibration peak (peak of)
34
+ calibration_window_halfwidth = 128 # the half width of the calibration window
35
+ steps = 6000
36
+ kernel_size = 61
37
+ kernel_sqrt = False # True to allow only positive values
38
+
39
+ # fid to spectra processing
40
+
41
+ def fid_to_spectrum(varian_fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=None, sin_pod=False):
42
+ dic, data = ng.varian.read(varian_fid_path)
43
+ data[0] *= 0.5
44
+ if sin_pod:
45
+ data = ng.proc_base.sp(data, end=0.98)
46
+
47
+ if target_length is not None:
48
+ if (pad_length := target_length - len(data)) > 0:
49
+ data = ng.proc_base.zf(data, pad_length)
50
+ else:
51
+ data = data[:target_length]
52
+
53
+ spec=ng.proc_base.fft(data)
54
+ spec = ng.process.proc_autophase.autops(spec, autophase_fn, p0=ph0_correction, p1=ph1_correction, disp=False)
55
+
56
+ return spec
57
+
58
+ # process optimal measurement fid to spectrum
59
+ opti_spectrum_full = fid_to_spectrum(opti_fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=target_length)
60
+
61
+ if calibration_peak_center == "auto":
62
+ calibration_peak_center = np.argmax(abs(opti_spectrum_full))
63
+ fitting_range = (calibration_peak_center - calibration_window_halfwidth, calibration_peak_center+calibration_window_halfwidth+1)
64
+
65
+ opti_spectrum = opti_spectrum_full[fitting_range[0]:fitting_range[1]]
66
+ np.save(opi_spectrum_file, opti_spectrum)
67
+ print(f"Optimal spectrum extracted to {opi_spectrum_file}")
68
+
69
+ # process loop fids to spectra
70
+ spec_list=[]
71
+ spec_names=[]
72
+
73
+ print("Extracting spectra from fids...")
74
+ for fid_path in tqdm(list(Path(data_dir).rglob('*.fid'))):
75
+ spec = fid_to_spectrum(fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=target_length)[fitting_range[0]:fitting_range[1]]
76
+
77
+ spec_list.append(spec)
78
+ spec_names.append(fid_path.name)
79
+
80
+ total = np.array(spec_list)
81
+ np.save(spectra_file, total)
82
+ pd.DataFrame(spec_names).to_csv(spectra_file_names, header=False)
83
+ # total = np.load(spectra_file)
84
+ print(f"Spectra extracted to {spectra_file}")
85
+
86
+ # process SCRF extraction
87
+ def fit_kernel(base, target, kernel_size, kernel_sqrt=True, steps=20000, verbose=False):
88
+
89
+ kernel = torch.ones((1,1,kernel_size), dtype=base.dtype)
90
+ if kernel_sqrt:
91
+ kernel /= torch.sqrt(torch.sum(kernel**2))
92
+ else:
93
+ kernel /= kernel_size
94
+ kernel.requires_grad = True
95
+
96
+ optimizer = torch.optim.Adam([kernel])
97
+
98
+ for epoch in range(steps):
99
+ if kernel_sqrt:
100
+ spe_est = torch.conv1d(base, kernel**2, padding='same')
101
+ else:
102
+ spe_est = torch.conv1d(base, kernel, padding='same')
103
+ loss = torch.mean(abs(target - spe_est)**2) #torch.nn.functional.mse_loss(spe_est, target)
104
+ loss.backward()
105
+ optimizer.step()
106
+ optimizer.zero_grad()
107
+
108
+ if verbose and (epoch+1) % 100 == 0:
109
+ print(epoch, loss.item())
110
+ if kernel_sqrt:
111
+ return kernel.detach()**2, loss.item()
112
+ else:
113
+ return kernel.detach(), loss.item()
114
+
115
+ responses = torch.empty(len(total), 1, 1, 1, 1, kernel_size)
116
+ losses = torch.empty(len(total))
117
+ base = torch.tensor(opti_spectrum.real).unsqueeze(0)
118
+ targets = torch.tensor(total.real)
119
+
120
+ # normalization
121
+ base /= base.sum()
122
+ targets /= targets.sum(dim=(-1,), keepdim=True)
123
+
124
+ print("\nExtracting SCRFs...")
125
+ for i, target in tqdm(enumerate(targets), total=len(targets)):
126
+ kernel, loss = fit_kernel(base, target.unsqueeze(0), kernel_size, kernel_sqrt=kernel_sqrt, steps=steps)
127
+ responses[i, 0, 0] = kernel
128
+ losses[i] = loss
129
+
130
+ torch.save(responses, responses_file)
131
+ torch.save(losses, losses_file)
132
+ print(f"SCRFs extracted to {responses_file}, losses saved to {losses_file}")
predict-gui.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.set_grad_enabled(False)
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from omegaconf import OmegaConf
6
+ import gradio as gr
7
+ import plotly.graph_objects as go
8
+
9
+ from src.models import ShimNetWithSCRF, Predictor
10
+ from predict import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
11
+
12
+ # silent deprecation warnings
13
+ import warnings
14
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
15
+
16
+ import argparse
17
+
18
+ # Add argument parsing for server_name
19
+ parser = argparse.ArgumentParser(description="Launch ShimNet Spectra Correction App")
20
+ parser.add_argument(
21
+ "--server_name",
22
+ type=str,
23
+ default="127.0.0.1",
24
+ help="Server name to bind the app (default: 127.0.0.1). Use 0.0.0.0 for external access."
25
+ )
26
+ parser.add_argument(
27
+ "--share",
28
+ action="store_true",
29
+ help="If set, generates a public link to share the app."
30
+ )
31
+ args = parser.parse_args()
32
+
33
+ def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None):
34
+ if input_spectrometer_frequency == 0:
35
+ input_spectrometer_frequency = None
36
+ # Load configuration and initialize predictor
37
+ config = OmegaConf.load(config_file)
38
+ model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
39
+ predictor = initialize_predictor(config, weights_file)
40
+
41
+ # Load input data
42
+ input_data = np.loadtxt(input_file)
43
+ input_freqs_input_ppm, input_spectrum = input_data[:, 0], input_data[:, 1]
44
+
45
+ # Convert input frequencies to model's frequency
46
+ if input_spectrometer_frequency is not None:
47
+ input_freqs_model_ppm = input_freqs_input_ppm * input_spectrometer_frequency / config.metadata.spectrometer_frequency
48
+ else:
49
+ input_freqs_model_ppm = input_freqs_input_ppm
50
+
51
+ # Resample input spectrum
52
+ freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point)
53
+
54
+ # Scale and process spectrum
55
+ spectrum_tensor = torch.tensor(spectrum).float()
56
+ scaling_factor = Defaults.SCALE / spectrum_tensor.max()
57
+ spectrum_tensor *= scaling_factor
58
+ prediction = predictor(spectrum_tensor).numpy()
59
+ prediction /= scaling_factor
60
+
61
+ # Resample output spectrum
62
+ output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)
63
+
64
+ # Prepare output data for download
65
+ output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
66
+ output_file = f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
67
+ np.savetxt(output_file, output_data)
68
+
69
+ # Create Plotly figure
70
+ fig = go.Figure()
71
+
72
+ # Add Input Spectrum and Corrected Spectrum (always visible)
73
+ normalization_value = input_spectrum.max()
74
+ fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=input_spectrum/normalization_value, mode='lines', name='Input Spectrum', visible=True, line=dict(color='#EF553B'))) # red
75
+ fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=output_prediction/normalization_value, mode='lines', name='Corrected Spectrum', visible=True, line=dict(color='#00cc96'))) # green
76
+
77
+ if reference_spectrum is not None:
78
+ reference_spectrum_freqs, reference_spectrum_intensity = np.loadtxt(reference_spectrum).T
79
+ reference_spectrum_intensity /= reference_spectrum_intensity.max()
80
+ n_zooms = 50
81
+ zooms = np.geomspace(0.01, 100, 2 * n_zooms + 1)
82
+
83
+ # Add Reference Data traces (initially invisible)
84
+ for zoom in zooms:
85
+ fig.add_trace(
86
+ go.Scatter(
87
+ x=reference_spectrum_freqs,
88
+ y=reference_spectrum_intensity * zoom,
89
+ mode='lines',
90
+ name=f'Reference Data (Zoom: {zoom:.2f})',
91
+ visible=False,
92
+ line=dict(color='#636efa')
93
+ )
94
+ )
95
+ # Make the middle zoom level visible by default
96
+ fig.data[2 * n_zooms // 2 + 2].visible = True
97
+
98
+ # Create and add slider
99
+ steps = []
100
+ for i in range(2, len(fig.data)): # Start from the reference data traces
101
+ step = dict(
102
+ method="update",
103
+ args=[{"visible": [True, True] + [False] * (len(fig.data) - 2)}], # Keep first two traces visible
104
+ )
105
+ step["args"][0]["visible"][i] = True # Toggle i'th reference trace to "visible"
106
+ steps.append(step)
107
+
108
+ sliders = [dict(
109
+ active=n_zooms,
110
+ currentvalue={"prefix": "Reference zoom: "},
111
+ pad={"t": 50},
112
+ steps=steps
113
+ )]
114
+
115
+ fig.update_layout(
116
+ sliders=sliders
117
+ )
118
+
119
+ fig.update_layout(
120
+ title="Spectrum Visualization",
121
+ xaxis_title="Frequency (ppm)",
122
+ yaxis_title="Intensity"
123
+ )
124
+
125
+ return fig, output_file
126
+
127
+ # Gradio app
128
+ with gr.Blocks() as app:
129
+ gr.Markdown("# ShimNet Spectra Correction")
130
+ gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://chemrxiv.org/engage/chemrxiv/article-details/67ef86686dde43c90860d315)")
131
+ gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.")
132
+
133
+ with gr.Row():
134
+ with gr.Column():
135
+ model_selection = gr.Radio(
136
+ label="Select Model",
137
+ choices=["600 MHz", "700 MHz", "Custom"],
138
+ value="600 MHz"
139
+ )
140
+ config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
141
+ weights_file = gr.File(label="Custom Weights File (.pt)", visible=False, height=120)
142
+
143
+ with gr.Column():
144
+ input_file = gr.File(label="Input File (.txt | .csv)", height=120)
145
+ input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
146
+ gr.Markdown("Upload reference spectrum files (optional). Reference spectrum will be plotted for comparison.")
147
+ reference_spectrum_file = gr.File(label="Reference Spectra File (.txt | .csv)", height=120)
148
+
149
+ process_button = gr.Button("Process File")
150
+ plot_output = gr.Plot(label="Spectrum Visualization")
151
+ download_button = gr.File(label="Download Processed File", interactive=False, height=120)
152
+
153
+ # Update visibility of config and weights fields based on model selection
154
+ def update_visibility(selected_model):
155
+ if selected_model == "Custom":
156
+ return gr.update(visible=True), gr.update(visible=True)
157
+ else:
158
+ return gr.update(visible=False), gr.update(visible=False)
159
+
160
+ model_selection.change(
161
+ update_visibility,
162
+ inputs=[model_selection],
163
+ outputs=[config_file, weights_file]
164
+ )
165
+
166
+ # Process button click logic
167
+ def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file):
168
+ if model_selection == "600 MHz":
169
+ config_file = "configs/shimnet_600.yaml"
170
+ weights_file = "weights/shimnet_600MHz.pt"
171
+ elif model_selection == "700 MHz":
172
+ config_file = "configs/shimnet_700.yaml"
173
+ weights_file = "weights/shimnet_700MHz.pt"
174
+ else:
175
+ config_file = config_file.name
176
+ weights_file = weights_file.name
177
+
178
+ return process_file(input_file.name, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file.name if reference_spectrum_file else None)
179
+
180
+ process_button.click(
181
+ process_file_with_model,
182
+ inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file],
183
+ outputs=[plot_output, download_button]
184
+ )
185
+
186
+ app.launch(share=args.share, server_name=args.server_name)
187
+
188
+ # '#636efa',
189
+ # '#EF553B',
190
+ # '#00cc96',
191
+ # '#ab63fa',
192
+ # '#FFA15A',
193
+ # '#19d3f3',
194
+ # '#FF6692',
195
+ # '#B6E880',
196
+ # '#FF97FF',
197
+ # '#FECB52'
predict.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.set_grad_enabled(False)
3
+ import numpy as np
4
+ import argparse
5
+ from pathlib import Path
6
+ import sys, os
7
+ from omegaconf import OmegaConf
8
+
9
+ from src.models import ShimNetWithSCRF, Predictor
10
+
11
+ # silent deprecation warnings
12
+ # https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
13
+ import warnings
14
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
15
+
16
+ class Defaults:
17
+ SCALE = 16.0
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("input_files", help="Input files", nargs="+")
22
+ parser.add_argument("--config", help="config file .yaml")
23
+ parser.add_argument("--weights", help="model weights")
24
+ parser.add_argument("-o", "--output_dir", default=".", help="Output directory")
25
+ parser.add_argument("--input_spectrometer_frequency", default=None, type=float, help="spectrometer frequency in MHz (input sample collection frequency). Empty if the same as in the training data")
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+ # functions
30
+ def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
31
+ """resample input spectrum to match the model's frequency range"""
32
+ freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point)
33
+ spectrum = np.interp(freqs, input_freqs, input_spectrum)
34
+ return freqs, spectrum
35
+
36
+ def resample_output_spectrum(input_freqs, freqs, prediction):
37
+ """resample prediction to match the input spectrum's frequency range"""
38
+ prediction = np.interp(input_freqs, freqs, prediction)
39
+ return prediction
40
+
41
+ def initialize_predictor(config, weights_file):
42
+ model = ShimNetWithSCRF(**config.model.kwargs)
43
+ predictor = Predictor(model, weights_file)
44
+ return predictor
45
+
46
+ # run
47
+ if __name__ == "__main__":
48
+ args = parse_args()
49
+ output_dir = Path(args.output_dir)
50
+ output_dir.mkdir(exist_ok=True, parents=True)
51
+
52
+ config = OmegaConf.load(args.config)
53
+ model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
54
+ predictor = initialize_predictor(config, args.weights)
55
+
56
+ for input_file in args.input_files:
57
+ print(f"processing {input_file} ...")
58
+
59
+ # load data
60
+ input_data = np.loadtxt(input_file)
61
+ input_freqs_input_ppm, input_spectrum = input_data[:,0], input_data[:,1]
62
+
63
+ # convert input frequencies to model's frequency - correct for zero filling ad spectrometer frequency
64
+ if args.input_spectrometer_frequency is not None:
65
+ input_freqs_model_ppm = input_freqs_input_ppm * args.input_spectrometer_frequency / config.metadata.spectrometer_frequency
66
+ else:
67
+ input_freqs_model_ppm = input_freqs_input_ppm
68
+
69
+ freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point)
70
+
71
+ spectrum = torch.tensor(spectrum).float()
72
+ # scale height of the spectrum
73
+ scaling_factor = Defaults.SCALE / spectrum.max()
74
+ spectrum *= scaling_factor
75
+
76
+ # correct spectrum
77
+ prediction = predictor(spectrum).numpy()
78
+
79
+ # rescale height
80
+ prediction /= scaling_factor
81
+
82
+ # resample the output to match the input spectrum
83
+ output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)
84
+
85
+ # save result
86
+ output_file = output_dir / f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
87
+
88
+ np.savetxt(output_file, np.column_stack((input_freqs_input_ppm, output_prediction)))
89
+ print(f"saved to {output_file}")
requirements-cpu.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1+cpu
2
+ torchaudio==2.4.1+cpu
3
+ nmrglue==0.11
4
+ torchdata==0.9.0
5
+ numpy==2.0.2
6
+ matplotlib==3.9.3
7
+ pandas==2.2.3
8
+ tqdm==4.67.1
9
+ hydra-core==1.3.2
requirements-gpu.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ torchaudio==2.4.1
3
+ nmrglue==0.11
4
+ torchdata==0.9.0
5
+ numpy==2.0.2
6
+ matplotlib==3.9.3
7
+ pandas==2.2.3
8
+ tqdm==4.67.1
9
+ hydra-core==1.3.2
requirements-gui.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio==5.23.2
2
+ plotly==6.0.1
sample_data/2ethylonaphthalene_bestshims_700MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/2ethylonaphthalene_up_1mm_700MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_20ul_700MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_X_supressed_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_Z1Z2Z3Z4_supressed_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_Z1Z2_supressed_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_besteshims_supressed_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_bestshims_700MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/CresolRed_after_styrene_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/CresolRed_after_styrene_cut_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/CresolRed_bestshims_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/CresolRed_cut_bestshims_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Geraniol_bestshims_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Geraniol_up_1mm_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/SodiumButyrate_after_glucose_cut_700MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/SodiumButyrate_cut_bestshims_700MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/generators.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchdata
4
+ # from itertools import islice
5
+
6
+ def random_value(min_value, max_value):
7
+ return (min_value + torch.rand(1) * (max_value - min_value)).item()
8
+
9
+ def random_loguniform(min_value, max_value):
10
+ return (min_value * torch.exp(torch.rand(1) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
11
+
12
+ def calculate_theoretical_spectrum(peaks_parameters: dict, frq_frq:torch.Tensor):
13
+ # extract parameters
14
+ tff_lin = peaks_parameters["tff_lin"]
15
+ twf_lin = peaks_parameters["twf_lin"]
16
+ thf_lin = peaks_parameters["thf_lin"]
17
+ trf_lin = peaks_parameters["trf_lin"]
18
+
19
+ lwf_lin = twf_lin
20
+ lhf_lin = thf_lin * (1. - trf_lin)
21
+ gwf_lin = twf_lin
22
+ gdf_lin = gwf_lin / torch.tensor(2.).log().mul(2.).sqrt()
23
+ ghf_lin = thf_lin * trf_lin
24
+ # calculate Lorenz peaks contriubutions
25
+ lsf_linfrq = lwf_lin[:, None] ** 2 / (lwf_lin[:, None] ** 2 + (frq_frq - tff_lin[:, None]) ** 2) * lhf_lin[:, None]
26
+ # calculate Gaussian peaks contriubutions
27
+ gsf_linfrq = torch.exp(-(frq_frq - tff_lin[:, None]) ** 2 / gdf_lin[:, None] ** 2 / 2.) * ghf_lin[:, None]
28
+ tsf_linfrq = lsf_linfrq + gsf_linfrq
29
+ # sum peaks contriubutions
30
+ tsf_frq = tsf_linfrq.sum(0, keepdim = True)
31
+ return tsf_frq
32
+
33
+
34
+ pascal_triangle = [(1,), (1,1), (1,2,1), (1,3,3,1), (1,4,6,4,1), (1,5,10,10,5,1), (1,6,15,20,15,6,1), (1,7, 21,35,35,21,7,1)]
35
+ normalized_pascal_triangle = [torch.tensor(x)/sum(x) for x in pascal_triangle]
36
+
37
+ def pascal_multiplicity(multiplicity):
38
+ intensities = normalized_pascal_triangle[multiplicity-1]
39
+ n_peaks = len(intensities)
40
+ shifts = torch.arange(n_peaks)-((n_peaks-1)/2)
41
+ return shifts, intensities
42
+
43
+ def double_multiplicity(multiplicity1, multiplicity2, j1=1, j2=1):
44
+ shifts1, intensities1 = pascal_multiplicity(multiplicity1)
45
+ shifts2, intensities2 = pascal_multiplicity(multiplicity2)
46
+
47
+ shifts = (j1*shifts1.reshape(-1,1) + j2*shifts2.reshape(1,-1)).flatten()
48
+ intensities = (intensities1.reshape(-1,1) * intensities2.reshape(1,-1)).flatten()
49
+ return shifts, intensities
50
+
51
+ def generate_multiplet_parameters(multiplicity, tff_lin, thf_lin, twf_lin, trf_lin, j1, j2):
52
+ shifts, intensities = double_multiplicity(multiplicity[0], multiplicity[1], j1, j2)
53
+ n_peaks = len(shifts)
54
+
55
+ return {
56
+ "tff_lin": shifts + tff_lin,
57
+ "thf_lin": intensities * thf_lin,
58
+ "twf_lin": torch.full((n_peaks,), twf_lin),
59
+ "trf_lin": torch.full((n_peaks,), trf_lin),
60
+ }
61
+
62
+ def value_to_index(values, table):
63
+ span = table[-1] - table[0]
64
+ indices = ((values - table[0])/span * (len(table)-1)) #.round().type(torch.int64)
65
+ return indices
66
+
67
+ def generate_theoretical_spectrum(
68
+ number_of_signals_min, number_of_signals_max,
69
+ spectrum_width_min, spectrum_width_max,
70
+ relative_width_min, relative_width_max,
71
+ tff_min, tff_max,
72
+ thf_min, thf_max,
73
+ trf_min, trf_max,
74
+ relative_height_min, relative_height_max,
75
+ multiplicity_j1_min, multiplicity_j1_max,
76
+ multiplicity_j2_min, multiplicity_j2_max,
77
+ atom_groups_data,
78
+ frq_frq
79
+ ):
80
+ number_of_signals = torch.randint(number_of_signals_min, number_of_signals_max+1, [])
81
+ atom_group_indices = torch.randint(0, len(atom_groups_data), [number_of_signals])
82
+ width_spectrum = random_loguniform(spectrum_width_min, spectrum_width_max)
83
+ height_spectrum = random_loguniform(thf_min, thf_max)
84
+
85
+ peak_parameters_data = []
86
+ theoretical_spectrum = None
87
+ for atom_group_index in atom_group_indices:
88
+ relative_intensity, multiplicity1, multiplicity2 = atom_groups_data[atom_group_index]
89
+ position = random_value(tff_min, tff_max)
90
+ j1 = random_value(multiplicity_j1_min, multiplicity_j1_max)
91
+ j2 = random_value(multiplicity_j2_min, multiplicity_j2_max)
92
+ width = width_spectrum*random_loguniform(relative_width_min, relative_width_max)
93
+ height = height_spectrum*relative_intensity*random_loguniform(relative_height_min, relative_height_max)
94
+ gaussian_contribution = random_value(trf_min, trf_max)
95
+
96
+ peaks_parameters = generate_multiplet_parameters(multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin= width, trf_lin= gaussian_contribution, j1=j1, j2=j2)
97
+ peaks_parameters["tff_relative"] = value_to_index(peaks_parameters["tff_lin"], frq_frq)
98
+ peak_parameters_data.append(peaks_parameters)
99
+ spectrum_contribution = calculate_theoretical_spectrum(peaks_parameters, frq_frq)
100
+ if theoretical_spectrum is None:
101
+ theoretical_spectrum = spectrum_contribution
102
+ else:
103
+ theoretical_spectrum += spectrum_contribution
104
+ return theoretical_spectrum, peak_parameters_data
105
+
106
+
107
+ def theoretical_generator(
108
+ atom_groups_data,
109
+ pixels=2048, frq_step=11160.7142857 / 32768,
110
+ number_of_signals_min=1, number_of_signals_max=8,
111
+ spectrum_width_min=0.2, spectrum_width_max=1,
112
+ relative_width_min=1, relative_width_max=2,
113
+ relative_height_min=1, relative_height_max=1,
114
+ relative_frequency_min=-0.4, relative_frequency_max=0.4,
115
+ thf_min=1/16, thf_max=16,
116
+ trf_min=0, trf_max=1,
117
+ multiplicity_j1_min=0, multiplicity_j1_max=15,
118
+ multiplicity_j2_min=0, multiplicity_j2_max=15,
119
+ ):
120
+ tff_min = relative_frequency_min * pixels * frq_step
121
+ tff_max = relative_frequency_max * pixels * frq_step
122
+ frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step
123
+
124
+ while True:
125
+ yield generate_theoretical_spectrum(
126
+ number_of_signals_min=number_of_signals_min,
127
+ number_of_signals_max=number_of_signals_max,
128
+ spectrum_width_min=spectrum_width_min,
129
+ spectrum_width_max=spectrum_width_max,
130
+ relative_width_min=relative_width_min,
131
+ relative_width_max=relative_width_max,
132
+ relative_height_min=relative_height_min,
133
+ relative_height_max=relative_height_max,
134
+ tff_min=tff_min, tff_max=tff_max,
135
+ thf_min=thf_min, thf_max=thf_max,
136
+ trf_min=trf_min, trf_max=trf_max,
137
+ multiplicity_j1_min=multiplicity_j1_min,
138
+ multiplicity_j1_max=multiplicity_j1_max,
139
+ multiplicity_j2_min=multiplicity_j2_min,
140
+ multiplicity_j2_max=multiplicity_j2_max,
141
+ atom_groups_data=atom_groups_data,
142
+ frq_frq=frq_frq
143
+ )
144
+
145
+ class ResponseLibrary:
146
+ def __init__(self, reponse_files, normalize=True):
147
+ self.data = [torch.load(f, map_location='cpu', weights_only=True).flatten(0,-4) for f in reponse_files]
148
+ if normalize:
149
+ self.data = [data/torch.sum(data, dim=(-1,), keepdim=True) for data in self.data]
150
+ lengths = [len(data) for data in self.data]
151
+ self.start_indices = torch.cumsum(torch.tensor([0] + lengths[:-1]), 0)
152
+ self.total_length = sum(lengths)
153
+
154
+ def __getitem__(self, idx):
155
+ if idx >= self.total_length:
156
+ raise ValueError(f'index {idx} out of range')
157
+ tensor_index = torch.searchsorted(self.start_indices, idx, right=True) - 1
158
+ return self.data[tensor_index][idx - self.start_indices[tensor_index]]
159
+
160
+ def __len__(self):
161
+ return self.total_length
162
+
163
+ def generator(
164
+ theoretical_generator_params,
165
+ response_function_library,
166
+ response_function_stretch_min=0.5,
167
+ response_function_stretch_max=2.0,
168
+ response_function_noise=0.,
169
+ spectrum_noise_min=0.,
170
+ spectrum_noise_max=1/64,
171
+ include_spectrum_data=False,
172
+ include_peak_mask=False,
173
+ include_response_function=False,
174
+ flip_response_function=False
175
+
176
+ ):
177
+ for theoretical_spectrum, theoretical_spectrum_data in theoretical_generator(**theoretical_generator_params):
178
+ # get response function
179
+ response_function = response_function_library[torch.randint(0, len(response_function_library), [1])][0]
180
+ # stretch response function
181
+ padding_size = (response_function.shape[-1] - 1)//2
182
+ padding_size = round(random_loguniform(response_function_stretch_min, response_function_stretch_max)*padding_size) #torch.randint(round(padding_size*response_function_stretch_min), round(padding_size*response_function_stretch_max), [1]).item()
183
+ response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
184
+ response_function /= response_function.sum() # normalize sum of response function to 1
185
+ # add noise to response function
186
+ response_function += torch.randn(response_function.shape) * response_function_noise
187
+ response_function /= response_function.sum() # normalize sum of response function to 1
188
+ if flip_response_function and (torch.rand(1).item() < 0.5):
189
+ response_function = response_function.flip(-1)
190
+ # disturbed spectrum
191
+ disturbed_spectrum = torch.nn.functional.conv1d(theoretical_spectrum, response_function, padding=padding_size)
192
+ # add noise
193
+ noised_spectrum = disturbed_spectrum + torch.randn(disturbed_spectrum.shape) * random_value(spectrum_noise_min, spectrum_noise_max)
194
+
195
+ out = {
196
+ # 'response_function': response_function,
197
+ 'theoretical_spectrum': theoretical_spectrum,
198
+ 'disturbed_spectrum': disturbed_spectrum,
199
+ 'noised_spectrum': noised_spectrum,
200
+ }
201
+ if include_response_function:
202
+ out['response_function'] = response_function
203
+ if include_spectrum_data:
204
+ out["theoretical_spectrum_data"] = theoretical_spectrum_data
205
+ if include_peak_mask:
206
+ all_peaks_rel = torch.cat([peak_data["tff_relative"] for peak_data in theoretical_spectrum_data])
207
+ peaks_indices = all_peaks_rel.round().type(torch.int64)
208
+ out["peaks_mask"] = torch.scatter(torch.zeros(out["theoretical_spectrum"].shape[1]), 0, peaks_indices, 1.).unsqueeze(0)
209
+
210
+ yield out
211
+
212
+
213
+ def collate_with_spectrum_data(batch):
214
+ tensor_keys = set(batch[0].keys())
215
+ tensor_keys.remove('theoretical_spectrum_data')
216
+ out = {k: torch.stack([item[k] for item in batch]) for k in tensor_keys}
217
+ out["theoretical_spectrum_data"] = [item["theoretical_spectrum_data"] for item in batch]
218
+ return out
219
+
220
+ def get_datapipe(
221
+ response_functions_files,
222
+ atom_groups_data_file=None,
223
+ batch_size=64,
224
+ pixels=2048, frq_step=11160.7142857 / 32768,
225
+ number_of_signals_min=1, number_of_signals_max=8,
226
+ spectrum_width_min=0.2, spectrum_width_max=1,
227
+ relative_width_min=1, relative_width_max=2,
228
+ relative_height_min=1, relative_height_max=1,
229
+ relative_frequency_min=-0.4, relative_frequency_max=0.4,
230
+ thf_min=1/16, thf_max=16,
231
+ trf_min=0, trf_max=1,
232
+ multiplicity_j1_min=0, multiplicity_j1_max=15,
233
+ multiplicity_j2_min=0, multiplicity_j2_max=15,
234
+ response_function_stretch_min=0.5,
235
+ response_function_stretch_max=2.0,
236
+ response_function_noise=0.,
237
+ spectrum_noise_min=0.,
238
+ spectrum_noise_max=1/64,
239
+ include_spectrum_data=False,
240
+ include_peak_mask=False,
241
+ include_response_function=False,
242
+ flip_response_function=False
243
+ ):
244
+ # singlets
245
+ if atom_groups_data_file is None:
246
+ atom_groups_data = np.ones((1,3), dtype=int)
247
+ else:
248
+ atom_groups_data = np.loadtxt(atom_groups_data_file, usecols=(1,2,3), dtype=int)
249
+ response_function_library = ResponseLibrary(response_functions_files)
250
+ g = generator(
251
+ theoretical_generator_params=dict(
252
+ atom_groups_data=atom_groups_data,
253
+ pixels=pixels, frq_step=frq_step,
254
+ number_of_signals_min=number_of_signals_min, number_of_signals_max=number_of_signals_max,
255
+ spectrum_width_min=spectrum_width_min, spectrum_width_max=spectrum_width_max,
256
+ relative_width_min=relative_width_min, relative_width_max=relative_width_max,
257
+ relative_height_min=relative_height_min, relative_height_max=relative_height_max,
258
+ relative_frequency_min=relative_frequency_min, relative_frequency_max=relative_frequency_max,
259
+ thf_min=thf_min, thf_max=thf_max,
260
+ trf_min=trf_min, trf_max=trf_max,
261
+ multiplicity_j1_min=multiplicity_j1_min, multiplicity_j1_max=multiplicity_j1_max,
262
+ multiplicity_j2_min=multiplicity_j2_min, multiplicity_j2_max=multiplicity_j2_max
263
+ ),
264
+ response_function_library=response_function_library,
265
+ response_function_stretch_min=response_function_stretch_min,
266
+ response_function_stretch_max=response_function_stretch_max,
267
+ response_function_noise=response_function_noise,
268
+ spectrum_noise_min=spectrum_noise_min,
269
+ spectrum_noise_max=spectrum_noise_max,
270
+ include_spectrum_data=include_spectrum_data,
271
+ include_peak_mask=include_peak_mask,
272
+ include_response_function=include_response_function,
273
+ flip_response_function=flip_response_function
274
+ )
275
+
276
+ pipe = torchdata.datapipes.iter.IterableWrapper(g, deepcopy=False)
277
+ pipe = pipe.batch(batch_size)
278
+ pipe = pipe.collate(collate_fn=collate_with_spectrum_data if include_spectrum_data else None)
279
+
280
+ return pipe
src/models.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class ConvEncoder(torch.nn.Module):
4
+ def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
5
+ super().__init__()
6
+ if output_dim is None:
7
+ output_dim = hidden_dim
8
+ self.conv4 = torch.nn.Conv1d(1, hidden_dim, kernel_size)
9
+ self.conv3 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
10
+ self.conv2 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
11
+ self.conv1 = torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
12
+ self.dropout = torch.nn.Dropout(dropout)
13
+
14
+ def forward(self, feature): #(samples, 1, 2048)
15
+ feature = self.dropout(self.conv4(feature)) #(samples, 64, 2042)
16
+ feature = feature.relu()
17
+ feature = self.dropout(self.conv3(feature)) #(samples, 64, 2036)
18
+ feature = feature.relu()
19
+ feature = self.dropout(self.conv2(feature)) #(samples, 64, 2030)
20
+ feature = feature.relu()
21
+ feature = self.dropout(self.conv1(feature)) #(samples, 64, 2024)
22
+ return feature
23
+
24
+ class ConvDecoder(torch.nn.Module):
25
+ def __init__(self, input_dim=None, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
26
+ super().__init__()
27
+ if output_dim is None:
28
+ output_dim = hidden_dim
29
+ self.convTranspose1 = torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size)
30
+ self.convTranspose2 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
31
+ self.convTranspose3 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
32
+ self.convTranspose4 = torch.nn.ConvTranspose1d(hidden_dim, 1, kernel_size)
33
+
34
+ def forward(self, feature): #(samples, 1, 2048)
35
+ feature = self.convTranspose1(feature) #(samples, 64, 2030)
36
+ feature = feature.relu()
37
+ feature = self.convTranspose2(feature) #(samples, 64, 2036)
38
+ feature = feature.relu()
39
+ feature = self.convTranspose3(feature) #(samples, 64, 2042)
40
+ feature = feature.relu()
41
+ feature = self.convTranspose4(feature)
42
+ return feature
43
+
44
+ class ResponseHead(torch.nn.Module):
45
+ def __init__(self, input_dim, output_length, hidden_dims=[128]):
46
+ super().__init__()
47
+ response_head_dims = [input_dim]+hidden_dims + [output_length]
48
+ response_head_layers = [torch.nn.Linear(response_head_dims[0], response_head_dims[1])]
49
+ for dims_in, dims_out in zip(response_head_dims[1:-1], response_head_dims[2:]):
50
+ response_head_layers.extend([
51
+ torch.nn.GELU(),
52
+ torch.nn.Linear(dims_in, dims_out)
53
+ ])
54
+ self.response_head = torch.nn.Sequential(*response_head_layers)
55
+
56
+ def forward(self, feature):
57
+ return self.response_head(feature)
58
+
59
+ class ShimNetWithSCRF(torch.nn.Module):
60
+ def __init__(self,
61
+ encoder_hidden_dims=64,
62
+ encoder_dropout=0,
63
+ bottleneck_dim=64,
64
+ rensponse_length=61,
65
+ resnponse_head_dims=[128],
66
+ decoder_hidden_dims=64
67
+ ):
68
+ super().__init__()
69
+ self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout)
70
+ self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim))
71
+ torch.nn.init.xavier_normal_(self.query)
72
+
73
+ self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims)
74
+
75
+ self.rensponse_length = rensponse_length
76
+ self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims)
77
+
78
+ def forward(self, feature): #(samples, 1, 2048)
79
+ feature = self.encoder(feature) #(samples, 64, 2042)
80
+ energy = self.query @ feature #(samples, 1, 2024)
81
+ weight = torch.nn.functional.softmax(energy, 2) #(samples, 1, 2024)
82
+ global_features = feature @ weight.transpose(1, 2) #(samples, 64, 1)
83
+
84
+ response = self.response_head(global_features.squeeze(-1))
85
+
86
+ feature, global_features = torch.broadcast_tensors(feature, global_features) #(samples, 64, 2048)
87
+ feature = torch.cat([feature, global_features], 1) #(samples, 128, 2024)
88
+ denoised_spectrum = self.decoder(feature) #(samples, 1, 2048)
89
+
90
+ return {
91
+ 'denoised': denoised_spectrum,
92
+ 'response': response,
93
+ 'attention': weight.squeeze(1)
94
+ }
95
+
96
+ class Predictor:
97
+ def __init__(self, model=None, weights_file=None):
98
+ self.model = model
99
+ if weights_file is not None:
100
+ self.model.load_state_dict(torch.load(weights_file, map_location='cpu', weights_only=True))
101
+
102
+ def __call__(self, nsf_frq):
103
+ with torch.no_grad():
104
+ msf_frq = self.model(nsf_frq[None, None])["denoised"]
105
+ return msf_frq[0, 0]
train.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchaudio
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from omegaconf import OmegaConf
5
+ from hydra.utils import instantiate
6
+ import datetime
7
+ import sys
8
+ import matplotlib.pyplot as plt
9
+
10
+
11
+ import matplotlib
12
+ matplotlib.use('Agg')
13
+
14
+ # silent deprecation_warning() from datapipes
15
+ import warnings
16
+ warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
17
+
18
+ from src import models
19
+ from src.generators import get_datapipe
20
+
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ if len(sys.argv) < 2:
23
+ print("Please provide the run directory as an argument.")
24
+ sys.exit(1)
25
+ run_dir = Path(sys.argv[1])
26
+
27
+ config = OmegaConf.load(run_dir / "config.yaml")
28
+
29
+ if (run_dir / "train.txt").is_file():
30
+ minimum = np.min(np.loadtxt(run_dir / "train.txt")[:,2])
31
+ else:
32
+ minimum = float("inf")
33
+
34
+ # initialization
35
+ model = instantiate({"_target_": f"__main__.models.{config.model.name}", **config.model.kwargs}).to(device)
36
+ model_weights_file = run_dir / f'model.pt'
37
+ optimizer = torch.optim.Adam(model.parameters())
38
+ optimizer_weights_file = run_dir / f'optimizer.pt'
39
+
40
+ def evaluate_model(stage=0, epoch=0):
41
+ plot_dir = run_dir / "plots" / f"{stage}_{epoch}"
42
+ plot_dir.mkdir(exist_ok=True, parents=True)
43
+
44
+ torch.save(model.state_dict(), plot_dir / "model.pt")
45
+ torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt")
46
+
47
+ num_plots = config.logging.num_plots
48
+ pipe = get_datapipe(
49
+ **config.data,
50
+ include_response_function=True,
51
+ batch_size=num_plots
52
+ )
53
+ batch = next(iter(pipe))
54
+
55
+ with torch.no_grad():
56
+ out = model(batch['noised_spectrum'].to(device))
57
+ noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu()
58
+
59
+ for i in range(num_plots):
60
+ plt.figure(figsize=(30,6))
61
+ plt.plot(batch['theoretical_spectrum'].cpu().numpy()[i,0])
62
+ plt.plot(out['denoised'].cpu().numpy()[i,0])
63
+ plt.savefig(plot_dir / f"{i:03d}_spectrum_clean.png")
64
+
65
+ plt.figure(figsize=(30,6))
66
+ plt.plot(batch['noised_spectrum'].cpu().numpy()[i,0])
67
+ plt.plot(noised_est.cpu().numpy()[i,0])
68
+ plt.savefig(plot_dir / f"{i:03d}_spectrum_noise.png")
69
+
70
+ plt.figure(figsize=(10,6))
71
+ plt.plot(batch['response_function'].cpu().numpy()[i,0,0])
72
+ plt.plot(out['response'].cpu().numpy()[i])
73
+ plt.savefig(plot_dir / f"{i:03d}_response.png")
74
+
75
+ if "attention" in out:
76
+ plt.figure(figsize=(10, 6))
77
+ plt.plot(out['attention'].cpu().numpy()[i])
78
+ plt.savefig(plot_dir / f"{i:03d}_attention.png")
79
+
80
+ plt.close("all")
81
+
82
+ for i_stage, training_stage in enumerate(config.training):
83
+ if model_weights_file.is_file():
84
+ model.load_state_dict(torch.load(model_weights_file, weights_only=True))
85
+
86
+ if optimizer_weights_file.is_file():
87
+ optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True))
88
+ optimizer.param_groups[0]['lr'] = training_stage.learning_rate
89
+
90
+ pipe = get_datapipe(
91
+ **config.data,
92
+ include_response_function=True,
93
+ batch_size=training_stage.batch_size
94
+ )
95
+
96
+ losses_history = []
97
+ losses_history_limit = 64*100 // training_stage.batch_size
98
+
99
+ last_evaluation = 0
100
+ for epoch, batch in pipe.enumerate():
101
+
102
+ # logging
103
+ iters_done = epoch*training_stage.batch_size
104
+ if (iters_done - last_evaluation) > config.logging.step:
105
+ evaluate_model(i_stage, epoch)
106
+ last_evaluation = iters_done
107
+
108
+ if iters_done > training_stage.max_iters:
109
+ evaluate_model(i_stage, epoch)
110
+ break
111
+
112
+ # run model
113
+ out = model(batch['noised_spectrum'].to(device))
114
+ # calculate losses
115
+ loss_response = torch.nn.functional.mse_loss(out['response'], batch['response_function'].squeeze(dim=(1,2)).to(device))
116
+ loss_clean = torch.nn.functional.mse_loss(out['denoised'], batch['theoretical_spectrum'].to(device))
117
+ noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same")
118
+ loss_noised = torch.nn.functional.mse_loss(noised_est, batch['noised_spectrum'].to(device))
119
+ loss = config.losses_weights.response*loss_response + config.losses_weights.clean*loss_clean + config.losses_weights.noised*loss_noised
120
+
121
+ # logging
122
+ losses_history.append(loss_clean.item())
123
+ losses_history = losses_history[-losses_history_limit:]
124
+ loss_avg = sum(losses_history)/len(losses_history)
125
+ message = f"{epoch:7d} {loss:0.3e} {loss_avg:0.3e} {loss_clean:0.3e} {loss_response:0.3e} {loss_noised:0.3e}"
126
+ # message = '%7i %.3e %.3e %.3e' % (epoch, loss, regress, classify)
127
+ with open(run_dir / f'train.txt', 'a') as f:
128
+ f.write(message + '\n')
129
+ print(message, flush = True)
130
+
131
+ # save best
132
+ if loss_avg < minimum:
133
+ minimum = loss_avg
134
+ torch.save(model.state_dict(), model_weights_file)
135
+ torch.save(optimizer.state_dict(),optimizer_weights_file)
136
+
137
+ # update weights
138
+ loss.backward()
139
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
140
+ optimizer.step()
141
+ optimizer.zero_grad()
weights/shimnet_600MHz.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bac4bbff78b4d898c07ac0daab30203ead33a896d5803060593746bbb15dd10
3
+ size 889765
weights/shimnet_700MHz.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b42c9a2dfd4cb3ed0ce3aa328ea09cce03b9b3713e92b728229f24fff3781835
3
+ size 880746