Marek Bukowicki commited on
Commit
5b32793
·
1 Parent(s): 0fef3e5

1st commit

Browse files
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python's temps
2
+ **/.ipynb_checkpoints/
3
+ **/__pycache__/
4
+
5
+ # training outputs
6
+ runs/
7
+ # data files
8
+ data/
9
+ # typically weights and data
10
+ *.pt
11
+
Readme.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ShimNet
2
+ ShimNet is a data-driven AI solution to improve high-resolution nuclear magnetic resonance (NMR) spectra
3
+ distorted by the inhomogeneous magnetic field (less than optimal shimming). To use it, the experimental training data has to be collected (see **Data collection** below).
4
+ Example data can also be downloaded (see below).
5
+
6
+ Paper: ...
7
+
8
+ ## Installation
9
+
10
+ Python 3.9+
11
+
12
+ GPU version (for training and inference)
13
+ ```
14
+ pip install -r requirements.txt
15
+ ```
16
+
17
+ CPU version (for inference, not recommended for training)
18
+ ```
19
+ pip install -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
20
+ ```
21
+
22
+ ## Usage
23
+ To correct spectra presented in the paper:
24
+ 1. download weights (model parameters):
25
+ ```
26
+ python download_files.py
27
+ ```
28
+ or directly from [Google Drive 700MHz](https://drive.google.com/uc?export=download&id=17fTNWl7YW6mPbbZWga0EfdoF_6S8fCke) and [Google Drive 600MHz](https://drive.google.com/uc?export=download&id=1_VxOpFGJcFsOa5DHOW2GJbP8RvHCmC1N) and place it in `weights` directory
29
+
30
+
31
+ 2. : run correction (e.g. `Azarone_20ul_700MHz.csv`):
32
+ ```
33
+ python predict.py sample_data/Azarone_20ul_700MHz.csv -o output --config configs/shimnet_700.yaml --weights weights/shimnet_700MHz.pt
34
+ ```
35
+ The output will be `output/Azarone_20ul_700MHz_processed.csv` file
36
+
37
+ Multiple files may be processed using "*" syntax:
38
+ ```
39
+ python predict.py sample_data/*700MHz.csv -o output --config configs/shimnet_700.yaml --weights weights/shimnet_700MH
40
+ z.pt
41
+ ```
42
+
43
+ For 600 MHz data use `--config configs/shimnet_600.yaml` and `--weights weights/shimnet_600MHz.pt`, e.g.:
44
+
45
+ ```
46
+ python predict.py sample_data/CresolRed_after_styrene_600MHz.csv -o output --config configs/shimnet_600.yaml --weights weights/shimnet_600MHz.pt
47
+ ```
48
+
49
+ ### input format
50
+
51
+ The spectrum file for reconstruction should be in the format of two columns separated by a space and without the sign at the end of the line at the end of the file(example below):
52
+ ```csv
53
+ -1.97134 0.0167137
54
+ -1.97085 -0.00778748
55
+ -1.97036 -0.0109595
56
+ -1.96988 0.00825978
57
+ -1.96939 0.0133886
58
+ ```
59
+
60
+ ## Train on your data
61
+
62
+ For the model to function properly, it should be trained on calibration data from the spectrometer used for the measurements. To train a model on data from your spectrometer, please follow the instructions below.
63
+
64
+ ### Training data collection
65
+
66
+ Below we describe the training data collection for Agilent/Varian spectrometers. For machines of other vendors similar procedure can be implemented.
67
+ To collect ShimNet training data use Python script (sweep_shims_lineshape_Z1Z2.py) from the calibration_loop folder to drive the spectrometer:
68
+ 1. Install TReNDS package ( trends.spektrino.com )
69
+ 2. Open VnmrJ and type: 'listenon'
70
+ 3. Put the lineshape sample (1% CHCl3 in deuterated acetone), set standard PROTON parameters, and set nt=1 (do not modify sw and at!)
71
+ 4. Shim the sample and collect the data. Save the optimally shimmed dataset
72
+ 5. Edit the sweep_shims_lineshape_Z1Z2.py script
73
+ 6. Put optimum z1 and z2 shim values as optiz1 and optiz2 below
74
+ 7. Define the calibration range as range_z1 and range_z2 (default is ok)
75
+ 8. Start the python script:
76
+ ```
77
+ python3 ./sweep_shims_lineshape_Z1Z2.py
78
+ ```
79
+ The spectrometer will start collecting spectra
80
+
81
+ ### SCRF extraction
82
+ Shim Coil Response Functions (SCRF) should be extracted from the spectra with `extract_scrf_from_fids.py` script.
83
+ ```
84
+ python extract_scrf_from_fids.py
85
+ ```
86
+
87
+ The script uses hardcoded paths to the NMR signals (fid-s) in Agilent/Varian format: a directory with optimal measurement (`opti_fid_path` available) and a directory with calibration loop measurements (`data_dir`):
88
+ ```python
89
+ # input
90
+ data_dir = "../../sample_run/loop"
91
+ opti_fid_path = "../../sample_run/opti.fid"
92
+
93
+ ```
94
+
95
+ The output files are also hardcoded:
96
+ ```python
97
+ # output
98
+ spectra_file = "../../sample_run/total.npy"
99
+ spectra_file_names = "../../sample_run/total.csv"
100
+ opi_spectrum_file = "../../sample_run/opti.npy"
101
+ responses_file = "../../sample_run/scrf_61.pt"
102
+ ```
103
+ where only the `responses_file` is used in ShimNet training.
104
+
105
+ If the measurements are stored in a format other than Varian, you may need to change this line:
106
+ ```python
107
+ dic, data = ng.varian.read(varian_fid_path)
108
+ ```
109
+ (see nmrglue package documentation for details)
110
+
111
+ ### Training
112
+
113
+ 1. Download multiplets database:
114
+ ```
115
+ python download_files.py --multiplets
116
+ ```
117
+ 2. Configure run:
118
+ - create a run directory, e.g. `runs/my_lab_spectrometer_2025`
119
+ - create a configuration file:
120
+ 1. copy `configs/shimnet_template.py` to the run directory and rename it to `config.yaml`
121
+ ```bash
122
+ cp configs/shimnet_template.py runs/my_lab_spectrometer_2025/config.yaml
123
+ ```
124
+ 2. edit the SCRF in path in the config file:
125
+ ```yaml
126
+ response_functions_files:
127
+ - path/to/srcf_file
128
+ ```
129
+ e.g.
130
+ ```yaml
131
+ response_functions_files:
132
+ - ../../sample_run/scrf_61.pt
133
+ ```
134
+ 3. adjust spectrometer frequency step `frq_step` to match your data (spectrometer range in Hz divided by number of points in spectrum):
135
+ ```yaml
136
+ frq_step: 0.34059797
137
+ ```
138
+ 4. adjust spectromer frequency in the metadata
139
+ ```yaml
140
+ metadata: # additional metadata, not used in the training process
141
+ spectrometer_frequency: 700.0 # MHz
142
+ ```
143
+ 3. Run training:
144
+ ```
145
+ python train.py runs/my_lab_spectrometer_2025
146
+ ```
147
+ Training results will appear in `runs/my_lab_spectrometer_2025` directory.
148
+ Model parameters are stored in `runs/my_lab_spectrometer_2025/model.pt` file
149
+ 4. Use trained model:
150
+
151
+ use `--config runs/my_lab_spectrometer_2025/config.yaml` and `--weights runs/my_lab_spectrometer_2025/model.pt` flags, e.g.
152
+ ```
153
+ python predict.py my_sample1.csv -o my_output --config runs/my_lab_spectrometer_2025/config.yaml --weights runs/my_lab_spectrometer_2025/model.pt
154
+ ```
155
+
156
+
157
+
158
+ ## Repeat training on our data
159
+
160
+ If you want to train the network using the calibration data from our paper, follow the procedure below.
161
+
162
+ 1. Download multiplets database and our SCRF file:
163
+ ```
164
+ python download_files.py --multiplets --SCRF --no-weights
165
+ ```
166
+ 2. Configure run
167
+ ```bash
168
+ mkdir -p runs/repeat_paper_training
169
+ cp configs/shimnet_700.yaml runs/repeat_paper_training/config.yaml
170
+ ```
171
+ 3. Run training:
172
+ ```
173
+ python train.py runs/repeat_paper_training
174
+ ```
175
+ Training results will appear in `runs/repeat_paper_training` directory.
176
+
calibration_loop/maclib/shimator_loop ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ " macro shimnet loop "
2
+ " arguments: optimal z1, optimal z2, step in z1 loop, step in z2 loop, half number of steps in z1, half number of steps in z2"
3
+
4
+ $svfdir='/home/nmrbox/kkazimierczuk/shimnet/data/'
5
+
6
+ $opti_z1=$1
7
+ $opti_z2=$2
8
+ $step_z1=$3
9
+ $step_z2=$4
10
+ $steps_z1=$5
11
+ $steps_z2=$6
12
+ $nn=''
13
+ $bb=''
14
+ $msg=''
15
+ format((($steps_z1*2+1)*($steps_z2*2+1)*(d1+at)*nt)/3600,5,1):$msg
16
+ $msg='Time: '+$msg+' h'
17
+ banner($msg)
18
+ $j=0
19
+ repeat
20
+ $i=0
21
+ $z2=$opti_z2-$steps_z2*$step_z2+$j*$step_z2
22
+ format($z2,5,1):$bb
23
+ repeat
24
+ $z1=$opti_z1-$steps_z1*$step_z1+$i*$step_z1
25
+ "su"
26
+ "go"
27
+ format($z1,5,1):$nn
28
+ $filepath=$svfdir + 'z1_'+$nn+'z2_'+$bb+'.fid'
29
+ svf($filepath,'force')
30
+ $i=$i+1
31
+ until $z1>$opti_z1+$steps_z1*$step_z1-1
32
+ $j=$j+1
33
+ until $z2>$opti_z2+$steps_z2*$step_z2-1
calibration_loop/sweep_shims_lineshape_Z1Z2.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from TReND.send2vnmr import *
2
+ import shutil
3
+ import time
4
+ import os
5
+ import numpy as np
6
+ from subprocess import call
7
+
8
+ #This is the script to collect ShimNet trainig data on Agilent spectrometers. Execution:
9
+ # 1. Open VnmrJ and type: 'listenon'
10
+ # 2. Put the lineshape sample, set standard PROTON parameters and set one scan (do not modify sw and at!)
11
+ # 3. Shim the sample and collect the data. Save the optimally shimmed datatset
12
+ # 4. Put optimum z1 and z2 shim values as optiz1 and optiz2 below
13
+ # 5. Define the calibration range as range_z1 and range_z2 (default is ok)
14
+ # 6. Start the python script:
15
+ # python3 ./sweep_shims_lineshape_Z1Z2.py
16
+ # the spectrometer will start collecting spectra
17
+
18
+ Classic_path = '/home/nmr700/shimnet6/lshp'
19
+ optiz1= 8868 #put optimum shim values here
20
+ optiz2=-297
21
+
22
+ range_z1=100 #put optimum shim ranges here
23
+ range_z2=100 #put optimum shim ranges here
24
+
25
+ z1_sweep=np.arange(optiz1-range_z1,optiz1+range_z1+1,2.0)
26
+ z2_sweep=np.arange(optiz2-range_z2,optiz2+range_z2+1,2.0)
27
+
28
+ for i in range(1,np.shape(z1_sweep)[0]+1, 1):
29
+ for j in range(1,np.shape(z2_sweep)[0]+1, 1):
30
+ wait_until_idle()
31
+
32
+ Run_macro("sethw('z1',"+str(z1_sweep[i-1])+ ")")
33
+ Run_macro("sethw('z2',"+str(z2_sweep[j-1])+ ")")
34
+
35
+ go_if_idle()
36
+ wait_until_idle()
37
+ time.sleep(0.5)
38
+
39
+ Save_experiment(Classic_path + '_z1_'+ str(int(z1_sweep[i-1])) + '_z2_'+ str(int(z2_sweep[j-1]))+ '.fid')
configs/shimnet_600.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: ShimNetWithSCRF
3
+ kwargs:
4
+ rensponse_length: 81
5
+ resnponse_head_dims:
6
+ - 128
7
+ training:
8
+ - batch_size: 64
9
+ learning_rate: 0.001
10
+ max_iters: 1600000
11
+ - batch_size: 512
12
+ learning_rate: 0.001
13
+ max_iters: 25600000
14
+ - batch_size: 512
15
+ learning_rate: 0.0005
16
+ max_iters: 12800000
17
+ losses_weights:
18
+ clean: 1.0
19
+ noised: 1.0
20
+ response: 1.0
21
+ data:
22
+ response_functions_files:
23
+ # Paste path to your SCRF file here
24
+ # - Can be absolute path
25
+ # - Can be relative to repository root
26
+ - data/600hz-2-Lnative/scrf_81.pt
27
+ atom_groups_data_file: data/multiplets_10000_parsed.txt
28
+ response_function_stretch_min: 1.0
29
+ response_function_stretch_max: 1.0
30
+ response_function_noise: 0.0
31
+ multiplicity_j1_min: 0.0
32
+ multiplicity_j1_max: 15
33
+ multiplicity_j2_min: 0.0
34
+ multiplicity_j2_max: 15
35
+ number_of_signals_min: 2
36
+ number_of_signals_max: 5
37
+ thf_min: 0.5
38
+ thf_max: 2
39
+ relative_height_min: 0.5
40
+ relative_height_max: 4
41
+ frq_step: 0.30048
42
+ logging:
43
+ step: 1000000
44
+ num_plots: 32
45
+ metadata:
46
+ spectrometer_frequency: 600.0
configs/shimnet_700.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: ShimNetWithSCRF
3
+ kwargs:
4
+ rensponse_length: 61
5
+ resnponse_head_dims:
6
+ - 128
7
+ training:
8
+ - batch_size: 64
9
+ learning_rate: 0.001
10
+ max_iters: 1600000
11
+ - batch_size: 512
12
+ learning_rate: 0.001
13
+ max_iters: 6400000
14
+ - batch_size: 512
15
+ learning_rate: 0.0005
16
+ max_iters: 12800000
17
+ - batch_size: 800
18
+ learning_rate: 0.0005
19
+ max_iters: 12800000
20
+ losses_weights:
21
+ clean: 1.0
22
+ noised: 1.0
23
+ response: 1.0
24
+ data:
25
+ response_functions_files:
26
+ - data/scrf_61_700MHz.pt
27
+ atom_groups_data_file: data/multiplets_10000_parsed.txt
28
+ response_function_stretch_min: 1.0
29
+ response_function_stretch_max: 1.0
30
+ response_function_noise: 0.0
31
+ multiplicity_j1_min: 0.0
32
+ multiplicity_j1_max: 15
33
+ multiplicity_j2_min: 0.0
34
+ multiplicity_j2_max: 15
35
+ number_of_signals_min: 2
36
+ number_of_signals_max: 5
37
+ thf_min: 0.5
38
+ thf_max: 2
39
+ relative_height_min: 0.5
40
+ relative_height_max: 4
41
+ frq_step: 0.34059797
42
+ logging:
43
+ step: 1000000
44
+ num_plots: 32
45
+ metadata:
46
+ spectrometer_frequency: 700.0
configs/shimnet_template.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: ShimNetWithSCRF
3
+ kwargs:
4
+ rensponse_length: 61
5
+ resnponse_head_dims:
6
+ - 128
7
+ training:
8
+ - batch_size: 64
9
+ learning_rate: 0.001
10
+ max_iters: 1600000
11
+ - batch_size: 512
12
+ learning_rate: 0.001
13
+ max_iters: 6400000
14
+ - batch_size: 512
15
+ learning_rate: 0.0005
16
+ max_iters: 12800000
17
+ - batch_size: 800
18
+ learning_rate: 0.0005
19
+ max_iters: 12800000
20
+ losses_weights:
21
+ clean: 1.0
22
+ noised: 1.0
23
+ response: 1.0
24
+ data:
25
+ response_functions_files:
26
+ # Specify the path to your SCRF file/files here.
27
+ # - It can be an absolute path.
28
+ # - It can be relative to the repository root.
29
+ # Multiple files can be listed in the following rows (YAML list format).
30
+ - path/to/srcf_file
31
+ atom_groups_data_file: data/multiplets_10000_parsed.txt
32
+ response_function_stretch_min: 1.0 # realtive
33
+ response_function_stretch_max: 1.0 # realtive
34
+ response_function_noise: 0.0 # arbitrary hight units
35
+ multiplicity_j1_min: 0.0 # Hz
36
+ multiplicity_j1_max: 15 # Hz
37
+ multiplicity_j2_min: 0.0 # Hz
38
+ multiplicity_j2_max: 15 # Hz
39
+ number_of_signals_min: 2
40
+ number_of_signals_max: 5
41
+ thf_min: 0.5 # arbitrary hight units
42
+ thf_max: 2 # arbitrary hight units
43
+ relative_height_min: 0.5 # realtive
44
+ relative_height_max: 4 # realtive
45
+ frq_step: 0.34059797 # Hz per point
46
+ logging:
47
+ step: 1000000
48
+ num_plots: 32
49
+ metadata: # additional metadata, not used in the training process
50
+ spectrometer_frequency: 700.0 # MHz
download_files.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ from pathlib import Path
3
+ import argparse
4
+
5
+ ALL_FILES_TO_DOWNLOAD = {
6
+ "weights": [{
7
+ "url": "https://drive.google.com/uc?export=download&id=17fTNWl7YW6mPbbZWga0EfdoF_6S8fCke",
8
+ "destination": "weights/shimnet_700MHz.pt"
9
+ },
10
+ {
11
+ "url": "https://drive.google.com/uc?export=download&id=1_VxOpFGJcFsOa5DHOW2GJbP8RvHCmC1N",
12
+ "destination": "weights/shimnet_600MHz.pt"
13
+ }],
14
+ "SCRF": [{
15
+ "url": "https://drive.google.com/uc?export=download&id=113al7A__yYALx_2hkESuzFIDU3feVtNY",
16
+ "destination": "data/scrf_61_700MHz.pt"
17
+ }],
18
+ "mupltiplets": [{
19
+ "url": "https://drive.google.com/uc?export=download&id=1QGvV-Au50ZxaP1vFsmR_auI299Dw-Wrt",
20
+ "destination": "data/multiplets_10000_parsed.txt"
21
+ }],
22
+ "development": []
23
+ }
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(
27
+ description='Download files: weighs (default), SCRF (optional), multiplet data (optional)',
28
+ )
29
+ parser.add_argument(
30
+ '--weights',
31
+ action='store_true',
32
+ default=True,
33
+ help='Download weights file (default behavior). Use --no-weights to opt out.',
34
+ )
35
+ parser.add_argument(
36
+ '--no-weights',
37
+ action='store_false',
38
+ dest='weights',
39
+ help='Do not download weights file.',
40
+ )
41
+ parser.add_argument('--SCRF', action='store_true', help='Download SCRF file')
42
+ parser.add_argument('--multiplets', action='store_true', help='Download multiplets data file')
43
+ parser.add_argument('--development', action='store_true', help='Download development weights file')
44
+
45
+ parser.add_argument('--all', action='store_true', help='Download all available files')
46
+
47
+ args = parser.parse_args()
48
+ # Set all individual flags if --all is specified
49
+ if args.all:
50
+ args.weights = True
51
+ args.SCRF = True
52
+ args.multiplets = True
53
+ args.development = True
54
+
55
+ return args
56
+
57
+ def download_file(url, target):
58
+ target = Path(target)
59
+ if target.exists():
60
+ response = input(f"File {target} already exists. Overwrite? (y/n): ")
61
+ if response.lower() != 'y':
62
+ print(f"Download of {target} cancelled")
63
+ return
64
+ target.parent.mkdir(parents=True, exist_ok=True)
65
+ try:
66
+ urllib.request.urlretrieve(url, target)
67
+ print(f"Downloaded {target}")
68
+ except Exception as e:
69
+ print(f"Failed to download file from {url}:\n {e}")
70
+
71
+
72
+ if __name__ == "__main__":
73
+ args = parse_args()
74
+
75
+ main_dir = Path(__file__).parent
76
+ if args.weights:
77
+ for file_data in ALL_FILES_TO_DOWNLOAD["weights"]:
78
+ download_file(file_data["url"], main_dir / file_data["destination"])
79
+
80
+ if args.SCRF:
81
+ for file_data in ALL_FILES_TO_DOWNLOAD["SCRF"]:
82
+ download_file(file_data["url"], main_dir / file_data["destination"])
83
+
84
+ if args.multiplets:
85
+ for file_data in ALL_FILES_TO_DOWNLOAD["mupltiplets"]:
86
+ download_file(file_data["url"], main_dir / file_data["destination"])
87
+
88
+ if args.development:
89
+ for file_data in ALL_FILES_TO_DOWNLOAD["development"]:
90
+ download_file(file_data["url"], main_dir / file_data["destination"])
extract_scrf_from_fids.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import nmrglue as ng
9
+ from tqdm import tqdm
10
+
11
+ # input
12
+ data_dir = "../../../dane/600Hz_20241227/loop"
13
+ opti_fid_path = "../../../dane/600Hz_20241227/opti"
14
+
15
+ # output
16
+ spectra_file = "data/600Hz_20241227-Lnative/total.npy"
17
+ spectra_file_names = "data/600Hz_20241227-Lnative/total.csv"
18
+ opi_spectrum_file = "data/600Hz_20241227-Lnative/opti.npy"
19
+ responses_file = "data/600Hz_20241227-Lnative/scrf_61.pt"
20
+ losses_file = "data/600Hz_20241227-Lnative/losses_scrf_61.pt"
21
+
22
+ # create directories for output files
23
+ for file_path in [spectra_file, spectra_file_names, opi_spectrum_file, responses_file]:
24
+ Path(file_path).parent.mkdir(parents=True, exist_ok=True)
25
+
26
+ # settings: fid to spectrum
27
+ ph0_correction =-190.49
28
+ ph1_correction = 0
29
+ autophase_fn="acme"
30
+ target_length = None
31
+
32
+ # settings: SCRF extraction
33
+ calibration_peak_center = "auto"#12277 # the center of the calibration peak (peak of)
34
+ calibration_window_halfwidth = 128 # the half width of the calibration window
35
+ steps = 6000
36
+ kernel_size = 61
37
+ kernel_sqrt = False # True to allow only positive values
38
+
39
+ # fid to spectra processing
40
+
41
+ def fid_to_spectrum(varian_fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=None, sin_pod=False):
42
+ dic, data = ng.varian.read(varian_fid_path)
43
+ data[0] *= 0.5
44
+ if sin_pod:
45
+ data = ng.proc_base.sp(data, end=0.98)
46
+
47
+ if target_length is not None:
48
+ if (pad_length := target_length - len(data)) > 0:
49
+ data = ng.proc_base.zf(data, pad_length)
50
+ else:
51
+ data = data[:target_length]
52
+
53
+ spec=ng.proc_base.fft(data)
54
+ spec = ng.process.proc_autophase.autops(spec, autophase_fn, p0=ph0_correction, p1=ph1_correction, disp=False)
55
+
56
+ return spec
57
+
58
+ # process optimal measurement fid to spectrum
59
+ opti_spectrum_full = fid_to_spectrum(opti_fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=target_length)
60
+
61
+ if calibration_peak_center == "auto":
62
+ calibration_peak_center = np.argmax(abs(opti_spectrum_full))
63
+ fitting_range = (calibration_peak_center - calibration_window_halfwidth, calibration_peak_center+calibration_window_halfwidth+1)
64
+
65
+ opti_spectrum = opti_spectrum_full[fitting_range[0]:fitting_range[1]]
66
+ np.save(opi_spectrum_file, opti_spectrum)
67
+ print(f"Optimal spectrum extracted to {opi_spectrum_file}")
68
+
69
+ # process loop fids to spectra
70
+ spec_list=[]
71
+ spec_names=[]
72
+
73
+ print("Extracting spectra from fids...")
74
+ for fid_path in tqdm(list(Path(data_dir).rglob('*.fid'))):
75
+ spec = fid_to_spectrum(fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=target_length)[fitting_range[0]:fitting_range[1]]
76
+
77
+ spec_list.append(spec)
78
+ spec_names.append(fid_path.name)
79
+
80
+ total = np.array(spec_list)
81
+ np.save(spectra_file, total)
82
+ pd.DataFrame(spec_names).to_csv(spectra_file_names, header=False)
83
+ # total = np.load(spectra_file)
84
+ print(f"Spectra extracted to {spectra_file}")
85
+
86
+ # process SCRF extraction
87
+ def fit_kernel(base, target, kernel_size, kernel_sqrt=True, steps=20000, verbose=False):
88
+
89
+ kernel = torch.ones((1,1,kernel_size), dtype=base.dtype)
90
+ if kernel_sqrt:
91
+ kernel /= torch.sqrt(torch.sum(kernel**2))
92
+ else:
93
+ kernel /= kernel_size
94
+ kernel.requires_grad = True
95
+
96
+ optimizer = torch.optim.Adam([kernel])
97
+
98
+ for epoch in range(steps):
99
+ if kernel_sqrt:
100
+ spe_est = torch.conv1d(base, kernel**2, padding='same')
101
+ else:
102
+ spe_est = torch.conv1d(base, kernel, padding='same')
103
+ loss = torch.mean(abs(target - spe_est)**2) #torch.nn.functional.mse_loss(spe_est, target)
104
+ loss.backward()
105
+ optimizer.step()
106
+ optimizer.zero_grad()
107
+
108
+ if verbose and (epoch+1) % 100 == 0:
109
+ print(epoch, loss.item())
110
+ if kernel_sqrt:
111
+ return kernel.detach()**2, loss.item()
112
+ else:
113
+ return kernel.detach(), loss.item()
114
+
115
+ responses = torch.empty(len(total), 1, 1, 1, 1, kernel_size)
116
+ losses = torch.empty(len(total))
117
+ base = torch.tensor(opti_spectrum.real).unsqueeze(0)
118
+ targets = torch.tensor(total.real)
119
+
120
+ # normalization
121
+ base /= base.sum()
122
+ targets /= targets.sum(dim=(-1,), keepdim=True)
123
+
124
+ print("\nExtracting SCRFs...")
125
+ for i, target in tqdm(enumerate(targets), total=len(targets)):
126
+ kernel, loss = fit_kernel(base, target.unsqueeze(0), kernel_size, kernel_sqrt=kernel_sqrt, steps=steps)
127
+ responses[i, 0, 0] = kernel
128
+ losses[i] = loss
129
+
130
+ torch.save(responses, responses_file)
131
+ torch.save(losses, losses_file)
132
+ print(f"SCRFs extracted to {responses_file}, losses saved to {losses_file}")
predict.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.set_grad_enabled(False)
3
+ import numpy as np
4
+ import argparse
5
+ from pathlib import Path
6
+ import sys, os
7
+ from omegaconf import OmegaConf
8
+
9
+ from src.models import ShimNetWithSCRF, Predictor
10
+
11
+ # silent deprecation warnings
12
+ # https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
13
+ import warnings
14
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
15
+
16
+ class Defaults:
17
+ SCALE = 16.0
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("input_files", help="Input files", nargs="+")
22
+ parser.add_argument("--config", help="config file .yaml")
23
+ parser.add_argument("--weights", help="model weights")
24
+ parser.add_argument("-o", "--output_dir", default=".", help="Output directory")
25
+ parser.add_argument("--input_spectrometer_frequency", default=None, type=float, help="spectrometer frequency in MHz (input sample collection frequency). Empty if the same as in the training data")
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+ # functions
30
+ def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
31
+ """resample input spectrum to match the model's frequency range"""
32
+ freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point)
33
+ spectrum = np.interp(freqs, input_freqs, input_spectrum)
34
+ return freqs, spectrum
35
+
36
+ def resample_output_spectrum(input_freqs, freqs, prediction):
37
+ """resample prediction to match the input spectrum's frequency range"""
38
+ prediction = np.interp(input_freqs, freqs, prediction)
39
+ return prediction
40
+
41
+ def initialize_predictor(config, weights_file):
42
+ model = ShimNetWithSCRF(**config.model.kwargs)
43
+ predictor = Predictor(model, weights_file)
44
+ return predictor
45
+
46
+ # run
47
+ if __name__ == "__main__":
48
+ args = parse_args()
49
+ output_dir = Path(args.output_dir)
50
+ output_dir.mkdir(exist_ok=True, parents=True)
51
+
52
+ config = OmegaConf.load(args.config)
53
+ model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
54
+ predictor = initialize_predictor(config, args.weights)
55
+
56
+ for input_file in args.input_files:
57
+ print(f"processing {input_file} ...")
58
+
59
+ # load data
60
+ input_data = np.loadtxt(input_file)
61
+ input_freqs_input_ppm, input_spectrum = input_data[:,0], input_data[:,1]
62
+
63
+ # convert input frequencies to model's frequency - correct for zero filling ad spectrometer frequency
64
+ if args.input_spectrometer_frequency is not None:
65
+ input_freqs_model_ppm = input_freqs_input_ppm * args.input_spectrometer_frequency / config.metadata.spectrometer_frequency
66
+ else:
67
+ input_freqs_model_ppm = input_freqs_input_ppm
68
+
69
+ freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point)
70
+
71
+ spectrum = torch.tensor(spectrum).float()
72
+ # scale height of the spectrum
73
+ scaling_factor = Defaults.SCALE / spectrum.max()
74
+ spectrum *= scaling_factor
75
+
76
+ # correct spectrum
77
+ prediction = predictor(spectrum).numpy()
78
+
79
+ # rescale height
80
+ prediction /= scaling_factor
81
+
82
+ # resample the output to match the input spectrum
83
+ output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)
84
+
85
+ # save result
86
+ output_file = output_dir / f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
87
+
88
+ np.savetxt(output_file, np.column_stack((input_freqs_input_ppm, output_prediction)))
89
+ print(f"saved to {output_file}")
requirements-cpu.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1+cpu
2
+ torchaudio==2.4.1+cpu
3
+ nmrglue==0.11
4
+ torchdata==0.9.0
5
+ numpy==2.0.2
6
+ matplotlib==3.9.3
7
+ pandas==2.2.3
8
+ tqdm==4.67.1
9
+ hydra-core==1.3.2
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ torchaudio==2.4.1
3
+ nmrglue==0.11
4
+ torchdata==0.9.0
5
+ numpy==2.0.2
6
+ matplotlib==3.9.3
7
+ pandas==2.2.3
8
+ tqdm==4.67.1
9
+ hydra-core==1.3.2
sample_data/2ethylonaphthalene_bestshims_700MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/2ethylonaphthalene_up_1mm_700MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_20ul_700MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_X_supressed_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_Z1Z2Z3Z4_supressed_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_Z1Z2_supressed_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_besteshims_supressed_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Azarone_bestshims_700MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/CresolRed_after_styrene_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/CresolRed_bestshims_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Geraniol_bestshims_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
sample_data/Geraniol_up_1mm_600MHz.csv ADDED
The diff for this file is too large to render. See raw diff
 
