Spaces:
Sleeping
Sleeping
Marek Bukowicki commited on
Commit ·
5b32793
1
Parent(s): 0fef3e5
1st commit
Browse files- .gitignore +11 -0
- Readme.md +176 -0
- calibration_loop/maclib/shimator_loop +33 -0
- calibration_loop/sweep_shims_lineshape_Z1Z2.py +39 -0
- configs/shimnet_600.yaml +46 -0
- configs/shimnet_700.yaml +46 -0
- configs/shimnet_template.yaml +50 -0
- download_files.py +90 -0
- extract_scrf_from_fids.py +132 -0
- predict.py +89 -0
- requirements-cpu.txt +9 -0
- requirements.txt +9 -0
- sample_data/2ethylonaphthalene_bestshims_700MHz.csv +0 -0
- sample_data/2ethylonaphthalene_up_1mm_700MHz.csv +0 -0
- sample_data/Azarone_20ul_700MHz.csv +0 -0
- sample_data/Azarone_X_supressed_600MHz.csv +0 -0
- sample_data/Azarone_Z1Z2Z3Z4_supressed_600MHz.csv +0 -0
- sample_data/Azarone_Z1Z2_supressed_600MHz.csv +0 -0
- sample_data/Azarone_besteshims_supressed_600MHz.csv +0 -0
- sample_data/Azarone_bestshims_700MHz.csv +0 -0
- sample_data/CresolRed_after_styrene_600MHz.csv +0 -0
- sample_data/CresolRed_bestshims_600MHz.csv +0 -0
- sample_data/Geraniol_bestshims_600MHz.csv +0 -0
- sample_data/Geraniol_up_1mm_600MHz.csv +0 -0
- src/generators.py +280 -0
- src/models.py +105 -0
- train.py +141 -0
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# python's temps
|
| 2 |
+
**/.ipynb_checkpoints/
|
| 3 |
+
**/__pycache__/
|
| 4 |
+
|
| 5 |
+
# training outputs
|
| 6 |
+
runs/
|
| 7 |
+
# data files
|
| 8 |
+
data/
|
| 9 |
+
# typically weights and data
|
| 10 |
+
*.pt
|
| 11 |
+
|
Readme.md
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ShimNet
|
| 2 |
+
ShimNet is a data-driven AI solution to improve high-resolution nuclear magnetic resonance (NMR) spectra
|
| 3 |
+
distorted by the inhomogeneous magnetic field (less than optimal shimming). To use it, the experimental training data has to be collected (see **Data collection** below).
|
| 4 |
+
Example data can also be downloaded (see below).
|
| 5 |
+
|
| 6 |
+
Paper: ...
|
| 7 |
+
|
| 8 |
+
## Installation
|
| 9 |
+
|
| 10 |
+
Python 3.9+
|
| 11 |
+
|
| 12 |
+
GPU version (for training and inference)
|
| 13 |
+
```
|
| 14 |
+
pip install -r requirements.txt
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
CPU version (for inference, not recommended for training)
|
| 18 |
+
```
|
| 19 |
+
pip install -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
## Usage
|
| 23 |
+
To correct spectra presented in the paper:
|
| 24 |
+
1. download weights (model parameters):
|
| 25 |
+
```
|
| 26 |
+
python download_files.py
|
| 27 |
+
```
|
| 28 |
+
or directly from [Google Drive 700MHz](https://drive.google.com/uc?export=download&id=17fTNWl7YW6mPbbZWga0EfdoF_6S8fCke) and [Google Drive 600MHz](https://drive.google.com/uc?export=download&id=1_VxOpFGJcFsOa5DHOW2GJbP8RvHCmC1N) and place it in `weights` directory
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
2. : run correction (e.g. `Azarone_20ul_700MHz.csv`):
|
| 32 |
+
```
|
| 33 |
+
python predict.py sample_data/Azarone_20ul_700MHz.csv -o output --config configs/shimnet_700.yaml --weights weights/shimnet_700MHz.pt
|
| 34 |
+
```
|
| 35 |
+
The output will be `output/Azarone_20ul_700MHz_processed.csv` file
|
| 36 |
+
|
| 37 |
+
Multiple files may be processed using "*" syntax:
|
| 38 |
+
```
|
| 39 |
+
python predict.py sample_data/*700MHz.csv -o output --config configs/shimnet_700.yaml --weights weights/shimnet_700MH
|
| 40 |
+
z.pt
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
For 600 MHz data use `--config configs/shimnet_600.yaml` and `--weights weights/shimnet_600MHz.pt`, e.g.:
|
| 44 |
+
|
| 45 |
+
```
|
| 46 |
+
python predict.py sample_data/CresolRed_after_styrene_600MHz.csv -o output --config configs/shimnet_600.yaml --weights weights/shimnet_600MHz.pt
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### input format
|
| 50 |
+
|
| 51 |
+
The spectrum file for reconstruction should be in the format of two columns separated by a space and without the sign at the end of the line at the end of the file(example below):
|
| 52 |
+
```csv
|
| 53 |
+
-1.97134 0.0167137
|
| 54 |
+
-1.97085 -0.00778748
|
| 55 |
+
-1.97036 -0.0109595
|
| 56 |
+
-1.96988 0.00825978
|
| 57 |
+
-1.96939 0.0133886
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Train on your data
|
| 61 |
+
|
| 62 |
+
For the model to function properly, it should be trained on calibration data from the spectrometer used for the measurements. To train a model on data from your spectrometer, please follow the instructions below.
|
| 63 |
+
|
| 64 |
+
### Training data collection
|
| 65 |
+
|
| 66 |
+
Below we describe the training data collection for Agilent/Varian spectrometers. For machines of other vendors similar procedure can be implemented.
|
| 67 |
+
To collect ShimNet training data use Python script (sweep_shims_lineshape_Z1Z2.py) from the calibration_loop folder to drive the spectrometer:
|
| 68 |
+
1. Install TReNDS package ( trends.spektrino.com )
|
| 69 |
+
2. Open VnmrJ and type: 'listenon'
|
| 70 |
+
3. Put the lineshape sample (1% CHCl3 in deuterated acetone), set standard PROTON parameters, and set nt=1 (do not modify sw and at!)
|
| 71 |
+
4. Shim the sample and collect the data. Save the optimally shimmed dataset
|
| 72 |
+
5. Edit the sweep_shims_lineshape_Z1Z2.py script
|
| 73 |
+
6. Put optimum z1 and z2 shim values as optiz1 and optiz2 below
|
| 74 |
+
7. Define the calibration range as range_z1 and range_z2 (default is ok)
|
| 75 |
+
8. Start the python script:
|
| 76 |
+
```
|
| 77 |
+
python3 ./sweep_shims_lineshape_Z1Z2.py
|
| 78 |
+
```
|
| 79 |
+
The spectrometer will start collecting spectra
|
| 80 |
+
|
| 81 |
+
### SCRF extraction
|
| 82 |
+
Shim Coil Response Functions (SCRF) should be extracted from the spectra with `extract_scrf_from_fids.py` script.
|
| 83 |
+
```
|
| 84 |
+
python extract_scrf_from_fids.py
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
The script uses hardcoded paths to the NMR signals (fid-s) in Agilent/Varian format: a directory with optimal measurement (`opti_fid_path` available) and a directory with calibration loop measurements (`data_dir`):
|
| 88 |
+
```python
|
| 89 |
+
# input
|
| 90 |
+
data_dir = "../../sample_run/loop"
|
| 91 |
+
opti_fid_path = "../../sample_run/opti.fid"
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
The output files are also hardcoded:
|
| 96 |
+
```python
|
| 97 |
+
# output
|
| 98 |
+
spectra_file = "../../sample_run/total.npy"
|
| 99 |
+
spectra_file_names = "../../sample_run/total.csv"
|
| 100 |
+
opi_spectrum_file = "../../sample_run/opti.npy"
|
| 101 |
+
responses_file = "../../sample_run/scrf_61.pt"
|
| 102 |
+
```
|
| 103 |
+
where only the `responses_file` is used in ShimNet training.
|
| 104 |
+
|
| 105 |
+
If the measurements are stored in a format other than Varian, you may need to change this line:
|
| 106 |
+
```python
|
| 107 |
+
dic, data = ng.varian.read(varian_fid_path)
|
| 108 |
+
```
|
| 109 |
+
(see nmrglue package documentation for details)
|
| 110 |
+
|
| 111 |
+
### Training
|
| 112 |
+
|
| 113 |
+
1. Download multiplets database:
|
| 114 |
+
```
|
| 115 |
+
python download_files.py --multiplets
|
| 116 |
+
```
|
| 117 |
+
2. Configure run:
|
| 118 |
+
- create a run directory, e.g. `runs/my_lab_spectrometer_2025`
|
| 119 |
+
- create a configuration file:
|
| 120 |
+
1. copy `configs/shimnet_template.py` to the run directory and rename it to `config.yaml`
|
| 121 |
+
```bash
|
| 122 |
+
cp configs/shimnet_template.py runs/my_lab_spectrometer_2025/config.yaml
|
| 123 |
+
```
|
| 124 |
+
2. edit the SCRF in path in the config file:
|
| 125 |
+
```yaml
|
| 126 |
+
response_functions_files:
|
| 127 |
+
- path/to/srcf_file
|
| 128 |
+
```
|
| 129 |
+
e.g.
|
| 130 |
+
```yaml
|
| 131 |
+
response_functions_files:
|
| 132 |
+
- ../../sample_run/scrf_61.pt
|
| 133 |
+
```
|
| 134 |
+
3. adjust spectrometer frequency step `frq_step` to match your data (spectrometer range in Hz divided by number of points in spectrum):
|
| 135 |
+
```yaml
|
| 136 |
+
frq_step: 0.34059797
|
| 137 |
+
```
|
| 138 |
+
4. adjust spectromer frequency in the metadata
|
| 139 |
+
```yaml
|
| 140 |
+
metadata: # additional metadata, not used in the training process
|
| 141 |
+
spectrometer_frequency: 700.0 # MHz
|
| 142 |
+
```
|
| 143 |
+
3. Run training:
|
| 144 |
+
```
|
| 145 |
+
python train.py runs/my_lab_spectrometer_2025
|
| 146 |
+
```
|
| 147 |
+
Training results will appear in `runs/my_lab_spectrometer_2025` directory.
|
| 148 |
+
Model parameters are stored in `runs/my_lab_spectrometer_2025/model.pt` file
|
| 149 |
+
4. Use trained model:
|
| 150 |
+
|
| 151 |
+
use `--config runs/my_lab_spectrometer_2025/config.yaml` and `--weights runs/my_lab_spectrometer_2025/model.pt` flags, e.g.
|
| 152 |
+
```
|
| 153 |
+
python predict.py my_sample1.csv -o my_output --config runs/my_lab_spectrometer_2025/config.yaml --weights runs/my_lab_spectrometer_2025/model.pt
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
## Repeat training on our data
|
| 159 |
+
|
| 160 |
+
If you want to train the network using the calibration data from our paper, follow the procedure below.
|
| 161 |
+
|
| 162 |
+
1. Download multiplets database and our SCRF file:
|
| 163 |
+
```
|
| 164 |
+
python download_files.py --multiplets --SCRF --no-weights
|
| 165 |
+
```
|
| 166 |
+
2. Configure run
|
| 167 |
+
```bash
|
| 168 |
+
mkdir -p runs/repeat_paper_training
|
| 169 |
+
cp configs/shimnet_700.yaml runs/repeat_paper_training/config.yaml
|
| 170 |
+
```
|
| 171 |
+
3. Run training:
|
| 172 |
+
```
|
| 173 |
+
python train.py runs/repeat_paper_training
|
| 174 |
+
```
|
| 175 |
+
Training results will appear in `runs/repeat_paper_training` directory.
|
| 176 |
+
|
calibration_loop/maclib/shimator_loop
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
" macro shimnet loop "
|
| 2 |
+
" arguments: optimal z1, optimal z2, step in z1 loop, step in z2 loop, half number of steps in z1, half number of steps in z2"
|
| 3 |
+
|
| 4 |
+
$svfdir='/home/nmrbox/kkazimierczuk/shimnet/data/'
|
| 5 |
+
|
| 6 |
+
$opti_z1=$1
|
| 7 |
+
$opti_z2=$2
|
| 8 |
+
$step_z1=$3
|
| 9 |
+
$step_z2=$4
|
| 10 |
+
$steps_z1=$5
|
| 11 |
+
$steps_z2=$6
|
| 12 |
+
$nn=''
|
| 13 |
+
$bb=''
|
| 14 |
+
$msg=''
|
| 15 |
+
format((($steps_z1*2+1)*($steps_z2*2+1)*(d1+at)*nt)/3600,5,1):$msg
|
| 16 |
+
$msg='Time: '+$msg+' h'
|
| 17 |
+
banner($msg)
|
| 18 |
+
$j=0
|
| 19 |
+
repeat
|
| 20 |
+
$i=0
|
| 21 |
+
$z2=$opti_z2-$steps_z2*$step_z2+$j*$step_z2
|
| 22 |
+
format($z2,5,1):$bb
|
| 23 |
+
repeat
|
| 24 |
+
$z1=$opti_z1-$steps_z1*$step_z1+$i*$step_z1
|
| 25 |
+
"su"
|
| 26 |
+
"go"
|
| 27 |
+
format($z1,5,1):$nn
|
| 28 |
+
$filepath=$svfdir + 'z1_'+$nn+'z2_'+$bb+'.fid'
|
| 29 |
+
svf($filepath,'force')
|
| 30 |
+
$i=$i+1
|
| 31 |
+
until $z1>$opti_z1+$steps_z1*$step_z1-1
|
| 32 |
+
$j=$j+1
|
| 33 |
+
until $z2>$opti_z2+$steps_z2*$step_z2-1
|
calibration_loop/sweep_shims_lineshape_Z1Z2.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from TReND.send2vnmr import *
|
| 2 |
+
import shutil
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
from subprocess import call
|
| 7 |
+
|
| 8 |
+
#This is the script to collect ShimNet trainig data on Agilent spectrometers. Execution:
|
| 9 |
+
# 1. Open VnmrJ and type: 'listenon'
|
| 10 |
+
# 2. Put the lineshape sample, set standard PROTON parameters and set one scan (do not modify sw and at!)
|
| 11 |
+
# 3. Shim the sample and collect the data. Save the optimally shimmed datatset
|
| 12 |
+
# 4. Put optimum z1 and z2 shim values as optiz1 and optiz2 below
|
| 13 |
+
# 5. Define the calibration range as range_z1 and range_z2 (default is ok)
|
| 14 |
+
# 6. Start the python script:
|
| 15 |
+
# python3 ./sweep_shims_lineshape_Z1Z2.py
|
| 16 |
+
# the spectrometer will start collecting spectra
|
| 17 |
+
|
| 18 |
+
Classic_path = '/home/nmr700/shimnet6/lshp'
|
| 19 |
+
optiz1= 8868 #put optimum shim values here
|
| 20 |
+
optiz2=-297
|
| 21 |
+
|
| 22 |
+
range_z1=100 #put optimum shim ranges here
|
| 23 |
+
range_z2=100 #put optimum shim ranges here
|
| 24 |
+
|
| 25 |
+
z1_sweep=np.arange(optiz1-range_z1,optiz1+range_z1+1,2.0)
|
| 26 |
+
z2_sweep=np.arange(optiz2-range_z2,optiz2+range_z2+1,2.0)
|
| 27 |
+
|
| 28 |
+
for i in range(1,np.shape(z1_sweep)[0]+1, 1):
|
| 29 |
+
for j in range(1,np.shape(z2_sweep)[0]+1, 1):
|
| 30 |
+
wait_until_idle()
|
| 31 |
+
|
| 32 |
+
Run_macro("sethw('z1',"+str(z1_sweep[i-1])+ ")")
|
| 33 |
+
Run_macro("sethw('z2',"+str(z2_sweep[j-1])+ ")")
|
| 34 |
+
|
| 35 |
+
go_if_idle()
|
| 36 |
+
wait_until_idle()
|
| 37 |
+
time.sleep(0.5)
|
| 38 |
+
|
| 39 |
+
Save_experiment(Classic_path + '_z1_'+ str(int(z1_sweep[i-1])) + '_z2_'+ str(int(z2_sweep[j-1]))+ '.fid')
|
configs/shimnet_600.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: ShimNetWithSCRF
|
| 3 |
+
kwargs:
|
| 4 |
+
rensponse_length: 81
|
| 5 |
+
resnponse_head_dims:
|
| 6 |
+
- 128
|
| 7 |
+
training:
|
| 8 |
+
- batch_size: 64
|
| 9 |
+
learning_rate: 0.001
|
| 10 |
+
max_iters: 1600000
|
| 11 |
+
- batch_size: 512
|
| 12 |
+
learning_rate: 0.001
|
| 13 |
+
max_iters: 25600000
|
| 14 |
+
- batch_size: 512
|
| 15 |
+
learning_rate: 0.0005
|
| 16 |
+
max_iters: 12800000
|
| 17 |
+
losses_weights:
|
| 18 |
+
clean: 1.0
|
| 19 |
+
noised: 1.0
|
| 20 |
+
response: 1.0
|
| 21 |
+
data:
|
| 22 |
+
response_functions_files:
|
| 23 |
+
# Paste path to your SCRF file here
|
| 24 |
+
# - Can be absolute path
|
| 25 |
+
# - Can be relative to repository root
|
| 26 |
+
- data/600hz-2-Lnative/scrf_81.pt
|
| 27 |
+
atom_groups_data_file: data/multiplets_10000_parsed.txt
|
| 28 |
+
response_function_stretch_min: 1.0
|
| 29 |
+
response_function_stretch_max: 1.0
|
| 30 |
+
response_function_noise: 0.0
|
| 31 |
+
multiplicity_j1_min: 0.0
|
| 32 |
+
multiplicity_j1_max: 15
|
| 33 |
+
multiplicity_j2_min: 0.0
|
| 34 |
+
multiplicity_j2_max: 15
|
| 35 |
+
number_of_signals_min: 2
|
| 36 |
+
number_of_signals_max: 5
|
| 37 |
+
thf_min: 0.5
|
| 38 |
+
thf_max: 2
|
| 39 |
+
relative_height_min: 0.5
|
| 40 |
+
relative_height_max: 4
|
| 41 |
+
frq_step: 0.30048
|
| 42 |
+
logging:
|
| 43 |
+
step: 1000000
|
| 44 |
+
num_plots: 32
|
| 45 |
+
metadata:
|
| 46 |
+
spectrometer_frequency: 600.0
|
configs/shimnet_700.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: ShimNetWithSCRF
|
| 3 |
+
kwargs:
|
| 4 |
+
rensponse_length: 61
|
| 5 |
+
resnponse_head_dims:
|
| 6 |
+
- 128
|
| 7 |
+
training:
|
| 8 |
+
- batch_size: 64
|
| 9 |
+
learning_rate: 0.001
|
| 10 |
+
max_iters: 1600000
|
| 11 |
+
- batch_size: 512
|
| 12 |
+
learning_rate: 0.001
|
| 13 |
+
max_iters: 6400000
|
| 14 |
+
- batch_size: 512
|
| 15 |
+
learning_rate: 0.0005
|
| 16 |
+
max_iters: 12800000
|
| 17 |
+
- batch_size: 800
|
| 18 |
+
learning_rate: 0.0005
|
| 19 |
+
max_iters: 12800000
|
| 20 |
+
losses_weights:
|
| 21 |
+
clean: 1.0
|
| 22 |
+
noised: 1.0
|
| 23 |
+
response: 1.0
|
| 24 |
+
data:
|
| 25 |
+
response_functions_files:
|
| 26 |
+
- data/scrf_61_700MHz.pt
|
| 27 |
+
atom_groups_data_file: data/multiplets_10000_parsed.txt
|
| 28 |
+
response_function_stretch_min: 1.0
|
| 29 |
+
response_function_stretch_max: 1.0
|
| 30 |
+
response_function_noise: 0.0
|
| 31 |
+
multiplicity_j1_min: 0.0
|
| 32 |
+
multiplicity_j1_max: 15
|
| 33 |
+
multiplicity_j2_min: 0.0
|
| 34 |
+
multiplicity_j2_max: 15
|
| 35 |
+
number_of_signals_min: 2
|
| 36 |
+
number_of_signals_max: 5
|
| 37 |
+
thf_min: 0.5
|
| 38 |
+
thf_max: 2
|
| 39 |
+
relative_height_min: 0.5
|
| 40 |
+
relative_height_max: 4
|
| 41 |
+
frq_step: 0.34059797
|
| 42 |
+
logging:
|
| 43 |
+
step: 1000000
|
| 44 |
+
num_plots: 32
|
| 45 |
+
metadata:
|
| 46 |
+
spectrometer_frequency: 700.0
|
configs/shimnet_template.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: ShimNetWithSCRF
|
| 3 |
+
kwargs:
|
| 4 |
+
rensponse_length: 61
|
| 5 |
+
resnponse_head_dims:
|
| 6 |
+
- 128
|
| 7 |
+
training:
|
| 8 |
+
- batch_size: 64
|
| 9 |
+
learning_rate: 0.001
|
| 10 |
+
max_iters: 1600000
|
| 11 |
+
- batch_size: 512
|
| 12 |
+
learning_rate: 0.001
|
| 13 |
+
max_iters: 6400000
|
| 14 |
+
- batch_size: 512
|
| 15 |
+
learning_rate: 0.0005
|
| 16 |
+
max_iters: 12800000
|
| 17 |
+
- batch_size: 800
|
| 18 |
+
learning_rate: 0.0005
|
| 19 |
+
max_iters: 12800000
|
| 20 |
+
losses_weights:
|
| 21 |
+
clean: 1.0
|
| 22 |
+
noised: 1.0
|
| 23 |
+
response: 1.0
|
| 24 |
+
data:
|
| 25 |
+
response_functions_files:
|
| 26 |
+
# Specify the path to your SCRF file/files here.
|
| 27 |
+
# - It can be an absolute path.
|
| 28 |
+
# - It can be relative to the repository root.
|
| 29 |
+
# Multiple files can be listed in the following rows (YAML list format).
|
| 30 |
+
- path/to/srcf_file
|
| 31 |
+
atom_groups_data_file: data/multiplets_10000_parsed.txt
|
| 32 |
+
response_function_stretch_min: 1.0 # realtive
|
| 33 |
+
response_function_stretch_max: 1.0 # realtive
|
| 34 |
+
response_function_noise: 0.0 # arbitrary hight units
|
| 35 |
+
multiplicity_j1_min: 0.0 # Hz
|
| 36 |
+
multiplicity_j1_max: 15 # Hz
|
| 37 |
+
multiplicity_j2_min: 0.0 # Hz
|
| 38 |
+
multiplicity_j2_max: 15 # Hz
|
| 39 |
+
number_of_signals_min: 2
|
| 40 |
+
number_of_signals_max: 5
|
| 41 |
+
thf_min: 0.5 # arbitrary hight units
|
| 42 |
+
thf_max: 2 # arbitrary hight units
|
| 43 |
+
relative_height_min: 0.5 # realtive
|
| 44 |
+
relative_height_max: 4 # realtive
|
| 45 |
+
frq_step: 0.34059797 # Hz per point
|
| 46 |
+
logging:
|
| 47 |
+
step: 1000000
|
| 48 |
+
num_plots: 32
|
| 49 |
+
metadata: # additional metadata, not used in the training process
|
| 50 |
+
spectrometer_frequency: 700.0 # MHz
|
download_files.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import urllib.request
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
ALL_FILES_TO_DOWNLOAD = {
|
| 6 |
+
"weights": [{
|
| 7 |
+
"url": "https://drive.google.com/uc?export=download&id=17fTNWl7YW6mPbbZWga0EfdoF_6S8fCke",
|
| 8 |
+
"destination": "weights/shimnet_700MHz.pt"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"url": "https://drive.google.com/uc?export=download&id=1_VxOpFGJcFsOa5DHOW2GJbP8RvHCmC1N",
|
| 12 |
+
"destination": "weights/shimnet_600MHz.pt"
|
| 13 |
+
}],
|
| 14 |
+
"SCRF": [{
|
| 15 |
+
"url": "https://drive.google.com/uc?export=download&id=113al7A__yYALx_2hkESuzFIDU3feVtNY",
|
| 16 |
+
"destination": "data/scrf_61_700MHz.pt"
|
| 17 |
+
}],
|
| 18 |
+
"mupltiplets": [{
|
| 19 |
+
"url": "https://drive.google.com/uc?export=download&id=1QGvV-Au50ZxaP1vFsmR_auI299Dw-Wrt",
|
| 20 |
+
"destination": "data/multiplets_10000_parsed.txt"
|
| 21 |
+
}],
|
| 22 |
+
"development": []
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
def parse_args():
|
| 26 |
+
parser = argparse.ArgumentParser(
|
| 27 |
+
description='Download files: weighs (default), SCRF (optional), multiplet data (optional)',
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
'--weights',
|
| 31 |
+
action='store_true',
|
| 32 |
+
default=True,
|
| 33 |
+
help='Download weights file (default behavior). Use --no-weights to opt out.',
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
'--no-weights',
|
| 37 |
+
action='store_false',
|
| 38 |
+
dest='weights',
|
| 39 |
+
help='Do not download weights file.',
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument('--SCRF', action='store_true', help='Download SCRF file')
|
| 42 |
+
parser.add_argument('--multiplets', action='store_true', help='Download multiplets data file')
|
| 43 |
+
parser.add_argument('--development', action='store_true', help='Download development weights file')
|
| 44 |
+
|
| 45 |
+
parser.add_argument('--all', action='store_true', help='Download all available files')
|
| 46 |
+
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
# Set all individual flags if --all is specified
|
| 49 |
+
if args.all:
|
| 50 |
+
args.weights = True
|
| 51 |
+
args.SCRF = True
|
| 52 |
+
args.multiplets = True
|
| 53 |
+
args.development = True
|
| 54 |
+
|
| 55 |
+
return args
|
| 56 |
+
|
| 57 |
+
def download_file(url, target):
|
| 58 |
+
target = Path(target)
|
| 59 |
+
if target.exists():
|
| 60 |
+
response = input(f"File {target} already exists. Overwrite? (y/n): ")
|
| 61 |
+
if response.lower() != 'y':
|
| 62 |
+
print(f"Download of {target} cancelled")
|
| 63 |
+
return
|
| 64 |
+
target.parent.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
try:
|
| 66 |
+
urllib.request.urlretrieve(url, target)
|
| 67 |
+
print(f"Downloaded {target}")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Failed to download file from {url}:\n {e}")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
args = parse_args()
|
| 74 |
+
|
| 75 |
+
main_dir = Path(__file__).parent
|
| 76 |
+
if args.weights:
|
| 77 |
+
for file_data in ALL_FILES_TO_DOWNLOAD["weights"]:
|
| 78 |
+
download_file(file_data["url"], main_dir / file_data["destination"])
|
| 79 |
+
|
| 80 |
+
if args.SCRF:
|
| 81 |
+
for file_data in ALL_FILES_TO_DOWNLOAD["SCRF"]:
|
| 82 |
+
download_file(file_data["url"], main_dir / file_data["destination"])
|
| 83 |
+
|
| 84 |
+
if args.multiplets:
|
| 85 |
+
for file_data in ALL_FILES_TO_DOWNLOAD["mupltiplets"]:
|
| 86 |
+
download_file(file_data["url"], main_dir / file_data["destination"])
|
| 87 |
+
|
| 88 |
+
if args.development:
|
| 89 |
+
for file_data in ALL_FILES_TO_DOWNLOAD["development"]:
|
| 90 |
+
download_file(file_data["url"], main_dir / file_data["destination"])
|
extract_scrf_from_fids.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
import nmrglue as ng
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
# input
|
| 12 |
+
data_dir = "../../../dane/600Hz_20241227/loop"
|
| 13 |
+
opti_fid_path = "../../../dane/600Hz_20241227/opti"
|
| 14 |
+
|
| 15 |
+
# output
|
| 16 |
+
spectra_file = "data/600Hz_20241227-Lnative/total.npy"
|
| 17 |
+
spectra_file_names = "data/600Hz_20241227-Lnative/total.csv"
|
| 18 |
+
opi_spectrum_file = "data/600Hz_20241227-Lnative/opti.npy"
|
| 19 |
+
responses_file = "data/600Hz_20241227-Lnative/scrf_61.pt"
|
| 20 |
+
losses_file = "data/600Hz_20241227-Lnative/losses_scrf_61.pt"
|
| 21 |
+
|
| 22 |
+
# create directories for output files
|
| 23 |
+
for file_path in [spectra_file, spectra_file_names, opi_spectrum_file, responses_file]:
|
| 24 |
+
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
# settings: fid to spectrum
|
| 27 |
+
ph0_correction =-190.49
|
| 28 |
+
ph1_correction = 0
|
| 29 |
+
autophase_fn="acme"
|
| 30 |
+
target_length = None
|
| 31 |
+
|
| 32 |
+
# settings: SCRF extraction
|
| 33 |
+
calibration_peak_center = "auto"#12277 # the center of the calibration peak (peak of)
|
| 34 |
+
calibration_window_halfwidth = 128 # the half width of the calibration window
|
| 35 |
+
steps = 6000
|
| 36 |
+
kernel_size = 61
|
| 37 |
+
kernel_sqrt = False # True to allow only positive values
|
| 38 |
+
|
| 39 |
+
# fid to spectra processing
|
| 40 |
+
|
| 41 |
+
def fid_to_spectrum(varian_fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=None, sin_pod=False):
|
| 42 |
+
dic, data = ng.varian.read(varian_fid_path)
|
| 43 |
+
data[0] *= 0.5
|
| 44 |
+
if sin_pod:
|
| 45 |
+
data = ng.proc_base.sp(data, end=0.98)
|
| 46 |
+
|
| 47 |
+
if target_length is not None:
|
| 48 |
+
if (pad_length := target_length - len(data)) > 0:
|
| 49 |
+
data = ng.proc_base.zf(data, pad_length)
|
| 50 |
+
else:
|
| 51 |
+
data = data[:target_length]
|
| 52 |
+
|
| 53 |
+
spec=ng.proc_base.fft(data)
|
| 54 |
+
spec = ng.process.proc_autophase.autops(spec, autophase_fn, p0=ph0_correction, p1=ph1_correction, disp=False)
|
| 55 |
+
|
| 56 |
+
return spec
|
| 57 |
+
|
| 58 |
+
# process optimal measurement fid to spectrum
|
| 59 |
+
opti_spectrum_full = fid_to_spectrum(opti_fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=target_length)
|
| 60 |
+
|
| 61 |
+
if calibration_peak_center == "auto":
|
| 62 |
+
calibration_peak_center = np.argmax(abs(opti_spectrum_full))
|
| 63 |
+
fitting_range = (calibration_peak_center - calibration_window_halfwidth, calibration_peak_center+calibration_window_halfwidth+1)
|
| 64 |
+
|
| 65 |
+
opti_spectrum = opti_spectrum_full[fitting_range[0]:fitting_range[1]]
|
| 66 |
+
np.save(opi_spectrum_file, opti_spectrum)
|
| 67 |
+
print(f"Optimal spectrum extracted to {opi_spectrum_file}")
|
| 68 |
+
|
| 69 |
+
# process loop fids to spectra
|
| 70 |
+
spec_list=[]
|
| 71 |
+
spec_names=[]
|
| 72 |
+
|
| 73 |
+
print("Extracting spectra from fids...")
|
| 74 |
+
for fid_path in tqdm(list(Path(data_dir).rglob('*.fid'))):
|
| 75 |
+
spec = fid_to_spectrum(fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=target_length)[fitting_range[0]:fitting_range[1]]
|
| 76 |
+
|
| 77 |
+
spec_list.append(spec)
|
| 78 |
+
spec_names.append(fid_path.name)
|
| 79 |
+
|
| 80 |
+
total = np.array(spec_list)
|
| 81 |
+
np.save(spectra_file, total)
|
| 82 |
+
pd.DataFrame(spec_names).to_csv(spectra_file_names, header=False)
|
| 83 |
+
# total = np.load(spectra_file)
|
| 84 |
+
print(f"Spectra extracted to {spectra_file}")
|
| 85 |
+
|
| 86 |
+
# process SCRF extraction
|
| 87 |
+
def fit_kernel(base, target, kernel_size, kernel_sqrt=True, steps=20000, verbose=False):
|
| 88 |
+
|
| 89 |
+
kernel = torch.ones((1,1,kernel_size), dtype=base.dtype)
|
| 90 |
+
if kernel_sqrt:
|
| 91 |
+
kernel /= torch.sqrt(torch.sum(kernel**2))
|
| 92 |
+
else:
|
| 93 |
+
kernel /= kernel_size
|
| 94 |
+
kernel.requires_grad = True
|
| 95 |
+
|
| 96 |
+
optimizer = torch.optim.Adam([kernel])
|
| 97 |
+
|
| 98 |
+
for epoch in range(steps):
|
| 99 |
+
if kernel_sqrt:
|
| 100 |
+
spe_est = torch.conv1d(base, kernel**2, padding='same')
|
| 101 |
+
else:
|
| 102 |
+
spe_est = torch.conv1d(base, kernel, padding='same')
|
| 103 |
+
loss = torch.mean(abs(target - spe_est)**2) #torch.nn.functional.mse_loss(spe_est, target)
|
| 104 |
+
loss.backward()
|
| 105 |
+
optimizer.step()
|
| 106 |
+
optimizer.zero_grad()
|
| 107 |
+
|
| 108 |
+
if verbose and (epoch+1) % 100 == 0:
|
| 109 |
+
print(epoch, loss.item())
|
| 110 |
+
if kernel_sqrt:
|
| 111 |
+
return kernel.detach()**2, loss.item()
|
| 112 |
+
else:
|
| 113 |
+
return kernel.detach(), loss.item()
|
| 114 |
+
|
| 115 |
+
responses = torch.empty(len(total), 1, 1, 1, 1, kernel_size)
|
| 116 |
+
losses = torch.empty(len(total))
|
| 117 |
+
base = torch.tensor(opti_spectrum.real).unsqueeze(0)
|
| 118 |
+
targets = torch.tensor(total.real)
|
| 119 |
+
|
| 120 |
+
# normalization
|
| 121 |
+
base /= base.sum()
|
| 122 |
+
targets /= targets.sum(dim=(-1,), keepdim=True)
|
| 123 |
+
|
| 124 |
+
print("\nExtracting SCRFs...")
|
| 125 |
+
for i, target in tqdm(enumerate(targets), total=len(targets)):
|
| 126 |
+
kernel, loss = fit_kernel(base, target.unsqueeze(0), kernel_size, kernel_sqrt=kernel_sqrt, steps=steps)
|
| 127 |
+
responses[i, 0, 0] = kernel
|
| 128 |
+
losses[i] = loss
|
| 129 |
+
|
| 130 |
+
torch.save(responses, responses_file)
|
| 131 |
+
torch.save(losses, losses_file)
|
| 132 |
+
print(f"SCRFs extracted to {responses_file}, losses saved to {losses_file}")
|
predict.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
torch.set_grad_enabled(False)
|
| 3 |
+
import numpy as np
|
| 4 |
+
import argparse
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys, os
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
|
| 9 |
+
from src.models import ShimNetWithSCRF, Predictor
|
| 10 |
+
|
| 11 |
+
# silent deprecation warnings
|
| 12 |
+
# https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
|
| 13 |
+
import warnings
|
| 14 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 15 |
+
|
| 16 |
+
class Defaults:
|
| 17 |
+
SCALE = 16.0
|
| 18 |
+
|
| 19 |
+
def parse_args():
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument("input_files", help="Input files", nargs="+")
|
| 22 |
+
parser.add_argument("--config", help="config file .yaml")
|
| 23 |
+
parser.add_argument("--weights", help="model weights")
|
| 24 |
+
parser.add_argument("-o", "--output_dir", default=".", help="Output directory")
|
| 25 |
+
parser.add_argument("--input_spectrometer_frequency", default=None, type=float, help="spectrometer frequency in MHz (input sample collection frequency). Empty if the same as in the training data")
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
return args
|
| 28 |
+
|
| 29 |
+
# functions
|
| 30 |
+
def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
|
| 31 |
+
"""resample input spectrum to match the model's frequency range"""
|
| 32 |
+
freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point)
|
| 33 |
+
spectrum = np.interp(freqs, input_freqs, input_spectrum)
|
| 34 |
+
return freqs, spectrum
|
| 35 |
+
|
| 36 |
+
def resample_output_spectrum(input_freqs, freqs, prediction):
|
| 37 |
+
"""resample prediction to match the input spectrum's frequency range"""
|
| 38 |
+
prediction = np.interp(input_freqs, freqs, prediction)
|
| 39 |
+
return prediction
|
| 40 |
+
|
| 41 |
+
def initialize_predictor(config, weights_file):
|
| 42 |
+
model = ShimNetWithSCRF(**config.model.kwargs)
|
| 43 |
+
predictor = Predictor(model, weights_file)
|
| 44 |
+
return predictor
|
| 45 |
+
|
| 46 |
+
# run
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
args = parse_args()
|
| 49 |
+
output_dir = Path(args.output_dir)
|
| 50 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 51 |
+
|
| 52 |
+
config = OmegaConf.load(args.config)
|
| 53 |
+
model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
|
| 54 |
+
predictor = initialize_predictor(config, args.weights)
|
| 55 |
+
|
| 56 |
+
for input_file in args.input_files:
|
| 57 |
+
print(f"processing {input_file} ...")
|
| 58 |
+
|
| 59 |
+
# load data
|
| 60 |
+
input_data = np.loadtxt(input_file)
|
| 61 |
+
input_freqs_input_ppm, input_spectrum = input_data[:,0], input_data[:,1]
|
| 62 |
+
|
| 63 |
+
# convert input frequencies to model's frequency - correct for zero filling ad spectrometer frequency
|
| 64 |
+
if args.input_spectrometer_frequency is not None:
|
| 65 |
+
input_freqs_model_ppm = input_freqs_input_ppm * args.input_spectrometer_frequency / config.metadata.spectrometer_frequency
|
| 66 |
+
else:
|
| 67 |
+
input_freqs_model_ppm = input_freqs_input_ppm
|
| 68 |
+
|
| 69 |
+
freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point)
|
| 70 |
+
|
| 71 |
+
spectrum = torch.tensor(spectrum).float()
|
| 72 |
+
# scale height of the spectrum
|
| 73 |
+
scaling_factor = Defaults.SCALE / spectrum.max()
|
| 74 |
+
spectrum *= scaling_factor
|
| 75 |
+
|
| 76 |
+
# correct spectrum
|
| 77 |
+
prediction = predictor(spectrum).numpy()
|
| 78 |
+
|
| 79 |
+
# rescale height
|
| 80 |
+
prediction /= scaling_factor
|
| 81 |
+
|
| 82 |
+
# resample the output to match the input spectrum
|
| 83 |
+
output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)
|
| 84 |
+
|
| 85 |
+
# save result
|
| 86 |
+
output_file = output_dir / f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
|
| 87 |
+
|
| 88 |
+
np.savetxt(output_file, np.column_stack((input_freqs_input_ppm, output_prediction)))
|
| 89 |
+
print(f"saved to {output_file}")
|
requirements-cpu.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.4.1+cpu
|
| 2 |
+
torchaudio==2.4.1+cpu
|
| 3 |
+
nmrglue==0.11
|
| 4 |
+
torchdata==0.9.0
|
| 5 |
+
numpy==2.0.2
|
| 6 |
+
matplotlib==3.9.3
|
| 7 |
+
pandas==2.2.3
|
| 8 |
+
tqdm==4.67.1
|
| 9 |
+
hydra-core==1.3.2
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.4.1
|
| 2 |
+
torchaudio==2.4.1
|
| 3 |
+
nmrglue==0.11
|
| 4 |
+
torchdata==0.9.0
|
| 5 |
+
numpy==2.0.2
|
| 6 |
+
matplotlib==3.9.3
|
| 7 |
+
pandas==2.2.3
|
| 8 |
+
tqdm==4.67.1
|
| 9 |
+
hydra-core==1.3.2
|
sample_data/2ethylonaphthalene_bestshims_700MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/2ethylonaphthalene_up_1mm_700MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_20ul_700MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_X_supressed_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_Z1Z2Z3Z4_supressed_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_Z1Z2_supressed_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_besteshims_supressed_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_bestshims_700MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/CresolRed_after_styrene_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/CresolRed_bestshims_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Geraniol_bestshims_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Geraniol_up_1mm_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/generators.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torchdata
|
| 4 |
+
# from itertools import islice
|
| 5 |
+
|
| 6 |
+
def random_value(min_value, max_value):
|
| 7 |
+
return (min_value + torch.rand(1) * (max_value - min_value)).item()
|
| 8 |
+
|
| 9 |
+
def random_loguniform(min_value, max_value):
|
| 10 |
+
return (min_value * torch.exp(torch.rand(1) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
|
| 11 |
+
|
| 12 |
+
def calculate_theoretical_spectrum(peaks_parameters: dict, frq_frq:torch.Tensor):
|
| 13 |
+
# extract parameters
|
| 14 |
+
tff_lin = peaks_parameters["tff_lin"]
|
| 15 |
+
twf_lin = peaks_parameters["twf_lin"]
|
| 16 |
+
thf_lin = peaks_parameters["thf_lin"]
|
| 17 |
+
trf_lin = peaks_parameters["trf_lin"]
|
| 18 |
+
|
| 19 |
+
lwf_lin = twf_lin
|
| 20 |
+
lhf_lin = thf_lin * (1. - trf_lin)
|
| 21 |
+
gwf_lin = twf_lin
|
| 22 |
+
gdf_lin = gwf_lin / torch.tensor(2.).log().mul(2.).sqrt()
|
| 23 |
+
ghf_lin = thf_lin * trf_lin
|
| 24 |
+
# calculate Lorenz peaks contriubutions
|
| 25 |
+
lsf_linfrq = lwf_lin[:, None] ** 2 / (lwf_lin[:, None] ** 2 + (frq_frq - tff_lin[:, None]) ** 2) * lhf_lin[:, None]
|
| 26 |
+
# calculate Gaussian peaks contriubutions
|
| 27 |
+
gsf_linfrq = torch.exp(-(frq_frq - tff_lin[:, None]) ** 2 / gdf_lin[:, None] ** 2 / 2.) * ghf_lin[:, None]
|
| 28 |
+
tsf_linfrq = lsf_linfrq + gsf_linfrq
|
| 29 |
+
# sum peaks contriubutions
|
| 30 |
+
tsf_frq = tsf_linfrq.sum(0, keepdim = True)
|
| 31 |
+
return tsf_frq
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
pascal_triangle = [(1,), (1,1), (1,2,1), (1,3,3,1), (1,4,6,4,1), (1,5,10,10,5,1), (1,6,15,20,15,6,1), (1,7, 21,35,35,21,7,1)]
|
| 35 |
+
normalized_pascal_triangle = [torch.tensor(x)/sum(x) for x in pascal_triangle]
|
| 36 |
+
|
| 37 |
+
def pascal_multiplicity(multiplicity):
|
| 38 |
+
intensities = normalized_pascal_triangle[multiplicity-1]
|
| 39 |
+
n_peaks = len(intensities)
|
| 40 |
+
shifts = torch.arange(n_peaks)-((n_peaks-1)/2)
|
| 41 |
+
return shifts, intensities
|
| 42 |
+
|
| 43 |
+
def double_multiplicity(multiplicity1, multiplicity2, j1=1, j2=1):
|
| 44 |
+
shifts1, intensities1 = pascal_multiplicity(multiplicity1)
|
| 45 |
+
shifts2, intensities2 = pascal_multiplicity(multiplicity2)
|
| 46 |
+
|
| 47 |
+
shifts = (j1*shifts1.reshape(-1,1) + j2*shifts2.reshape(1,-1)).flatten()
|
| 48 |
+
intensities = (intensities1.reshape(-1,1) * intensities2.reshape(1,-1)).flatten()
|
| 49 |
+
return shifts, intensities
|
| 50 |
+
|
| 51 |
+
def generate_multiplet_parameters(multiplicity, tff_lin, thf_lin, twf_lin, trf_lin, j1, j2):
|
| 52 |
+
shifts, intensities = double_multiplicity(multiplicity[0], multiplicity[1], j1, j2)
|
| 53 |
+
n_peaks = len(shifts)
|
| 54 |
+
|
| 55 |
+
return {
|
| 56 |
+
"tff_lin": shifts + tff_lin,
|
| 57 |
+
"thf_lin": intensities * thf_lin,
|
| 58 |
+
"twf_lin": torch.full((n_peaks,), twf_lin),
|
| 59 |
+
"trf_lin": torch.full((n_peaks,), trf_lin),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
def value_to_index(values, table):
|
| 63 |
+
span = table[-1] - table[0]
|
| 64 |
+
indices = ((values - table[0])/span * (len(table)-1)) #.round().type(torch.int64)
|
| 65 |
+
return indices
|
| 66 |
+
|
| 67 |
+
def generate_theoretical_spectrum(
|
| 68 |
+
number_of_signals_min, number_of_signals_max,
|
| 69 |
+
spectrum_width_min, spectrum_width_max,
|
| 70 |
+
relative_width_min, relative_width_max,
|
| 71 |
+
tff_min, tff_max,
|
| 72 |
+
thf_min, thf_max,
|
| 73 |
+
trf_min, trf_max,
|
| 74 |
+
relative_height_min, relative_height_max,
|
| 75 |
+
multiplicity_j1_min, multiplicity_j1_max,
|
| 76 |
+
multiplicity_j2_min, multiplicity_j2_max,
|
| 77 |
+
atom_groups_data,
|
| 78 |
+
frq_frq
|
| 79 |
+
):
|
| 80 |
+
number_of_signals = torch.randint(number_of_signals_min, number_of_signals_max+1, [])
|
| 81 |
+
atom_group_indices = torch.randint(0, len(atom_groups_data), [number_of_signals])
|
| 82 |
+
width_spectrum = random_loguniform(spectrum_width_min, spectrum_width_max)
|
| 83 |
+
height_spectrum = random_loguniform(thf_min, thf_max)
|
| 84 |
+
|
| 85 |
+
peak_parameters_data = []
|
| 86 |
+
theoretical_spectrum = None
|
| 87 |
+
for atom_group_index in atom_group_indices:
|
| 88 |
+
relative_intensity, multiplicity1, multiplicity2 = atom_groups_data[atom_group_index]
|
| 89 |
+
position = random_value(tff_min, tff_max)
|
| 90 |
+
j1 = random_value(multiplicity_j1_min, multiplicity_j1_max)
|
| 91 |
+
j2 = random_value(multiplicity_j2_min, multiplicity_j2_max)
|
| 92 |
+
width = width_spectrum*random_loguniform(relative_width_min, relative_width_max)
|
| 93 |
+
height = height_spectrum*relative_intensity*random_loguniform(relative_height_min, relative_height_max)
|
| 94 |
+
gaussian_contribution = random_value(trf_min, trf_max)
|
| 95 |
+
|
| 96 |
+
peaks_parameters = generate_multiplet_parameters(multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin= width, trf_lin= gaussian_contribution, j1=j1, j2=j2)
|
| 97 |
+
peaks_parameters["tff_relative"] = value_to_index(peaks_parameters["tff_lin"], frq_frq)
|
| 98 |
+
peak_parameters_data.append(peaks_parameters)
|
| 99 |
+
spectrum_contribution = calculate_theoretical_spectrum(peaks_parameters, frq_frq)
|
| 100 |
+
if theoretical_spectrum is None:
|
| 101 |
+
theoretical_spectrum = spectrum_contribution
|
| 102 |
+
else:
|
| 103 |
+
theoretical_spectrum += spectrum_contribution
|
| 104 |
+
return theoretical_spectrum, peak_parameters_data
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def theoretical_generator(
|
| 108 |
+
atom_groups_data,
|
| 109 |
+
pixels=2048, frq_step=11160.7142857 / 32768,
|
| 110 |
+
number_of_signals_min=1, number_of_signals_max=8,
|
| 111 |
+
spectrum_width_min=0.2, spectrum_width_max=1,
|
| 112 |
+
relative_width_min=1, relative_width_max=2,
|
| 113 |
+
relative_height_min=1, relative_height_max=1,
|
| 114 |
+
relative_frequency_min=-0.4, relative_frequency_max=0.4,
|
| 115 |
+
thf_min=1/16, thf_max=16,
|
| 116 |
+
trf_min=0, trf_max=1,
|
| 117 |
+
multiplicity_j1_min=0, multiplicity_j1_max=15,
|
| 118 |
+
multiplicity_j2_min=0, multiplicity_j2_max=15,
|
| 119 |
+
):
|
| 120 |
+
tff_min = relative_frequency_min * pixels * frq_step
|
| 121 |
+
tff_max = relative_frequency_max * pixels * frq_step
|
| 122 |
+
frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step
|
| 123 |
+
|
| 124 |
+
while True:
|
| 125 |
+
yield generate_theoretical_spectrum(
|
| 126 |
+
number_of_signals_min=number_of_signals_min,
|
| 127 |
+
number_of_signals_max=number_of_signals_max,
|
| 128 |
+
spectrum_width_min=spectrum_width_min,
|
| 129 |
+
spectrum_width_max=spectrum_width_max,
|
| 130 |
+
relative_width_min=relative_width_min,
|
| 131 |
+
relative_width_max=relative_width_max,
|
| 132 |
+
relative_height_min=relative_height_min,
|
| 133 |
+
relative_height_max=relative_height_max,
|
| 134 |
+
tff_min=tff_min, tff_max=tff_max,
|
| 135 |
+
thf_min=thf_min, thf_max=thf_max,
|
| 136 |
+
trf_min=trf_min, trf_max=trf_max,
|
| 137 |
+
multiplicity_j1_min=multiplicity_j1_min,
|
| 138 |
+
multiplicity_j1_max=multiplicity_j1_max,
|
| 139 |
+
multiplicity_j2_min=multiplicity_j2_min,
|
| 140 |
+
multiplicity_j2_max=multiplicity_j2_max,
|
| 141 |
+
atom_groups_data=atom_groups_data,
|
| 142 |
+
frq_frq=frq_frq
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
class ResponseLibrary:
|
| 146 |
+
def __init__(self, reponse_files, normalize=True):
|
| 147 |
+
self.data = [torch.load(f, map_location='cpu', weights_only=True).flatten(0,-4) for f in reponse_files]
|
| 148 |
+
if normalize:
|
| 149 |
+
self.data = [data/torch.sum(data, dim=(-1,), keepdim=True) for data in self.data]
|
| 150 |
+
lengths = [len(data) for data in self.data]
|
| 151 |
+
self.start_indices = torch.cumsum(torch.tensor([0] + lengths[:-1]), 0)
|
| 152 |
+
self.total_length = sum(lengths)
|
| 153 |
+
|
| 154 |
+
def __getitem__(self, idx):
|
| 155 |
+
if idx >= self.total_length:
|
| 156 |
+
raise ValueError(f'index {idx} out of range')
|
| 157 |
+
tensor_index = torch.searchsorted(self.start_indices, idx, right=True) - 1
|
| 158 |
+
return self.data[tensor_index][idx - self.start_indices[tensor_index]]
|
| 159 |
+
|
| 160 |
+
def __len__(self):
|
| 161 |
+
return self.total_length
|
| 162 |
+
|
| 163 |
+
def generator(
|
| 164 |
+
theoretical_generator_params,
|
| 165 |
+
response_function_library,
|
| 166 |
+
response_function_stretch_min=0.5,
|
| 167 |
+
response_function_stretch_max=2.0,
|
| 168 |
+
response_function_noise=0.,
|
| 169 |
+
spectrum_noise_min=0.,
|
| 170 |
+
spectrum_noise_max=1/64,
|
| 171 |
+
include_spectrum_data=False,
|
| 172 |
+
include_peak_mask=False,
|
| 173 |
+
include_response_function=False,
|
| 174 |
+
flip_response_function=False
|
| 175 |
+
|
| 176 |
+
):
|
| 177 |
+
for theoretical_spectrum, theoretical_spectrum_data in theoretical_generator(**theoretical_generator_params):
|
| 178 |
+
# get response function
|
| 179 |
+
response_function = response_function_library[torch.randint(0, len(response_function_library), [1])][0]
|
| 180 |
+
# stretch response function
|
| 181 |
+
padding_size = (response_function.shape[-1] - 1)//2
|
| 182 |
+
padding_size = round(random_loguniform(response_function_stretch_min, response_function_stretch_max)*padding_size) #torch.randint(round(padding_size*response_function_stretch_min), round(padding_size*response_function_stretch_max), [1]).item()
|
| 183 |
+
response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
|
| 184 |
+
response_function /= response_function.sum() # normalize sum of response function to 1
|
| 185 |
+
# add noise to response function
|
| 186 |
+
response_function += torch.randn(response_function.shape) * response_function_noise
|
| 187 |
+
response_function /= response_function.sum() # normalize sum of response function to 1
|
| 188 |
+
if flip_response_function and (torch.rand(1).item() < 0.5):
|
| 189 |
+
response_function = response_function.flip(-1)
|
| 190 |
+
# disturbed spectrum
|
| 191 |
+
disturbed_spectrum = torch.nn.functional.conv1d(theoretical_spectrum, response_function, padding=padding_size)
|
| 192 |
+
# add noise
|
| 193 |
+
noised_spectrum = disturbed_spectrum + torch.randn(disturbed_spectrum.shape) * random_value(spectrum_noise_min, spectrum_noise_max)
|
| 194 |
+
|
| 195 |
+
out = {
|
| 196 |
+
# 'response_function': response_function,
|
| 197 |
+
'theoretical_spectrum': theoretical_spectrum,
|
| 198 |
+
'disturbed_spectrum': disturbed_spectrum,
|
| 199 |
+
'noised_spectrum': noised_spectrum,
|
| 200 |
+
}
|
| 201 |
+
if include_response_function:
|
| 202 |
+
out['response_function'] = response_function
|
| 203 |
+
if include_spectrum_data:
|
| 204 |
+
out["theoretical_spectrum_data"] = theoretical_spectrum_data
|
| 205 |
+
if include_peak_mask:
|
| 206 |
+
all_peaks_rel = torch.cat([peak_data["tff_relative"] for peak_data in theoretical_spectrum_data])
|
| 207 |
+
peaks_indices = all_peaks_rel.round().type(torch.int64)
|
| 208 |
+
out["peaks_mask"] = torch.scatter(torch.zeros(out["theoretical_spectrum"].shape[1]), 0, peaks_indices, 1.).unsqueeze(0)
|
| 209 |
+
|
| 210 |
+
yield out
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def collate_with_spectrum_data(batch):
|
| 214 |
+
tensor_keys = set(batch[0].keys())
|
| 215 |
+
tensor_keys.remove('theoretical_spectrum_data')
|
| 216 |
+
out = {k: torch.stack([item[k] for item in batch]) for k in tensor_keys}
|
| 217 |
+
out["theoretical_spectrum_data"] = [item["theoretical_spectrum_data"] for item in batch]
|
| 218 |
+
return out
|
| 219 |
+
|
| 220 |
+
def get_datapipe(
|
| 221 |
+
response_functions_files,
|
| 222 |
+
atom_groups_data_file=None,
|
| 223 |
+
batch_size=64,
|
| 224 |
+
pixels=2048, frq_step=11160.7142857 / 32768,
|
| 225 |
+
number_of_signals_min=1, number_of_signals_max=8,
|
| 226 |
+
spectrum_width_min=0.2, spectrum_width_max=1,
|
| 227 |
+
relative_width_min=1, relative_width_max=2,
|
| 228 |
+
relative_height_min=1, relative_height_max=1,
|
| 229 |
+
relative_frequency_min=-0.4, relative_frequency_max=0.4,
|
| 230 |
+
thf_min=1/16, thf_max=16,
|
| 231 |
+
trf_min=0, trf_max=1,
|
| 232 |
+
multiplicity_j1_min=0, multiplicity_j1_max=15,
|
| 233 |
+
multiplicity_j2_min=0, multiplicity_j2_max=15,
|
| 234 |
+
response_function_stretch_min=0.5,
|
| 235 |
+
response_function_stretch_max=2.0,
|
| 236 |
+
response_function_noise=0.,
|
| 237 |
+
spectrum_noise_min=0.,
|
| 238 |
+
spectrum_noise_max=1/64,
|
| 239 |
+
include_spectrum_data=False,
|
| 240 |
+
include_peak_mask=False,
|
| 241 |
+
include_response_function=False,
|
| 242 |
+
flip_response_function=False
|
| 243 |
+
):
|
| 244 |
+
# singlets
|
| 245 |
+
if atom_groups_data_file is None:
|
| 246 |
+
atom_groups_data = np.ones((1,3), dtype=int)
|
| 247 |
+
else:
|
| 248 |
+
atom_groups_data = np.loadtxt(atom_groups_data_file, usecols=(1,2,3), dtype=int)
|
| 249 |
+
response_function_library = ResponseLibrary(response_functions_files)
|
| 250 |
+
g = generator(
|
| 251 |
+
theoretical_generator_params=dict(
|
| 252 |
+
atom_groups_data=atom_groups_data,
|
| 253 |
+
pixels=pixels, frq_step=frq_step,
|
| 254 |
+
number_of_signals_min=number_of_signals_min, number_of_signals_max=number_of_signals_max,
|
| 255 |
+
spectrum_width_min=spectrum_width_min, spectrum_width_max=spectrum_width_max,
|
| 256 |
+
relative_width_min=relative_width_min, relative_width_max=relative_width_max,
|
| 257 |
+
relative_height_min=relative_height_min, relative_height_max=relative_height_max,
|
| 258 |
+
relative_frequency_min=relative_frequency_min, relative_frequency_max=relative_frequency_max,
|
| 259 |
+
thf_min=thf_min, thf_max=thf_max,
|
| 260 |
+
trf_min=trf_min, trf_max=trf_max,
|
| 261 |
+
multiplicity_j1_min=multiplicity_j1_min, multiplicity_j1_max=multiplicity_j1_max,
|
| 262 |
+
multiplicity_j2_min=multiplicity_j2_min, multiplicity_j2_max=multiplicity_j2_max
|
| 263 |
+
),
|
| 264 |
+
response_function_library=response_function_library,
|
| 265 |
+
response_function_stretch_min=response_function_stretch_min,
|
| 266 |
+
response_function_stretch_max=response_function_stretch_max,
|
| 267 |
+
response_function_noise=response_function_noise,
|
| 268 |
+
spectrum_noise_min=spectrum_noise_min,
|
| 269 |
+
spectrum_noise_max=spectrum_noise_max,
|
| 270 |
+
include_spectrum_data=include_spectrum_data,
|
| 271 |
+
include_peak_mask=include_peak_mask,
|
| 272 |
+
include_response_function=include_response_function,
|
| 273 |
+
flip_response_function=flip_response_function
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
pipe = torchdata.datapipes.iter.IterableWrapper(g, deepcopy=False)
|
| 277 |
+
pipe = pipe.batch(batch_size)
|
| 278 |
+
pipe = pipe.collate(collate_fn=collate_with_spectrum_data if include_spectrum_data else None)
|
| 279 |
+
|
| 280 |
+
return pipe
|
src/models.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
class ConvEncoder(torch.nn.Module):
|
| 4 |
+
def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
|
| 5 |
+
super().__init__()
|
| 6 |
+
if output_dim is None:
|
| 7 |
+
output_dim = hidden_dim
|
| 8 |
+
self.conv4 = torch.nn.Conv1d(1, hidden_dim, kernel_size)
|
| 9 |
+
self.conv3 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
|
| 10 |
+
self.conv2 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
|
| 11 |
+
self.conv1 = torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
|
| 12 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 13 |
+
|
| 14 |
+
def forward(self, feature): #(samples, 1, 2048)
|
| 15 |
+
feature = self.dropout(self.conv4(feature)) #(samples, 64, 2042)
|
| 16 |
+
feature = feature.relu()
|
| 17 |
+
feature = self.dropout(self.conv3(feature)) #(samples, 64, 2036)
|
| 18 |
+
feature = feature.relu()
|
| 19 |
+
feature = self.dropout(self.conv2(feature)) #(samples, 64, 2030)
|
| 20 |
+
feature = feature.relu()
|
| 21 |
+
feature = self.dropout(self.conv1(feature)) #(samples, 64, 2024)
|
| 22 |
+
return feature
|
| 23 |
+
|
| 24 |
+
class ConvDecoder(torch.nn.Module):
|
| 25 |
+
def __init__(self, input_dim=None, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
|
| 26 |
+
super().__init__()
|
| 27 |
+
if output_dim is None:
|
| 28 |
+
output_dim = hidden_dim
|
| 29 |
+
self.convTranspose1 = torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size)
|
| 30 |
+
self.convTranspose2 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
|
| 31 |
+
self.convTranspose3 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
|
| 32 |
+
self.convTranspose4 = torch.nn.ConvTranspose1d(hidden_dim, 1, kernel_size)
|
| 33 |
+
|
| 34 |
+
def forward(self, feature): #(samples, 1, 2048)
|
| 35 |
+
feature = self.convTranspose1(feature) #(samples, 64, 2030)
|
| 36 |
+
feature = feature.relu()
|
| 37 |
+
feature = self.convTranspose2(feature) #(samples, 64, 2036)
|
| 38 |
+
feature = feature.relu()
|
| 39 |
+
feature = self.convTranspose3(feature) #(samples, 64, 2042)
|
| 40 |
+
feature = feature.relu()
|
| 41 |
+
feature = self.convTranspose4(feature)
|
| 42 |
+
return feature
|
| 43 |
+
|
| 44 |
+
class ResponseHead(torch.nn.Module):
|
| 45 |
+
def __init__(self, input_dim, output_length, hidden_dims=[128]):
|
| 46 |
+
super().__init__()
|
| 47 |
+
response_head_dims = [input_dim]+hidden_dims + [output_length]
|
| 48 |
+
response_head_layers = [torch.nn.Linear(response_head_dims[0], response_head_dims[1])]
|
| 49 |
+
for dims_in, dims_out in zip(response_head_dims[1:-1], response_head_dims[2:]):
|
| 50 |
+
response_head_layers.extend([
|
| 51 |
+
torch.nn.GELU(),
|
| 52 |
+
torch.nn.Linear(dims_in, dims_out)
|
| 53 |
+
])
|
| 54 |
+
self.response_head = torch.nn.Sequential(*response_head_layers)
|
| 55 |
+
|
| 56 |
+
def forward(self, feature):
|
| 57 |
+
return self.response_head(feature)
|
| 58 |
+
|
| 59 |
+
class ShimNetWithSCRF(torch.nn.Module):
|
| 60 |
+
def __init__(self,
|
| 61 |
+
encoder_hidden_dims=64,
|
| 62 |
+
encoder_dropout=0,
|
| 63 |
+
bottleneck_dim=64,
|
| 64 |
+
rensponse_length=61,
|
| 65 |
+
resnponse_head_dims=[128],
|
| 66 |
+
decoder_hidden_dims=64
|
| 67 |
+
):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout)
|
| 70 |
+
self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim))
|
| 71 |
+
torch.nn.init.xavier_normal_(self.query)
|
| 72 |
+
|
| 73 |
+
self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims)
|
| 74 |
+
|
| 75 |
+
self.rensponse_length = rensponse_length
|
| 76 |
+
self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims)
|
| 77 |
+
|
| 78 |
+
def forward(self, feature): #(samples, 1, 2048)
|
| 79 |
+
feature = self.encoder(feature) #(samples, 64, 2042)
|
| 80 |
+
energy = self.query @ feature #(samples, 1, 2024)
|
| 81 |
+
weight = torch.nn.functional.softmax(energy, 2) #(samples, 1, 2024)
|
| 82 |
+
global_features = feature @ weight.transpose(1, 2) #(samples, 64, 1)
|
| 83 |
+
|
| 84 |
+
response = self.response_head(global_features.squeeze(-1))
|
| 85 |
+
|
| 86 |
+
feature, global_features = torch.broadcast_tensors(feature, global_features) #(samples, 64, 2048)
|
| 87 |
+
feature = torch.cat([feature, global_features], 1) #(samples, 128, 2024)
|
| 88 |
+
denoised_spectrum = self.decoder(feature) #(samples, 1, 2048)
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
'denoised': denoised_spectrum,
|
| 92 |
+
'response': response,
|
| 93 |
+
'attention': weight.squeeze(1)
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
class Predictor:
|
| 97 |
+
def __init__(self, model=None, weights_file=None):
|
| 98 |
+
self.model = model
|
| 99 |
+
if weights_file is not None:
|
| 100 |
+
self.model.load_state_dict(torch.load(weights_file, map_location='cpu', weights_only=True))
|
| 101 |
+
|
| 102 |
+
def __call__(self, nsf_frq):
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
msf_frq = self.model(nsf_frq[None, None])["denoised"]
|
| 105 |
+
return msf_frq[0, 0]
|
train.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, torchaudio
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from hydra.utils import instantiate
|
| 6 |
+
import datetime
|
| 7 |
+
import sys
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import matplotlib
|
| 12 |
+
matplotlib.use('Agg')
|
| 13 |
+
|
| 14 |
+
# silent deprecation_warning() from datapipes
|
| 15 |
+
import warnings
|
| 16 |
+
warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
|
| 17 |
+
|
| 18 |
+
from src import models
|
| 19 |
+
from src.generators import get_datapipe
|
| 20 |
+
|
| 21 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 22 |
+
if len(sys.argv) < 2:
|
| 23 |
+
print("Please provide the run directory as an argument.")
|
| 24 |
+
sys.exit(1)
|
| 25 |
+
run_dir = Path(sys.argv[1])
|
| 26 |
+
|
| 27 |
+
config = OmegaConf.load(run_dir / "config.yaml")
|
| 28 |
+
|
| 29 |
+
if (run_dir / "train.txt").is_file():
|
| 30 |
+
minimum = np.min(np.loadtxt(run_dir / "train.txt")[:,2])
|
| 31 |
+
else:
|
| 32 |
+
minimum = float("inf")
|
| 33 |
+
|
| 34 |
+
# initialization
|
| 35 |
+
model = instantiate({"_target_": f"__main__.models.{config.model.name}", **config.model.kwargs}).to(device)
|
| 36 |
+
model_weights_file = run_dir / f'model.pt'
|
| 37 |
+
optimizer = torch.optim.Adam(model.parameters())
|
| 38 |
+
optimizer_weights_file = run_dir / f'optimizer.pt'
|
| 39 |
+
|
| 40 |
+
def evaluate_model(stage=0, epoch=0):
|
| 41 |
+
plot_dir = run_dir / "plots" / f"{stage}_{epoch}"
|
| 42 |
+
plot_dir.mkdir(exist_ok=True, parents=True)
|
| 43 |
+
|
| 44 |
+
torch.save(model.state_dict(), plot_dir / "model.pt")
|
| 45 |
+
torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt")
|
| 46 |
+
|
| 47 |
+
num_plots = config.logging.num_plots
|
| 48 |
+
pipe = get_datapipe(
|
| 49 |
+
**config.data,
|
| 50 |
+
include_response_function=True,
|
| 51 |
+
batch_size=num_plots
|
| 52 |
+
)
|
| 53 |
+
batch = next(iter(pipe))
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
out = model(batch['noised_spectrum'].to(device))
|
| 57 |
+
noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu()
|
| 58 |
+
|
| 59 |
+
for i in range(num_plots):
|
| 60 |
+
plt.figure(figsize=(30,6))
|
| 61 |
+
plt.plot(batch['theoretical_spectrum'].cpu().numpy()[i,0])
|
| 62 |
+
plt.plot(out['denoised'].cpu().numpy()[i,0])
|
| 63 |
+
plt.savefig(plot_dir / f"{i:03d}_spectrum_clean.png")
|
| 64 |
+
|
| 65 |
+
plt.figure(figsize=(30,6))
|
| 66 |
+
plt.plot(batch['noised_spectrum'].cpu().numpy()[i,0])
|
| 67 |
+
plt.plot(noised_est.cpu().numpy()[i,0])
|
| 68 |
+
plt.savefig(plot_dir / f"{i:03d}_spectrum_noise.png")
|
| 69 |
+
|
| 70 |
+
plt.figure(figsize=(10,6))
|
| 71 |
+
plt.plot(batch['response_function'].cpu().numpy()[i,0,0])
|
| 72 |
+
plt.plot(out['response'].cpu().numpy()[i])
|
| 73 |
+
plt.savefig(plot_dir / f"{i:03d}_response.png")
|
| 74 |
+
|
| 75 |
+
if "attention" in out:
|
| 76 |
+
plt.figure(figsize=(10, 6))
|
| 77 |
+
plt.plot(out['attention'].cpu().numpy()[i])
|
| 78 |
+
plt.savefig(plot_dir / f"{i:03d}_attention.png")
|
| 79 |
+
|
| 80 |
+
plt.close("all")
|
| 81 |
+
|
| 82 |
+
for i_stage, training_stage in enumerate(config.training):
|
| 83 |
+
if model_weights_file.is_file():
|
| 84 |
+
model.load_state_dict(torch.load(model_weights_file, weights_only=True))
|
| 85 |
+
|
| 86 |
+
if optimizer_weights_file.is_file():
|
| 87 |
+
optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True))
|
| 88 |
+
optimizer.param_groups[0]['lr'] = training_stage.learning_rate
|
| 89 |
+
|
| 90 |
+
pipe = get_datapipe(
|
| 91 |
+
**config.data,
|
| 92 |
+
include_response_function=True,
|
| 93 |
+
batch_size=training_stage.batch_size
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
losses_history = []
|
| 97 |
+
losses_history_limit = 64*100 // training_stage.batch_size
|
| 98 |
+
|
| 99 |
+
last_evaluation = 0
|
| 100 |
+
for epoch, batch in pipe.enumerate():
|
| 101 |
+
|
| 102 |
+
# logging
|
| 103 |
+
iters_done = epoch*training_stage.batch_size
|
| 104 |
+
if (iters_done - last_evaluation) > config.logging.step:
|
| 105 |
+
evaluate_model(i_stage, epoch)
|
| 106 |
+
last_evaluation = iters_done
|
| 107 |
+
|
| 108 |
+
if iters_done > training_stage.max_iters:
|
| 109 |
+
evaluate_model(i_stage, epoch)
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
# run model
|
| 113 |
+
out = model(batch['noised_spectrum'].to(device))
|
| 114 |
+
# calculate losses
|
| 115 |
+
loss_response = torch.nn.functional.mse_loss(out['response'], batch['response_function'].squeeze(dim=(1,2)).to(device))
|
| 116 |
+
loss_clean = torch.nn.functional.mse_loss(out['denoised'], batch['theoretical_spectrum'].to(device))
|
| 117 |
+
noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same")
|
| 118 |
+
loss_noised = torch.nn.functional.mse_loss(noised_est, batch['noised_spectrum'].to(device))
|
| 119 |
+
loss = config.losses_weights.response*loss_response + config.losses_weights.clean*loss_clean + config.losses_weights.noised*loss_noised
|
| 120 |
+
|
| 121 |
+
# logging
|
| 122 |
+
losses_history.append(loss_clean.item())
|
| 123 |
+
losses_history = losses_history[-losses_history_limit:]
|
| 124 |
+
loss_avg = sum(losses_history)/len(losses_history)
|
| 125 |
+
message = f"{epoch:7d} {loss:0.3e} {loss_avg:0.3e} {loss_clean:0.3e} {loss_response:0.3e} {loss_noised:0.3e}"
|
| 126 |
+
# message = '%7i %.3e %.3e %.3e' % (epoch, loss, regress, classify)
|
| 127 |
+
with open(run_dir / f'train.txt', 'a') as f:
|
| 128 |
+
f.write(message + '\n')
|
| 129 |
+
print(message, flush = True)
|
| 130 |
+
|
| 131 |
+
# save best
|
| 132 |
+
if loss_avg < minimum:
|
| 133 |
+
minimum = loss_avg
|
| 134 |
+
torch.save(model.state_dict(), model_weights_file)
|
| 135 |
+
torch.save(optimizer.state_dict(),optimizer_weights_file)
|
| 136 |
+
|
| 137 |
+
# update weights
|
| 138 |
+
loss.backward()
|
| 139 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
|
| 140 |
+
optimizer.step()
|
| 141 |
+
optimizer.zero_grad()
|