src/generators.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchdata
4
+ # from itertools import islice
5
+
6
+ def random_value(min_value, max_value):
7
+ return (min_value + torch.rand(1) * (max_value - min_value)).item()
8
+
9
+ def random_loguniform(min_value, max_value):
10
+ return (min_value * torch.exp(torch.rand(1) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
11
+
12
+ def calculate_theoretical_spectrum(peaks_parameters: dict, frq_frq:torch.Tensor):
13
+ # extract parameters
14
+ tff_lin = peaks_parameters["tff_lin"]
15
+ twf_lin = peaks_parameters["twf_lin"]
16
+ thf_lin = peaks_parameters["thf_lin"]
17
+ trf_lin = peaks_parameters["trf_lin"]
18
+
19
+ lwf_lin = twf_lin
20
+ lhf_lin = thf_lin * (1. - trf_lin)
21
+ gwf_lin = twf_lin
22
+ gdf_lin = gwf_lin / torch.tensor(2.).log().mul(2.).sqrt()
23
+ ghf_lin = thf_lin * trf_lin
24
+ # calculate Lorenz peaks contriubutions
25
+ lsf_linfrq = lwf_lin[:, None] ** 2 / (lwf_lin[:, None] ** 2 + (frq_frq - tff_lin[:, None]) ** 2) * lhf_lin[:, None]
26
+ # calculate Gaussian peaks contriubutions
27
+ gsf_linfrq = torch.exp(-(frq_frq - tff_lin[:, None]) ** 2 / gdf_lin[:, None] ** 2 / 2.) * ghf_lin[:, None]
28
+ tsf_linfrq = lsf_linfrq + gsf_linfrq
29
+ # sum peaks contriubutions
30
+ tsf_frq = tsf_linfrq.sum(0, keepdim = True)
31
+ return tsf_frq
32
+
33
+
34
+ pascal_triangle = [(1,), (1,1), (1,2,1), (1,3,3,1), (1,4,6,4,1), (1,5,10,10,5,1), (1,6,15,20,15,6,1), (1,7, 21,35,35,21,7,1)]
35
+ normalized_pascal_triangle = [torch.tensor(x)/sum(x) for x in pascal_triangle]
36
+
37
+ def pascal_multiplicity(multiplicity):
38
+ intensities = normalized_pascal_triangle[multiplicity-1]
39
+ n_peaks = len(intensities)
40
+ shifts = torch.arange(n_peaks)-((n_peaks-1)/2)
41
+ return shifts, intensities
42
+
43
+ def double_multiplicity(multiplicity1, multiplicity2, j1=1, j2=1):
44
+ shifts1, intensities1 = pascal_multiplicity(multiplicity1)
45
+ shifts2, intensities2 = pascal_multiplicity(multiplicity2)
46
+
47
+ shifts = (j1*shifts1.reshape(-1,1) + j2*shifts2.reshape(1,-1)).flatten()
48
+ intensities = (intensities1.reshape(-1,1) * intensities2.reshape(1,-1)).flatten()
49
+ return shifts, intensities
50
+
51
+ def generate_multiplet_parameters(multiplicity, tff_lin, thf_lin, twf_lin, trf_lin, j1, j2):
52
+ shifts, intensities = double_multiplicity(multiplicity[0], multiplicity[1], j1, j2)
53
+ n_peaks = len(shifts)
54
+
55
+ return {
56
+ "tff_lin": shifts + tff_lin,
57
+ "thf_lin": intensities * thf_lin,
58
+ "twf_lin": torch.full((n_peaks,), twf_lin),
59
+ "trf_lin": torch.full((n_peaks,), trf_lin),
60
+ }
61
+
62
+ def value_to_index(values, table):
63
+ span = table[-1] - table[0]
64
+ indices = ((values - table[0])/span * (len(table)-1)) #.round().type(torch.int64)
65
+ return indices
66
+
67
+ def generate_theoretical_spectrum(
68
+ number_of_signals_min, number_of_signals_max,
69
+ spectrum_width_min, spectrum_width_max,
70
+ relative_width_min, relative_width_max,
71
+ tff_min, tff_max,
72
+ thf_min, thf_max,
73
+ trf_min, trf_max,
74
+ relative_height_min, relative_height_max,
75
+ multiplicity_j1_min, multiplicity_j1_max,
76
+ multiplicity_j2_min, multiplicity_j2_max,
77
+ atom_groups_data,
78
+ frq_frq
79
+ ):
80
+ number_of_signals = torch.randint(number_of_signals_min, number_of_signals_max+1, [])
81
+ atom_group_indices = torch.randint(0, len(atom_groups_data), [number_of_signals])
82
+ width_spectrum = random_loguniform(spectrum_width_min, spectrum_width_max)
83
+ height_spectrum = random_loguniform(thf_min, thf_max)
84
+
85
+ peak_parameters_data = []
86
+ theoretical_spectrum = None
87
+ for atom_group_index in atom_group_indices:
88
+ relative_intensity, multiplicity1, multiplicity2 = atom_groups_data[atom_group_index]
89
+ position = random_value(tff_min, tff_max)
90
+ j1 = random_value(multiplicity_j1_min, multiplicity_j1_max)
91
+ j2 = random_value(multiplicity_j2_min, multiplicity_j2_max)
92
+ width = width_spectrum*random_loguniform(relative_width_min, relative_width_max)
93
+ height = height_spectrum*relative_intensity*random_loguniform(relative_height_min, relative_height_max)
94
+ gaussian_contribution = random_value(trf_min, trf_max)
95
+
96
+ peaks_parameters = generate_multiplet_parameters(multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin= width, trf_lin= gaussian_contribution, j1=j1, j2=j2)
97
+ peaks_parameters["tff_relative"] = value_to_index(peaks_parameters["tff_lin"], frq_frq)
98
+ peak_parameters_data.append(peaks_parameters)
99
+ spectrum_contribution = calculate_theoretical_spectrum(peaks_parameters, frq_frq)
100
+ if theoretical_spectrum is None:
101
+ theoretical_spectrum = spectrum_contribution
102
+ else:
103
+ theoretical_spectrum += spectrum_contribution
104
+ return theoretical_spectrum, peak_parameters_data
105
+
106
+
107
+ def theoretical_generator(
108
+ atom_groups_data,
109
+ pixels=2048, frq_step=11160.7142857 / 32768,
110
+ number_of_signals_min=1, number_of_signals_max=8,
111
+ spectrum_width_min=0.2, spectrum_width_max=1,
112
+ relative_width_min=1, relative_width_max=2,
113
+ relative_height_min=1, relative_height_max=1,
114
+ relative_frequency_min=-0.4, relative_frequency_max=0.4,
115
+ thf_min=1/16, thf_max=16,
116
+ trf_min=0, trf_max=1,
117
+ multiplicity_j1_min=0, multiplicity_j1_max=15,
118
+ multiplicity_j2_min=0, multiplicity_j2_max=15,
119
+ ):
120
+ tff_min = relative_frequency_min * pixels * frq_step
121
+ tff_max = relative_frequency_max * pixels * frq_step
122
+ frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step
123
+
124
+ while True:
125
+ yield generate_theoretical_spectrum(
126
+ number_of_signals_min=number_of_signals_min,
127
+ number_of_signals_max=number_of_signals_max,
128
+ spectrum_width_min=spectrum_width_min,
129
+ spectrum_width_max=spectrum_width_max,
130
+ relative_width_min=relative_width_min,
131
+ relative_width_max=relative_width_max,
132
+ relative_height_min=relative_height_min,
133
+ relative_height_max=relative_height_max,
134
+ tff_min=tff_min, tff_max=tff_max,
135
+ thf_min=thf_min, thf_max=thf_max,
136
+ trf_min=trf_min, trf_max=trf_max,
137
+ multiplicity_j1_min=multiplicity_j1_min,
138
+ multiplicity_j1_max=multiplicity_j1_max,
139
+ multiplicity_j2_min=multiplicity_j2_min,
140
+ multiplicity_j2_max=multiplicity_j2_max,
141
+ atom_groups_data=atom_groups_data,
142
+ frq_frq=frq_frq
143
+ )
144
+
145
+ class ResponseLibrary:
146
+ def __init__(self, reponse_files, normalize=True):
147
+ self.data = [torch.load(f, map_location='cpu', weights_only=True).flatten(0,-4) for f in reponse_files]
148
+ if normalize:
149
+ self.data = [data/torch.sum(data, dim=(-1,), keepdim=True) for data in self.data]
150
+ lengths = [len(data) for data in self.data]
151
+ self.start_indices = torch.cumsum(torch.tensor([0] + lengths[:-1]), 0)
152
+ self.total_length = sum(lengths)
153
+
154
+ def __getitem__(self, idx):
155
+ if idx >= self.total_length:
156
+ raise ValueError(f'index {idx} out of range')
157
+ tensor_index = torch.searchsorted(self.start_indices, idx, right=True) - 1
158
+ return self.data[tensor_index][idx - self.start_indices[tensor_index]]
159
+
160
+ def __len__(self):
161
+ return self.total_length
162
+
163
+ def generator(
164
+ theoretical_generator_params,
165
+ response_function_library,
166
+ response_function_stretch_min=0.5,
167
+ response_function_stretch_max=2.0,
168
+ response_function_noise=0.,
169
+ spectrum_noise_min=0.,
170
+ spectrum_noise_max=1/64,
171
+ include_spectrum_data=False,
172
+ include_peak_mask=False,
173
+ include_response_function=False,
174
+ flip_response_function=False
175
+
176
+ ):
177
+ for theoretical_spectrum, theoretical_spectrum_data in theoretical_generator(**theoretical_generator_params):
178
+ # get response function
179
+ response_function = response_function_library[torch.randint(0, len(response_function_library), [1])][0]
180
+ # stretch response function
181
+ padding_size = (response_function.shape[-1] - 1)//2
182
+ padding_size = round(random_loguniform(response_function_stretch_min, response_function_stretch_max)*padding_size) #torch.randint(round(padding_size*response_function_stretch_min), round(padding_size*response_function_stretch_max), [1]).item()
183
+ response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
184
+ response_function /= response_function.sum() # normalize sum of response function to 1
185
+ # add noise to response function
186
+ response_function += torch.randn(response_function.shape) * response_function_noise
187
+ response_function /= response_function.sum() # normalize sum of response function to 1
188
+ if flip_response_function and (torch.rand(1).item() < 0.5):
189
+ response_function = response_function.flip(-1)
190
+ # disturbed spectrum
191
+ disturbed_spectrum = torch.nn.functional.conv1d(theoretical_spectrum, response_function, padding=padding_size)
192
+ # add noise
193
+ noised_spectrum = disturbed_spectrum + torch.randn(disturbed_spectrum.shape) * random_value(spectrum_noise_min, spectrum_noise_max)
194
+
195
+ out = {
196
+ # 'response_function': response_function,
197
+ 'theoretical_spectrum': theoretical_spectrum,
198
+ 'disturbed_spectrum': disturbed_spectrum,
199
+ 'noised_spectrum': noised_spectrum,
200
+ }
201
+ if include_response_function:
202
+ out['response_function'] = response_function
203
+ if include_spectrum_data:
204
+ out["theoretical_spectrum_data"] = theoretical_spectrum_data
205
+ if include_peak_mask:
206
+ all_peaks_rel = torch.cat([peak_data["tff_relative"] for peak_data in theoretical_spectrum_data])
207
+ peaks_indices = all_peaks_rel.round().type(torch.int64)
208
+ out["peaks_mask"] = torch.scatter(torch.zeros(out["theoretical_spectrum"].shape[1]), 0, peaks_indices, 1.).unsqueeze(0)
209
+
210
+ yield out
211
+
212
+
213
+ def collate_with_spectrum_data(batch):
214
+ tensor_keys = set(batch[0].keys())
215
+ tensor_keys.remove('theoretical_spectrum_data')
216
+ out = {k: torch.stack([item[k] for item in batch]) for k in tensor_keys}
217
+ out["theoretical_spectrum_data"] = [item["theoretical_spectrum_data"] for item in batch]
218
+ return out
219
+
220
+ def get_datapipe(
221
+ response_functions_files,
222
+ atom_groups_data_file=None,
223
+ batch_size=64,
224
+ pixels=2048, frq_step=11160.7142857 / 32768,
225
+ number_of_signals_min=1, number_of_signals_max=8,
226
+ spectrum_width_min=0.2, spectrum_width_max=1,
227
+ relative_width_min=1, relative_width_max=2,
228
+ relative_height_min=1, relative_height_max=1,
229
+ relative_frequency_min=-0.4, relative_frequency_max=0.4,
230
+ thf_min=1/16, thf_max=16,
231
+ trf_min=0, trf_max=1,
232
+ multiplicity_j1_min=0, multiplicity_j1_max=15,
233
+ multiplicity_j2_min=0, multiplicity_j2_max=15,
234
+ response_function_stretch_min=0.5,
235
+ response_function_stretch_max=2.0,
236
+ response_function_noise=0.,
237
+ spectrum_noise_min=0.,
238
+ spectrum_noise_max=1/64,
239
+ include_spectrum_data=False,
240
+ include_peak_mask=False,
241
+ include_response_function=False,
242
+ flip_response_function=False
243
+ ):
244
+ # singlets
245
+ if atom_groups_data_file is None:
246
+ atom_groups_data = np.ones((1,3), dtype=int)
247
+ else:
248
+ atom_groups_data = np.loadtxt(atom_groups_data_file, usecols=(1,2,3), dtype=int)
249
+ response_function_library = ResponseLibrary(response_functions_files)
250
+ g = generator(
251
+ theoretical_generator_params=dict(
252
+ atom_groups_data=atom_groups_data,
253
+ pixels=pixels, frq_step=frq_step,
254
+ number_of_signals_min=number_of_signals_min, number_of_signals_max=number_of_signals_max,
255
+ spectrum_width_min=spectrum_width_min, spectrum_width_max=spectrum_width_max,
256
+ relative_width_min=relative_width_min, relative_width_max=relative_width_max,
257
+ relative_height_min=relative_height_min, relative_height_max=relative_height_max,
258
+ relative_frequency_min=relative_frequency_min, relative_frequency_max=relative_frequency_max,
259
+ thf_min=thf_min, thf_max=thf_max,
260
+ trf_min=trf_min, trf_max=trf_max,
261
+ multiplicity_j1_min=multiplicity_j1_min, multiplicity_j1_max=multiplicity_j1_max,
262
+ multiplicity_j2_min=multiplicity_j2_min, multiplicity_j2_max=multiplicity_j2_max
263
+ ),
264
+ response_function_library=response_function_library,
265
+ response_function_stretch_min=response_function_stretch_min,
266
+ response_function_stretch_max=response_function_stretch_max,
267
+ response_function_noise=response_function_noise,
268
+ spectrum_noise_min=spectrum_noise_min,
269
+ spectrum_noise_max=spectrum_noise_max,
270
+ include_spectrum_data=include_spectrum_data,
271
+ include_peak_mask=include_peak_mask,
272
+ include_response_function=include_response_function,
273
+ flip_response_function=flip_response_function
274
+ )
275
+
276
+ pipe = torchdata.datapipes.iter.IterableWrapper(g, deepcopy=False)
277
+ pipe = pipe.batch(batch_size)
278
+ pipe = pipe.collate(collate_fn=collate_with_spectrum_data if include_spectrum_data else None)
279
+
280
+ return pipe
src/models.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class ConvEncoder(torch.nn.Module):
4
+ def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
5
+ super().__init__()
6
+ if output_dim is None:
7
+ output_dim = hidden_dim
8
+ self.conv4 = torch.nn.Conv1d(1, hidden_dim, kernel_size)
9
+ self.conv3 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
10
+ self.conv2 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
11
+ self.conv1 = torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
12
+ self.dropout = torch.nn.Dropout(dropout)
13
+
14
+ def forward(self, feature): #(samples, 1, 2048)
15
+ feature = self.dropout(self.conv4(feature)) #(samples, 64, 2042)
16
+ feature = feature.relu()
17
+ feature = self.dropout(self.conv3(feature)) #(samples, 64, 2036)
18
+ feature = feature.relu()
19
+ feature = self.dropout(self.conv2(feature)) #(samples, 64, 2030)
20
+ feature = feature.relu()
21
+ feature = self.dropout(self.conv1(feature)) #(samples, 64, 2024)
22
+ return feature
23
+
24
+ class ConvDecoder(torch.nn.Module):
25
+ def __init__(self, input_dim=None, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
26
+ super().__init__()
27
+ if output_dim is None:
28
+ output_dim = hidden_dim
29
+ self.convTranspose1 = torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size)
30
+ self.convTranspose2 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
31
+ self.convTranspose3 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
32
+ self.convTranspose4 = torch.nn.ConvTranspose1d(hidden_dim, 1, kernel_size)
33
+
34
+ def forward(self, feature): #(samples, 1, 2048)
35
+ feature = self.convTranspose1(feature) #(samples, 64, 2030)
36
+ feature = feature.relu()
37
+ feature = self.convTranspose2(feature) #(samples, 64, 2036)
38
+ feature = feature.relu()
39
+ feature = self.convTranspose3(feature) #(samples, 64, 2042)
40
+ feature = feature.relu()
41
+ feature = self.convTranspose4(feature)
42
+ return feature
43
+
44
+ class ResponseHead(torch.nn.Module):
45
+ def __init__(self, input_dim, output_length, hidden_dims=[128]):
46
+ super().__init__()
47
+ response_head_dims = [input_dim]+hidden_dims + [output_length]
48
+ response_head_layers = [torch.nn.Linear(response_head_dims[0], response_head_dims[1])]
49
+ for dims_in, dims_out in zip(response_head_dims[1:-1], response_head_dims[2:]):
50
+ response_head_layers.extend([
51
+ torch.nn.GELU(),
52
+ torch.nn.Linear(dims_in, dims_out)
53
+ ])
54
+ self.response_head = torch.nn.Sequential(*response_head_layers)
55
+
56
+ def forward(self, feature):
57
+ return self.response_head(feature)
58
+
59
+ class ShimNetWithSCRF(torch.nn.Module):
60
+ def __init__(self,
61
+ encoder_hidden_dims=64,
62
+ encoder_dropout=0,
63
+ bottleneck_dim=64,
64
+ rensponse_length=61,
65
+ resnponse_head_dims=[128],
66
+ decoder_hidden_dims=64
67
+ ):
68
+ super().__init__()
69
+ self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout)
70
+ self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim))
71
+ torch.nn.init.xavier_normal_(self.query)
72
+
73
+ self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims)
74
+
75
+ self.rensponse_length = rensponse_length
76
+ self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims)
77
+
78
+ def forward(self, feature): #(samples, 1, 2048)
79
+ feature = self.encoder(feature) #(samples, 64, 2042)
80
+ energy = self.query @ feature #(samples, 1, 2024)
81
+ weight = torch.nn.functional.softmax(energy, 2) #(samples, 1, 2024)
82
+ global_features = feature @ weight.transpose(1, 2) #(samples, 64, 1)
83
+
84
+ response = self.response_head(global_features.squeeze(-1))
85
+
86
+ feature, global_features = torch.broadcast_tensors(feature, global_features) #(samples, 64, 2048)
87
+ feature = torch.cat([feature, global_features], 1) #(samples, 128, 2024)
88
+ denoised_spectrum = self.decoder(feature) #(samples, 1, 2048)
89
+
90
+ return {
91
+ 'denoised': denoised_spectrum,
92
+ 'response': response,
93
+ 'attention': weight.squeeze(1)
94
+ }
95
+
96
+ class Predictor:
97
+ def __init__(self, model=None, weights_file=None):
98
+ self.model = model
99
+ if weights_file is not None:
100
+ self.model.load_state_dict(torch.load(weights_file, map_location='cpu', weights_only=True))
101
+
102
+ def __call__(self, nsf_frq):
103
+ with torch.no_grad():
104
+ msf_frq = self.model(nsf_frq[None, None])["denoised"]
105
+ return msf_frq[0, 0]
train.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchaudio
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from omegaconf import OmegaConf
5
+ from hydra.utils import instantiate
6
+ import datetime
7
+ import sys
8
+ import matplotlib.pyplot as plt
9
+
10
+
11
+ import matplotlib
12
+ matplotlib.use('Agg')
13
+
14
+ # silent deprecation_warning() from datapipes
15
+ import warnings
16
+ warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
17
+
18
+ from src import models
19
+ from src.generators import get_datapipe
20
+
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ if len(sys.argv) < 2:
23
+ print("Please provide the run directory as an argument.")
24
+ sys.exit(1)
25
+ run_dir = Path(sys.argv[1])
26
+
27
+ config = OmegaConf.load(run_dir / "config.yaml")
28
+
29
+ if (run_dir / "train.txt").is_file():
30
+ minimum = np.min(np.loadtxt(run_dir / "train.txt")[:,2])
31
+ else:
32
+ minimum = float("inf")
33
+
34
+ # initialization
35
+ model = instantiate({"_target_": f"__main__.models.{config.model.name}", **config.model.kwargs}).to(device)
36
+ model_weights_file = run_dir / f'model.pt'
37
+ optimizer = torch.optim.Adam(model.parameters())
38
+ optimizer_weights_file = run_dir / f'optimizer.pt'
39
+
40
+ def evaluate_model(stage=0, epoch=0):
41
+ plot_dir = run_dir / "plots" / f"{stage}_{epoch}"
42
+ plot_dir.mkdir(exist_ok=True, parents=True)
43
+
44
+ torch.save(model.state_dict(), plot_dir / "model.pt")
45
+ torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt")
46
+
47
+ num_plots = config.logging.num_plots
48
+ pipe = get_datapipe(
49
+ **config.data,
50
+ include_response_function=True,
51
+ batch_size=num_plots
52
+ )
53
+ batch = next(iter(pipe))
54
+
55
+ with torch.no_grad():
56
+ out = model(batch['noised_spectrum'].to(device))
57
+ noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu()
58
+
59
+ for i in range(num_plots):
60
+ plt.figure(figsize=(30,6))
61
+ plt.plot(batch['theoretical_spectrum'].cpu().numpy()[i,0])
62
+ plt.plot(out['denoised'].cpu().numpy()[i,0])
63
+ plt.savefig(plot_dir / f"{i:03d}_spectrum_clean.png")
64
+
65
+ plt.figure(figsize=(30,6))
66
+ plt.plot(batch['noised_spectrum'].cpu().numpy()[i,0])
67
+ plt.plot(noised_est.cpu().numpy()[i,0])
68
+ plt.savefig(plot_dir / f"{i:03d}_spectrum_noise.png")
69
+
70
+ plt.figure(figsize=(10,6))
71
+ plt.plot(batch['response_function'].cpu().numpy()[i,0,0])
72
+ plt.plot(out['response'].cpu().numpy()[i])
73
+ plt.savefig(plot_dir / f"{i:03d}_response.png")
74
+
75
+ if "attention" in out:
76
+ plt.figure(figsize=(10, 6))
77
+ plt.plot(out['attention'].cpu().numpy()[i])
78
+ plt.savefig(plot_dir / f"{i:03d}_attention.png")
79
+
80
+ plt.close("all")
81
+
82
+ for i_stage, training_stage in enumerate(config.training):
83
+ if model_weights_file.is_file():
84
+ model.load_state_dict(torch.load(model_weights_file, weights_only=True))
85
+
86
+ if optimizer_weights_file.is_file():
87
+ optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True))
88
+ optimizer.param_groups[0]['lr'] = training_stage.learning_rate
89
+
90
+ pipe = get_datapipe(
91
+ **config.data,
92
+ include_response_function=True,
93
+ batch_size=training_stage.batch_size
94
+ )
95
+
96
+ losses_history = []
97
+ losses_history_limit = 64*100 // training_stage.batch_size
98
+
99
+ last_evaluation = 0
100
+ for epoch, batch in pipe.enumerate():
101
+
102
+ # logging
103
+ iters_done = epoch*training_stage.batch_size
104
+ if (iters_done - last_evaluation) > config.logging.step:
105
+ evaluate_model(i_stage, epoch)
106
+ last_evaluation = iters_done
107
+
108
+ if iters_done > training_stage.max_iters:
109
+ evaluate_model(i_stage, epoch)
110
+ break
111
+
112
+ # run model
113
+ out = model(batch['noised_spectrum'].to(device))
114
+ # calculate losses
115
+ loss_response = torch.nn.functional.mse_loss(out['response'], batch['response_function'].squeeze(dim=(1,2)).to(device))
116
+ loss_clean = torch.nn.functional.mse_loss(out['denoised'], batch['theoretical_spectrum'].to(device))
117
+ noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same")
118
+ loss_noised = torch.nn.functional.mse_loss(noised_est, batch['noised_spectrum'].to(device))
119
+ loss = config.losses_weights.response*loss_response + config.losses_weights.clean*loss_clean + config.losses_weights.noised*loss_noised
120
+
121
+ # logging
122
+ losses_history.append(loss_clean.item())
123
+ losses_history = losses_history[-losses_history_limit:]
124
+ loss_avg = sum(losses_history)/len(losses_history)
125
+ message = f"{epoch:7d} {loss:0.3e} {loss_avg:0.3e} {loss_clean:0.3e} {loss_response:0.3e} {loss_noised:0.3e}"
126
+ # message = '%7i %.3e %.3e %.3e' % (epoch, loss, regress, classify)
127
+ with open(run_dir / f'train.txt', 'a') as f:
128
+ f.write(message + '\n')
129
+ print(message, flush = True)
130
+
131
+ # save best
132
+ if loss_avg < minimum:
133
+ minimum = loss_avg
134
+ torch.save(model.state_dict(), model_weights_file)
135
+ torch.save(optimizer.state_dict(),optimizer_weights_file)
136
+
137
+ # update weights
138
+ loss.backward()
139
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
140
+ optimizer.step()
141
+ optimizer.zero_grad()