Spaces:
Sleeping
Sleeping
Marek Bukowicki commited on
Commit ·
64b4096
1
Parent(s): 12f8c5c
add shimnet code
Browse files- Dockerfile +13 -0
- LICENSE +21 -0
- README.md +221 -1
- calibration_loop/maclib/shimator_loop +33 -0
- calibration_loop/sweep_shims_lineshape_Z1Z2.py +39 -0
- configs/shimnet_600.yaml +46 -0
- configs/shimnet_700.yaml +46 -0
- configs/shimnet_template.yaml +50 -0
- download_files.py +95 -0
- extract_scrf_from_fids.py +132 -0
- predict-gui.py +197 -0
- predict.py +89 -0
- requirements-cpu.txt +9 -0
- requirements-gpu.txt +9 -0
- requirements-gui.txt +2 -0
- sample_data/2ethylonaphthalene_bestshims_700MHz.csv +0 -0
- sample_data/2ethylonaphthalene_up_1mm_700MHz.csv +0 -0
- sample_data/Azarone_20ul_700MHz.csv +0 -0
- sample_data/Azarone_X_supressed_600MHz.csv +0 -0
- sample_data/Azarone_Z1Z2Z3Z4_supressed_600MHz.csv +0 -0
- sample_data/Azarone_Z1Z2_supressed_600MHz.csv +0 -0
- sample_data/Azarone_besteshims_supressed_600MHz.csv +0 -0
- sample_data/Azarone_bestshims_700MHz.csv +0 -0
- sample_data/CresolRed_after_styrene_600MHz.csv +0 -0
- sample_data/CresolRed_after_styrene_cut_600MHz.csv +0 -0
- sample_data/CresolRed_bestshims_600MHz.csv +0 -0
- sample_data/CresolRed_cut_bestshims_600MHz.csv +0 -0
- sample_data/Geraniol_bestshims_600MHz.csv +0 -0
- sample_data/Geraniol_up_1mm_600MHz.csv +0 -0
- sample_data/SodiumButyrate_after_glucose_cut_700MHz.csv +0 -0
- sample_data/SodiumButyrate_cut_bestshims_700MHz.csv +0 -0
- src/generators.py +280 -0
- src/models.py +105 -0
- train.py +141 -0
- weights/shimnet_600MHz.pt +3 -0
- weights/shimnet_700MHz.pt +3 -0
Dockerfile
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10
|
| 2 |
+
|
| 3 |
+
WORKDIR /usr/src/app
|
| 4 |
+
|
| 5 |
+
COPY requirements-cpu.txt requirements-gui.txt ./
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements-cpu.txt -r requirements-gui.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
| 7 |
+
|
| 8 |
+
COPY . .
|
| 9 |
+
|
| 10 |
+
RUN python download_files.py --overwrite
|
| 11 |
+
|
| 12 |
+
CMD [ "python", "./predict-gui.py", "--server_name", "0.0.0.0" ]
|
| 13 |
+
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 center4ml
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -9,4 +9,224 @@ license: mit
|
|
| 9 |
short_description: ShimNet Spectra Correction
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
short_description: ShimNet Spectra Correction
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# ShimNet
|
| 13 |
+
ShimNet is a data-driven AI solution to improve high-resolution nuclear magnetic resonance (NMR) spectra
|
| 14 |
+
distorted by the inhomogeneous magnetic field (less than optimal shimming). To use it, the experimental training data has to be collected (see **Data collection** below).
|
| 15 |
+
Example data can also be downloaded (see below).
|
| 16 |
+
|
| 17 |
+
Paper: [ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://chemrxiv.org/engage/chemrxiv/article-details/67ef866)
|
| 18 |
+
|
| 19 |
+
## Installation
|
| 20 |
+
|
| 21 |
+
Python 3.9+ (3.10+ for GUI)
|
| 22 |
+
|
| 23 |
+
GPU version (for training and inference)
|
| 24 |
+
```
|
| 25 |
+
pip install -r requirements-gpu.txt
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
CPU version (for inference, not recommended for training)
|
| 29 |
+
```
|
| 30 |
+
pip install -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## Usage
|
| 34 |
+
To correct spectra presented in the paper:
|
| 35 |
+
1. download weights (model parameters):
|
| 36 |
+
```
|
| 37 |
+
python download_files.py
|
| 38 |
+
```
|
| 39 |
+
or directly from [Google Drive 700MHz](https://drive.google.com/uc?export=download&id=17fTNWl7YW6mPbbZWga0EfdoF_6S8fCke) and [Google Drive 600MHz](https://drive.google.com/uc?export=download&id=1_VxOpFGJcFsOa5DHOW2GJbP8RvHCmC1N) and place it in `weights` directory
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
2. : run correction (e.g. `Azarone_20ul_700MHz.csv`):
|
| 43 |
+
```
|
| 44 |
+
python predict.py sample_data/Azarone_20ul_700MHz.csv -o output --config configs/shimnet_700.yaml --weights weights/shimnet_700MHz.pt
|
| 45 |
+
```
|
| 46 |
+
The output will be `output/Azarone_20ul_700MHz_processed.csv` file
|
| 47 |
+
|
| 48 |
+
Multiple files may be processed using "*" syntax:
|
| 49 |
+
```
|
| 50 |
+
python predict.py sample_data/*700MHz.csv -o output --config configs/shimnet_700.yaml --weights weights/shimnet_700MH
|
| 51 |
+
z.pt
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
For 600 MHz data use `--config configs/shimnet_600.yaml` and `--weights weights/shimnet_600MHz.pt`, e.g.:
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
python predict.py sample_data/CresolRed_after_styrene_600MHz.csv -o output --config configs/shimnet_600.yaml --weights weights/shimnet_600MHz.pt
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### input format
|
| 61 |
+
|
| 62 |
+
The spectrum file for reconstruction should be in the format of two columns separated by a space and without the sign at the end of the line at the end of the file(example below):
|
| 63 |
+
```csv
|
| 64 |
+
-1.97134 0.0167137
|
| 65 |
+
-1.97085 -0.00778748
|
| 66 |
+
-1.97036 -0.0109595
|
| 67 |
+
-1.96988 0.00825978
|
| 68 |
+
-1.96939 0.0133886
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## Train on your data
|
| 72 |
+
|
| 73 |
+
For the model to function properly, it should be trained on calibration data from the spectrometer used for the measurements. To train a model on data from your spectrometer, please follow the instructions below.
|
| 74 |
+
|
| 75 |
+
### Training data collection
|
| 76 |
+
|
| 77 |
+
Below we describe the training data collection for Agilent/Varian spectrometers. For machines of other vendors similar procedure can be implemented.
|
| 78 |
+
To collect ShimNet training data use Python script (sweep_shims_lineshape_Z1Z2.py) from the calibration_loop folder to drive the spectrometer:
|
| 79 |
+
1. Install TReNDS package ( trends.spektrino.com )
|
| 80 |
+
2. Open VnmrJ and type: 'listenon'
|
| 81 |
+
3. Put the lineshape sample (1% CHCl3 in deuterated acetone), set standard PROTON parameters, and set nt=1 (do not modify sw and at!)
|
| 82 |
+
4. Shim the sample and collect the data. Save the optimally shimmed dataset
|
| 83 |
+
5. Edit the sweep_shims_lineshape_Z1Z2.py script
|
| 84 |
+
6. Put optimum z1 and z2 shim values as optiz1 and optiz2 below
|
| 85 |
+
7. Define the calibration range as range_z1 and range_z2 (default is ok)
|
| 86 |
+
8. Start the python script:
|
| 87 |
+
```
|
| 88 |
+
python3 ./sweep_shims_lineshape_Z1Z2.py
|
| 89 |
+
```
|
| 90 |
+
The spectrometer will start collecting spectra
|
| 91 |
+
|
| 92 |
+
### SCRF extraction
|
| 93 |
+
Shim Coil Response Functions (SCRF) should be extracted from the spectra with `extract_scrf_from_fids.py` script.
|
| 94 |
+
```
|
| 95 |
+
python extract_scrf_from_fids.py
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
The script uses hardcoded paths to the NMR signals (fid-s) in Agilent/Varian format: a directory with optimal measurement (`opti_fid_path` available) and a directory with calibration loop measurements (`data_dir`):
|
| 99 |
+
```python
|
| 100 |
+
# input
|
| 101 |
+
data_dir = "../../sample_run/loop"
|
| 102 |
+
opti_fid_path = "../../sample_run/opti.fid"
|
| 103 |
+
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
The output files are also hardcoded:
|
| 107 |
+
```python
|
| 108 |
+
# output
|
| 109 |
+
spectra_file = "../../sample_run/total.npy"
|
| 110 |
+
spectra_file_names = "../../sample_run/total.csv"
|
| 111 |
+
opi_spectrum_file = "../../sample_run/opti.npy"
|
| 112 |
+
responses_file = "../../sample_run/scrf_61.pt"
|
| 113 |
+
```
|
| 114 |
+
where only the `responses_file` is used in ShimNet training.
|
| 115 |
+
|
| 116 |
+
If the measurements are stored in a format other than Varian, you may need to change this line:
|
| 117 |
+
```python
|
| 118 |
+
dic, data = ng.varian.read(varian_fid_path)
|
| 119 |
+
```
|
| 120 |
+
(see nmrglue package documentation for details)
|
| 121 |
+
|
| 122 |
+
### Training
|
| 123 |
+
|
| 124 |
+
1. Download multiplets database:
|
| 125 |
+
```
|
| 126 |
+
python download_files.py --multiplets
|
| 127 |
+
```
|
| 128 |
+
2. Configure run:
|
| 129 |
+
- create a run directory, e.g. `runs/my_lab_spectrometer_2025`
|
| 130 |
+
- create a configuration file:
|
| 131 |
+
1. copy `configs/shimnet_template.py` to the run directory and rename it to `config.yaml`
|
| 132 |
+
```bash
|
| 133 |
+
cp configs/shimnet_template.py runs/my_lab_spectrometer_2025/config.yaml
|
| 134 |
+
```
|
| 135 |
+
2. edit the SCRF in path in the config file:
|
| 136 |
+
```yaml
|
| 137 |
+
response_functions_files:
|
| 138 |
+
- path/to/srcf_file
|
| 139 |
+
```
|
| 140 |
+
e.g.
|
| 141 |
+
```yaml
|
| 142 |
+
response_functions_files:
|
| 143 |
+
- ../../sample_run/scrf_61.pt
|
| 144 |
+
```
|
| 145 |
+
3. adjust spectrometer frequency step `frq_step` to match your data (spectrometer range in Hz divided by number of points in spectrum):
|
| 146 |
+
```yaml
|
| 147 |
+
frq_step: 0.34059797
|
| 148 |
+
```
|
| 149 |
+
4. adjust spectromer frequency in the metadata
|
| 150 |
+
```yaml
|
| 151 |
+
metadata: # additional metadata, not used in the training process
|
| 152 |
+
spectrometer_frequency: 700.0 # MHz
|
| 153 |
+
```
|
| 154 |
+
3. Run training:
|
| 155 |
+
```
|
| 156 |
+
python train.py runs/my_lab_spectrometer_2025
|
| 157 |
+
```
|
| 158 |
+
Training results will appear in `runs/my_lab_spectrometer_2025` directory.
|
| 159 |
+
Model parameters are stored in `runs/my_lab_spectrometer_2025/model.pt` file
|
| 160 |
+
4. Use trained model:
|
| 161 |
+
|
| 162 |
+
use `--config runs/my_lab_spectrometer_2025/config.yaml` and `--weights runs/my_lab_spectrometer_2025/model.pt` flags, e.g.
|
| 163 |
+
```
|
| 164 |
+
python predict.py my_sample1.csv -o my_output --config runs/my_lab_spectrometer_2025/config.yaml --weights runs/my_lab_spectrometer_2025/model.pt
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
## Repeat training on our data
|
| 168 |
+
|
| 169 |
+
If you want to train the network using the calibration data from our paper, follow the procedure below.
|
| 170 |
+
|
| 171 |
+
1. Download multiplets database and our SCRF files:
|
| 172 |
+
```
|
| 173 |
+
python download_files.py --multiplets --SCRF --no-weights
|
| 174 |
+
```
|
| 175 |
+
or directly download from Google Drive and store in `data/` directory: [Response Functions 600MHz](https://drive.google.com/file/d/1J-DsPtaITXU3TFrbxaZPH800U1uIiwje/view?usp=sharing), [Response Functions 700MHz](https://drive.google.com/file/d/113al7A__yYALx_2hkESuzFIDU3feVtNY/view?usp=sharing), [Multiplets data](https://drive.google.com/file/d/1QGvV-Au50ZxaP1vFsmR_auI299Dw-Wrt/view?usp=sharing)
|
| 176 |
+
|
| 177 |
+
2. Configure run
|
| 178 |
+
- For 600MHz spectrometer:
|
| 179 |
+
```bash
|
| 180 |
+
mkdir -p runs/repeat_paper_training_600MHz
|
| 181 |
+
cp configs/shimnet_600.yaml runs/repeat_paper_training_600MHz/config.yaml
|
| 182 |
+
```
|
| 183 |
+
- For 700 MHz spectrometer:
|
| 184 |
+
```bash
|
| 185 |
+
mkdir -p runs/repeat_paper_training_700MHz
|
| 186 |
+
cp configs/shimnet_700.yaml runs/repeat_paper_training_700MHz/config.yaml
|
| 187 |
+
```
|
| 188 |
+
3. Run training:
|
| 189 |
+
```
|
| 190 |
+
python train.py runs/repeat_paper_training_600MHz
|
| 191 |
+
```
|
| 192 |
+
or
|
| 193 |
+
```
|
| 194 |
+
python train.py runs/repeat_paper_training_700MHz
|
| 195 |
+
```
|
| 196 |
+
Training results will appear in `runs/repeat_paper_training_600MHz` or `runs/repeat_paper_training_700MHz` directory.
|
| 197 |
+
|
| 198 |
+
## GUI
|
| 199 |
+
|
| 200 |
+
### Installation
|
| 201 |
+
|
| 202 |
+
To use the ShimNet GUI, ensure you have Python 3.10 installed (not tested with Python 3.11+). After installing the ShimNet requirements (CPU/GPU), install the additional dependencies for the GUI:
|
| 203 |
+
|
| 204 |
+
```bash
|
| 205 |
+
pip install -r requirements-gui.txt
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
### Launching the GUI
|
| 209 |
+
|
| 210 |
+
The ShimNet GUI is built using Gradio. To start the application, run:
|
| 211 |
+
|
| 212 |
+
```bash
|
| 213 |
+
python predict-gui.py
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
Once the application starts, open your browser and navigate to:
|
| 217 |
+
|
| 218 |
+
```
|
| 219 |
+
http://127.0.0.1:7860
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
to access the GUI locally.
|
| 223 |
+
|
| 224 |
+
### Sharing the GUI
|
| 225 |
+
|
| 226 |
+
To make the GUI accessible over the internet, use the `--share` flag:
|
| 227 |
+
|
| 228 |
+
```bash
|
| 229 |
+
python predict-gui.py --share
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
A public web address will be displayed in the terminal, which you can use to access the GUI remotely or share with others.
|
calibration_loop/maclib/shimator_loop
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
" macro shimnet loop "
|
| 2 |
+
" arguments: optimal z1, optimal z2, step in z1 loop, step in z2 loop, half number of steps in z1, half number of steps in z2"
|
| 3 |
+
|
| 4 |
+
$svfdir='/home/nmrbox/kkazimierczuk/shimnet/data/'
|
| 5 |
+
|
| 6 |
+
$opti_z1=$1
|
| 7 |
+
$opti_z2=$2
|
| 8 |
+
$step_z1=$3
|
| 9 |
+
$step_z2=$4
|
| 10 |
+
$steps_z1=$5
|
| 11 |
+
$steps_z2=$6
|
| 12 |
+
$nn=''
|
| 13 |
+
$bb=''
|
| 14 |
+
$msg=''
|
| 15 |
+
format((($steps_z1*2+1)*($steps_z2*2+1)*(d1+at)*nt)/3600,5,1):$msg
|
| 16 |
+
$msg='Time: '+$msg+' h'
|
| 17 |
+
banner($msg)
|
| 18 |
+
$j=0
|
| 19 |
+
repeat
|
| 20 |
+
$i=0
|
| 21 |
+
$z2=$opti_z2-$steps_z2*$step_z2+$j*$step_z2
|
| 22 |
+
format($z2,5,1):$bb
|
| 23 |
+
repeat
|
| 24 |
+
$z1=$opti_z1-$steps_z1*$step_z1+$i*$step_z1
|
| 25 |
+
"su"
|
| 26 |
+
"go"
|
| 27 |
+
format($z1,5,1):$nn
|
| 28 |
+
$filepath=$svfdir + 'z1_'+$nn+'z2_'+$bb+'.fid'
|
| 29 |
+
svf($filepath,'force')
|
| 30 |
+
$i=$i+1
|
| 31 |
+
until $z1>$opti_z1+$steps_z1*$step_z1-1
|
| 32 |
+
$j=$j+1
|
| 33 |
+
until $z2>$opti_z2+$steps_z2*$step_z2-1
|
calibration_loop/sweep_shims_lineshape_Z1Z2.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from TReND.send2vnmr import *
|
| 2 |
+
import shutil
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
from subprocess import call
|
| 7 |
+
|
| 8 |
+
#This is the script to collect ShimNet trainig data on Agilent spectrometers. Execution:
|
| 9 |
+
# 1. Open VnmrJ and type: 'listenon'
|
| 10 |
+
# 2. Put the lineshape sample, set standard PROTON parameters and set one scan (do not modify sw and at!)
|
| 11 |
+
# 3. Shim the sample and collect the data. Save the optimally shimmed datatset
|
| 12 |
+
# 4. Put optimum z1 and z2 shim values as optiz1 and optiz2 below
|
| 13 |
+
# 5. Define the calibration range as range_z1 and range_z2 (default is ok)
|
| 14 |
+
# 6. Start the python script:
|
| 15 |
+
# python3 ./sweep_shims_lineshape_Z1Z2.py
|
| 16 |
+
# the spectrometer will start collecting spectra
|
| 17 |
+
|
| 18 |
+
Classic_path = '/home/nmr700/shimnet6/lshp'
|
| 19 |
+
optiz1= 8868 #put optimum shim values here
|
| 20 |
+
optiz2=-297
|
| 21 |
+
|
| 22 |
+
range_z1=100 #put optimum shim ranges here
|
| 23 |
+
range_z2=100 #put optimum shim ranges here
|
| 24 |
+
|
| 25 |
+
z1_sweep=np.arange(optiz1-range_z1,optiz1+range_z1+1,2.0)
|
| 26 |
+
z2_sweep=np.arange(optiz2-range_z2,optiz2+range_z2+1,2.0)
|
| 27 |
+
|
| 28 |
+
for i in range(1,np.shape(z1_sweep)[0]+1, 1):
|
| 29 |
+
for j in range(1,np.shape(z2_sweep)[0]+1, 1):
|
| 30 |
+
wait_until_idle()
|
| 31 |
+
|
| 32 |
+
Run_macro("sethw('z1',"+str(z1_sweep[i-1])+ ")")
|
| 33 |
+
Run_macro("sethw('z2',"+str(z2_sweep[j-1])+ ")")
|
| 34 |
+
|
| 35 |
+
go_if_idle()
|
| 36 |
+
wait_until_idle()
|
| 37 |
+
time.sleep(0.5)
|
| 38 |
+
|
| 39 |
+
Save_experiment(Classic_path + '_z1_'+ str(int(z1_sweep[i-1])) + '_z2_'+ str(int(z2_sweep[j-1]))+ '.fid')
|
configs/shimnet_600.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: ShimNetWithSCRF
|
| 3 |
+
kwargs:
|
| 4 |
+
rensponse_length: 81
|
| 5 |
+
resnponse_head_dims:
|
| 6 |
+
- 128
|
| 7 |
+
training:
|
| 8 |
+
- batch_size: 64
|
| 9 |
+
learning_rate: 0.001
|
| 10 |
+
max_iters: 1600000
|
| 11 |
+
- batch_size: 512
|
| 12 |
+
learning_rate: 0.001
|
| 13 |
+
max_iters: 25600000
|
| 14 |
+
- batch_size: 512
|
| 15 |
+
learning_rate: 0.0005
|
| 16 |
+
max_iters: 12800000
|
| 17 |
+
losses_weights:
|
| 18 |
+
clean: 1.0
|
| 19 |
+
noised: 1.0
|
| 20 |
+
response: 1.0
|
| 21 |
+
data:
|
| 22 |
+
response_functions_files:
|
| 23 |
+
# Paste path to your SCRF file here
|
| 24 |
+
# - Can be absolute path
|
| 25 |
+
# - Can be relative to repository root
|
| 26 |
+
- data/scrf_81_600MHz.pt
|
| 27 |
+
atom_groups_data_file: data/multiplets_10000_parsed.txt
|
| 28 |
+
response_function_stretch_min: 1.0
|
| 29 |
+
response_function_stretch_max: 1.0
|
| 30 |
+
response_function_noise: 0.0
|
| 31 |
+
multiplicity_j1_min: 0.0
|
| 32 |
+
multiplicity_j1_max: 15
|
| 33 |
+
multiplicity_j2_min: 0.0
|
| 34 |
+
multiplicity_j2_max: 15
|
| 35 |
+
number_of_signals_min: 2
|
| 36 |
+
number_of_signals_max: 5
|
| 37 |
+
thf_min: 0.5
|
| 38 |
+
thf_max: 2
|
| 39 |
+
relative_height_min: 0.5
|
| 40 |
+
relative_height_max: 4
|
| 41 |
+
frq_step: 0.30048
|
| 42 |
+
logging:
|
| 43 |
+
step: 1000000
|
| 44 |
+
num_plots: 32
|
| 45 |
+
metadata:
|
| 46 |
+
spectrometer_frequency: 600.0
|
configs/shimnet_700.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: ShimNetWithSCRF
|
| 3 |
+
kwargs:
|
| 4 |
+
rensponse_length: 61
|
| 5 |
+
resnponse_head_dims:
|
| 6 |
+
- 128
|
| 7 |
+
training:
|
| 8 |
+
- batch_size: 64
|
| 9 |
+
learning_rate: 0.001
|
| 10 |
+
max_iters: 1600000
|
| 11 |
+
- batch_size: 512
|
| 12 |
+
learning_rate: 0.001
|
| 13 |
+
max_iters: 6400000
|
| 14 |
+
- batch_size: 512
|
| 15 |
+
learning_rate: 0.0005
|
| 16 |
+
max_iters: 12800000
|
| 17 |
+
- batch_size: 800
|
| 18 |
+
learning_rate: 0.0005
|
| 19 |
+
max_iters: 12800000
|
| 20 |
+
losses_weights:
|
| 21 |
+
clean: 1.0
|
| 22 |
+
noised: 1.0
|
| 23 |
+
response: 1.0
|
| 24 |
+
data:
|
| 25 |
+
response_functions_files:
|
| 26 |
+
- data/scrf_61_700MHz.pt
|
| 27 |
+
atom_groups_data_file: data/multiplets_10000_parsed.txt
|
| 28 |
+
response_function_stretch_min: 1.0
|
| 29 |
+
response_function_stretch_max: 1.0
|
| 30 |
+
response_function_noise: 0.0
|
| 31 |
+
multiplicity_j1_min: 0.0
|
| 32 |
+
multiplicity_j1_max: 15
|
| 33 |
+
multiplicity_j2_min: 0.0
|
| 34 |
+
multiplicity_j2_max: 15
|
| 35 |
+
number_of_signals_min: 2
|
| 36 |
+
number_of_signals_max: 5
|
| 37 |
+
thf_min: 0.5
|
| 38 |
+
thf_max: 2
|
| 39 |
+
relative_height_min: 0.5
|
| 40 |
+
relative_height_max: 4
|
| 41 |
+
frq_step: 0.34059797
|
| 42 |
+
logging:
|
| 43 |
+
step: 1000000
|
| 44 |
+
num_plots: 32
|
| 45 |
+
metadata:
|
| 46 |
+
spectrometer_frequency: 700.0
|
configs/shimnet_template.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: ShimNetWithSCRF
|
| 3 |
+
kwargs:
|
| 4 |
+
rensponse_length: 61
|
| 5 |
+
resnponse_head_dims:
|
| 6 |
+
- 128
|
| 7 |
+
training:
|
| 8 |
+
- batch_size: 64
|
| 9 |
+
learning_rate: 0.001
|
| 10 |
+
max_iters: 1600000
|
| 11 |
+
- batch_size: 512
|
| 12 |
+
learning_rate: 0.001
|
| 13 |
+
max_iters: 6400000
|
| 14 |
+
- batch_size: 512
|
| 15 |
+
learning_rate: 0.0005
|
| 16 |
+
max_iters: 12800000
|
| 17 |
+
- batch_size: 800
|
| 18 |
+
learning_rate: 0.0005
|
| 19 |
+
max_iters: 12800000
|
| 20 |
+
losses_weights:
|
| 21 |
+
clean: 1.0
|
| 22 |
+
noised: 1.0
|
| 23 |
+
response: 1.0
|
| 24 |
+
data:
|
| 25 |
+
response_functions_files:
|
| 26 |
+
# Specify the path to your SCRF file/files here.
|
| 27 |
+
# - It can be an absolute path.
|
| 28 |
+
# - It can be relative to the repository root.
|
| 29 |
+
# Multiple files can be listed in the following rows (YAML list format).
|
| 30 |
+
- path/to/srcf_file
|
| 31 |
+
atom_groups_data_file: data/multiplets_10000_parsed.txt
|
| 32 |
+
response_function_stretch_min: 1.0 # realtive
|
| 33 |
+
response_function_stretch_max: 1.0 # realtive
|
| 34 |
+
response_function_noise: 0.0 # arbitrary hight units
|
| 35 |
+
multiplicity_j1_min: 0.0 # Hz
|
| 36 |
+
multiplicity_j1_max: 15 # Hz
|
| 37 |
+
multiplicity_j2_min: 0.0 # Hz
|
| 38 |
+
multiplicity_j2_max: 15 # Hz
|
| 39 |
+
number_of_signals_min: 2
|
| 40 |
+
number_of_signals_max: 5
|
| 41 |
+
thf_min: 0.5 # arbitrary hight units
|
| 42 |
+
thf_max: 2 # arbitrary hight units
|
| 43 |
+
relative_height_min: 0.5 # realtive
|
| 44 |
+
relative_height_max: 4 # realtive
|
| 45 |
+
frq_step: 0.34059797 # Hz per point
|
| 46 |
+
logging:
|
| 47 |
+
step: 1000000
|
| 48 |
+
num_plots: 32
|
| 49 |
+
metadata: # additional metadata, not used in the training process
|
| 50 |
+
spectrometer_frequency: 700.0 # MHz
|
download_files.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import urllib.request
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
ALL_FILES_TO_DOWNLOAD = {
|
| 6 |
+
"weights": [{
|
| 7 |
+
"url": "https://drive.google.com/uc?export=download&id=17fTNWl7YW6mPbbZWga0EfdoF_6S8fCke",
|
| 8 |
+
"destination": "weights/shimnet_700MHz.pt"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"url": "https://drive.google.com/uc?export=download&id=1_VxOpFGJcFsOa5DHOW2GJbP8RvHCmC1N",
|
| 12 |
+
"destination": "weights/shimnet_600MHz.pt"
|
| 13 |
+
}],
|
| 14 |
+
"SCRF": [{
|
| 15 |
+
"url": "https://drive.google.com/uc?export=download&id=113al7A__yYALx_2hkESuzFIDU3feVtNY",
|
| 16 |
+
"destination": "data/scrf_61_700MHz.pt"
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"url": "https://drive.google.com/uc?export=download&id=1J-DsPtaITXU3TFrbxaZPH800U1uIiwje",
|
| 20 |
+
"destination": "data/scrf_81_600MHz.pt"
|
| 21 |
+
}],
|
| 22 |
+
"mupltiplets": [{
|
| 23 |
+
"url": "https://drive.google.com/uc?export=download&id=1QGvV-Au50ZxaP1vFsmR_auI299Dw-Wrt",
|
| 24 |
+
"destination": "data/multiplets_10000_parsed.txt"
|
| 25 |
+
}],
|
| 26 |
+
"development": []
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
def parse_args():
|
| 30 |
+
parser = argparse.ArgumentParser(
|
| 31 |
+
description='Download files: weighs (default), SCRF (optional), multiplet data (optional)',
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument('--overwrite', action='store_true', help='Overwrite existing files')
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
'--weights',
|
| 36 |
+
action='store_true',
|
| 37 |
+
default=True,
|
| 38 |
+
help='Download weights file (default behavior). Use --no-weights to opt out.',
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
'--no-weights',
|
| 42 |
+
action='store_false',
|
| 43 |
+
dest='weights',
|
| 44 |
+
help='Do not download weights file.',
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument('--SCRF', action='store_true', help='Download SCRF files - Shim Coil Response Functions')
|
| 47 |
+
parser.add_argument('--multiplets', action='store_true', help='Download multiplets data file')
|
| 48 |
+
parser.add_argument('--development', action='store_true', help='Download development weights file')
|
| 49 |
+
|
| 50 |
+
parser.add_argument('--all', action='store_true', help='Download all available files')
|
| 51 |
+
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
# Set all individual flags if --all is specified
|
| 54 |
+
if args.all:
|
| 55 |
+
args.weights = True
|
| 56 |
+
args.SCRF = True
|
| 57 |
+
args.multiplets = True
|
| 58 |
+
args.development = True
|
| 59 |
+
|
| 60 |
+
return args
|
| 61 |
+
|
| 62 |
+
def download_file(url, target, overwrite=False):
|
| 63 |
+
target = Path(target)
|
| 64 |
+
if target.exists() and not overwrite:
|
| 65 |
+
response = input(f"File {target} already exists. Overwrite? (y/n): ")
|
| 66 |
+
if response.lower() != 'y':
|
| 67 |
+
print(f"Download of {target} cancelled")
|
| 68 |
+
return
|
| 69 |
+
target.parent.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
try:
|
| 71 |
+
urllib.request.urlretrieve(url, target)
|
| 72 |
+
print(f"Downloaded {target}")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Failed to download file from {url}:\n {e}")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
args = parse_args()
|
| 79 |
+
|
| 80 |
+
main_dir = Path(__file__).parent
|
| 81 |
+
if args.weights:
|
| 82 |
+
for file_data in ALL_FILES_TO_DOWNLOAD["weights"]:
|
| 83 |
+
download_file(file_data["url"], main_dir / file_data["destination"], args.overwrite)
|
| 84 |
+
|
| 85 |
+
if args.SCRF:
|
| 86 |
+
for file_data in ALL_FILES_TO_DOWNLOAD["SCRF"]:
|
| 87 |
+
download_file(file_data["url"], main_dir / file_data["destination"], args.overwrite)
|
| 88 |
+
|
| 89 |
+
if args.multiplets:
|
| 90 |
+
for file_data in ALL_FILES_TO_DOWNLOAD["mupltiplets"]:
|
| 91 |
+
download_file(file_data["url"], main_dir / file_data["destination"], args.overwrite)
|
| 92 |
+
|
| 93 |
+
if args.development:
|
| 94 |
+
for file_data in ALL_FILES_TO_DOWNLOAD["development"]:
|
| 95 |
+
download_file(file_data["url"], main_dir / file_data["destination"], args.overwrite)
|
extract_scrf_from_fids.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
import nmrglue as ng
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
# input
|
| 12 |
+
data_dir = "../../../dane/600Hz_20241227/loop"
|
| 13 |
+
opti_fid_path = "../../../dane/600Hz_20241227/opti"
|
| 14 |
+
|
| 15 |
+
# output
|
| 16 |
+
spectra_file = "data/600Hz_20241227-Lnative/total.npy"
|
| 17 |
+
spectra_file_names = "data/600Hz_20241227-Lnative/total.csv"
|
| 18 |
+
opi_spectrum_file = "data/600Hz_20241227-Lnative/opti.npy"
|
| 19 |
+
responses_file = "data/600Hz_20241227-Lnative/scrf_61.pt"
|
| 20 |
+
losses_file = "data/600Hz_20241227-Lnative/losses_scrf_61.pt"
|
| 21 |
+
|
| 22 |
+
# create directories for output files
|
| 23 |
+
for file_path in [spectra_file, spectra_file_names, opi_spectrum_file, responses_file]:
|
| 24 |
+
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
# settings: fid to spectrum
|
| 27 |
+
ph0_correction =-190.49
|
| 28 |
+
ph1_correction = 0
|
| 29 |
+
autophase_fn="acme"
|
| 30 |
+
target_length = None
|
| 31 |
+
|
| 32 |
+
# settings: SCRF extraction
|
| 33 |
+
calibration_peak_center = "auto"#12277 # the center of the calibration peak (peak of)
|
| 34 |
+
calibration_window_halfwidth = 128 # the half width of the calibration window
|
| 35 |
+
steps = 6000
|
| 36 |
+
kernel_size = 61
|
| 37 |
+
kernel_sqrt = False # True to allow only positive values
|
| 38 |
+
|
| 39 |
+
# fid to spectra processing
|
| 40 |
+
|
| 41 |
+
def fid_to_spectrum(varian_fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=None, sin_pod=False):
|
| 42 |
+
dic, data = ng.varian.read(varian_fid_path)
|
| 43 |
+
data[0] *= 0.5
|
| 44 |
+
if sin_pod:
|
| 45 |
+
data = ng.proc_base.sp(data, end=0.98)
|
| 46 |
+
|
| 47 |
+
if target_length is not None:
|
| 48 |
+
if (pad_length := target_length - len(data)) > 0:
|
| 49 |
+
data = ng.proc_base.zf(data, pad_length)
|
| 50 |
+
else:
|
| 51 |
+
data = data[:target_length]
|
| 52 |
+
|
| 53 |
+
spec=ng.proc_base.fft(data)
|
| 54 |
+
spec = ng.process.proc_autophase.autops(spec, autophase_fn, p0=ph0_correction, p1=ph1_correction, disp=False)
|
| 55 |
+
|
| 56 |
+
return spec
|
| 57 |
+
|
| 58 |
+
# process optimal measurement fid to spectrum
|
| 59 |
+
opti_spectrum_full = fid_to_spectrum(opti_fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=target_length)
|
| 60 |
+
|
| 61 |
+
if calibration_peak_center == "auto":
|
| 62 |
+
calibration_peak_center = np.argmax(abs(opti_spectrum_full))
|
| 63 |
+
fitting_range = (calibration_peak_center - calibration_window_halfwidth, calibration_peak_center+calibration_window_halfwidth+1)
|
| 64 |
+
|
| 65 |
+
opti_spectrum = opti_spectrum_full[fitting_range[0]:fitting_range[1]]
|
| 66 |
+
np.save(opi_spectrum_file, opti_spectrum)
|
| 67 |
+
print(f"Optimal spectrum extracted to {opi_spectrum_file}")
|
| 68 |
+
|
| 69 |
+
# process loop fids to spectra
|
| 70 |
+
spec_list=[]
|
| 71 |
+
spec_names=[]
|
| 72 |
+
|
| 73 |
+
print("Extracting spectra from fids...")
|
| 74 |
+
for fid_path in tqdm(list(Path(data_dir).rglob('*.fid'))):
|
| 75 |
+
spec = fid_to_spectrum(fid_path, ph0_correction, ph1_correction, autophase_fn, target_length=target_length)[fitting_range[0]:fitting_range[1]]
|
| 76 |
+
|
| 77 |
+
spec_list.append(spec)
|
| 78 |
+
spec_names.append(fid_path.name)
|
| 79 |
+
|
| 80 |
+
total = np.array(spec_list)
|
| 81 |
+
np.save(spectra_file, total)
|
| 82 |
+
pd.DataFrame(spec_names).to_csv(spectra_file_names, header=False)
|
| 83 |
+
# total = np.load(spectra_file)
|
| 84 |
+
print(f"Spectra extracted to {spectra_file}")
|
| 85 |
+
|
| 86 |
+
# process SCRF extraction
|
| 87 |
+
def fit_kernel(base, target, kernel_size, kernel_sqrt=True, steps=20000, verbose=False):
|
| 88 |
+
|
| 89 |
+
kernel = torch.ones((1,1,kernel_size), dtype=base.dtype)
|
| 90 |
+
if kernel_sqrt:
|
| 91 |
+
kernel /= torch.sqrt(torch.sum(kernel**2))
|
| 92 |
+
else:
|
| 93 |
+
kernel /= kernel_size
|
| 94 |
+
kernel.requires_grad = True
|
| 95 |
+
|
| 96 |
+
optimizer = torch.optim.Adam([kernel])
|
| 97 |
+
|
| 98 |
+
for epoch in range(steps):
|
| 99 |
+
if kernel_sqrt:
|
| 100 |
+
spe_est = torch.conv1d(base, kernel**2, padding='same')
|
| 101 |
+
else:
|
| 102 |
+
spe_est = torch.conv1d(base, kernel, padding='same')
|
| 103 |
+
loss = torch.mean(abs(target - spe_est)**2) #torch.nn.functional.mse_loss(spe_est, target)
|
| 104 |
+
loss.backward()
|
| 105 |
+
optimizer.step()
|
| 106 |
+
optimizer.zero_grad()
|
| 107 |
+
|
| 108 |
+
if verbose and (epoch+1) % 100 == 0:
|
| 109 |
+
print(epoch, loss.item())
|
| 110 |
+
if kernel_sqrt:
|
| 111 |
+
return kernel.detach()**2, loss.item()
|
| 112 |
+
else:
|
| 113 |
+
return kernel.detach(), loss.item()
|
| 114 |
+
|
| 115 |
+
responses = torch.empty(len(total), 1, 1, 1, 1, kernel_size)
|
| 116 |
+
losses = torch.empty(len(total))
|
| 117 |
+
base = torch.tensor(opti_spectrum.real).unsqueeze(0)
|
| 118 |
+
targets = torch.tensor(total.real)
|
| 119 |
+
|
| 120 |
+
# normalization
|
| 121 |
+
base /= base.sum()
|
| 122 |
+
targets /= targets.sum(dim=(-1,), keepdim=True)
|
| 123 |
+
|
| 124 |
+
print("\nExtracting SCRFs...")
|
| 125 |
+
for i, target in tqdm(enumerate(targets), total=len(targets)):
|
| 126 |
+
kernel, loss = fit_kernel(base, target.unsqueeze(0), kernel_size, kernel_sqrt=kernel_sqrt, steps=steps)
|
| 127 |
+
responses[i, 0, 0] = kernel
|
| 128 |
+
losses[i] = loss
|
| 129 |
+
|
| 130 |
+
torch.save(responses, responses_file)
|
| 131 |
+
torch.save(losses, losses_file)
|
| 132 |
+
print(f"SCRFs extracted to {responses_file}, losses saved to {losses_file}")
|
predict-gui.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
torch.set_grad_enabled(False)
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from omegaconf import OmegaConf
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import plotly.graph_objects as go
|
| 8 |
+
|
| 9 |
+
from src.models import ShimNetWithSCRF, Predictor
|
| 10 |
+
from predict import Defaults, resample_input_spectrum, resample_output_spectrum, initialize_predictor
|
| 11 |
+
|
| 12 |
+
# silent deprecation warnings
|
| 13 |
+
import warnings
|
| 14 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
|
| 18 |
+
# Add argument parsing for server_name
|
| 19 |
+
parser = argparse.ArgumentParser(description="Launch ShimNet Spectra Correction App")
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--server_name",
|
| 22 |
+
type=str,
|
| 23 |
+
default="127.0.0.1",
|
| 24 |
+
help="Server name to bind the app (default: 127.0.0.1). Use 0.0.0.0 for external access."
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--share",
|
| 28 |
+
action="store_true",
|
| 29 |
+
help="If set, generates a public link to share the app."
|
| 30 |
+
)
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
def process_file(input_file, config_file, weights_file, input_spectrometer_frequency=None,reference_spectrum=None):
|
| 34 |
+
if input_spectrometer_frequency == 0:
|
| 35 |
+
input_spectrometer_frequency = None
|
| 36 |
+
# Load configuration and initialize predictor
|
| 37 |
+
config = OmegaConf.load(config_file)
|
| 38 |
+
model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
|
| 39 |
+
predictor = initialize_predictor(config, weights_file)
|
| 40 |
+
|
| 41 |
+
# Load input data
|
| 42 |
+
input_data = np.loadtxt(input_file)
|
| 43 |
+
input_freqs_input_ppm, input_spectrum = input_data[:, 0], input_data[:, 1]
|
| 44 |
+
|
| 45 |
+
# Convert input frequencies to model's frequency
|
| 46 |
+
if input_spectrometer_frequency is not None:
|
| 47 |
+
input_freqs_model_ppm = input_freqs_input_ppm * input_spectrometer_frequency / config.metadata.spectrometer_frequency
|
| 48 |
+
else:
|
| 49 |
+
input_freqs_model_ppm = input_freqs_input_ppm
|
| 50 |
+
|
| 51 |
+
# Resample input spectrum
|
| 52 |
+
freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point)
|
| 53 |
+
|
| 54 |
+
# Scale and process spectrum
|
| 55 |
+
spectrum_tensor = torch.tensor(spectrum).float()
|
| 56 |
+
scaling_factor = Defaults.SCALE / spectrum_tensor.max()
|
| 57 |
+
spectrum_tensor *= scaling_factor
|
| 58 |
+
prediction = predictor(spectrum_tensor).numpy()
|
| 59 |
+
prediction /= scaling_factor
|
| 60 |
+
|
| 61 |
+
# Resample output spectrum
|
| 62 |
+
output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)
|
| 63 |
+
|
| 64 |
+
# Prepare output data for download
|
| 65 |
+
output_data = np.column_stack((input_freqs_input_ppm, output_prediction))
|
| 66 |
+
output_file = f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
|
| 67 |
+
np.savetxt(output_file, output_data)
|
| 68 |
+
|
| 69 |
+
# Create Plotly figure
|
| 70 |
+
fig = go.Figure()
|
| 71 |
+
|
| 72 |
+
# Add Input Spectrum and Corrected Spectrum (always visible)
|
| 73 |
+
normalization_value = input_spectrum.max()
|
| 74 |
+
fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=input_spectrum/normalization_value, mode='lines', name='Input Spectrum', visible=True, line=dict(color='#EF553B'))) # red
|
| 75 |
+
fig.add_trace(go.Scatter(x=input_freqs_input_ppm, y=output_prediction/normalization_value, mode='lines', name='Corrected Spectrum', visible=True, line=dict(color='#00cc96'))) # green
|
| 76 |
+
|
| 77 |
+
if reference_spectrum is not None:
|
| 78 |
+
reference_spectrum_freqs, reference_spectrum_intensity = np.loadtxt(reference_spectrum).T
|
| 79 |
+
reference_spectrum_intensity /= reference_spectrum_intensity.max()
|
| 80 |
+
n_zooms = 50
|
| 81 |
+
zooms = np.geomspace(0.01, 100, 2 * n_zooms + 1)
|
| 82 |
+
|
| 83 |
+
# Add Reference Data traces (initially invisible)
|
| 84 |
+
for zoom in zooms:
|
| 85 |
+
fig.add_trace(
|
| 86 |
+
go.Scatter(
|
| 87 |
+
x=reference_spectrum_freqs,
|
| 88 |
+
y=reference_spectrum_intensity * zoom,
|
| 89 |
+
mode='lines',
|
| 90 |
+
name=f'Reference Data (Zoom: {zoom:.2f})',
|
| 91 |
+
visible=False,
|
| 92 |
+
line=dict(color='#636efa')
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
# Make the middle zoom level visible by default
|
| 96 |
+
fig.data[2 * n_zooms // 2 + 2].visible = True
|
| 97 |
+
|
| 98 |
+
# Create and add slider
|
| 99 |
+
steps = []
|
| 100 |
+
for i in range(2, len(fig.data)): # Start from the reference data traces
|
| 101 |
+
step = dict(
|
| 102 |
+
method="update",
|
| 103 |
+
args=[{"visible": [True, True] + [False] * (len(fig.data) - 2)}], # Keep first two traces visible
|
| 104 |
+
)
|
| 105 |
+
step["args"][0]["visible"][i] = True # Toggle i'th reference trace to "visible"
|
| 106 |
+
steps.append(step)
|
| 107 |
+
|
| 108 |
+
sliders = [dict(
|
| 109 |
+
active=n_zooms,
|
| 110 |
+
currentvalue={"prefix": "Reference zoom: "},
|
| 111 |
+
pad={"t": 50},
|
| 112 |
+
steps=steps
|
| 113 |
+
)]
|
| 114 |
+
|
| 115 |
+
fig.update_layout(
|
| 116 |
+
sliders=sliders
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
fig.update_layout(
|
| 120 |
+
title="Spectrum Visualization",
|
| 121 |
+
xaxis_title="Frequency (ppm)",
|
| 122 |
+
yaxis_title="Intensity"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return fig, output_file
|
| 126 |
+
|
| 127 |
+
# Gradio app
|
| 128 |
+
with gr.Blocks() as app:
|
| 129 |
+
gr.Markdown("# ShimNet Spectra Correction")
|
| 130 |
+
gr.Markdown("[ShimNet: A neural network for post-acquisition improvement of NMR spectra distorted by magnetic-field inhomogeneity](https://chemrxiv.org/engage/chemrxiv/article-details/67ef86686dde43c90860d315)")
|
| 131 |
+
gr.Markdown("Upload your input file, configuration, and weights to process the NMR spectrum.")
|
| 132 |
+
|
| 133 |
+
with gr.Row():
|
| 134 |
+
with gr.Column():
|
| 135 |
+
model_selection = gr.Radio(
|
| 136 |
+
label="Select Model",
|
| 137 |
+
choices=["600 MHz", "700 MHz", "Custom"],
|
| 138 |
+
value="600 MHz"
|
| 139 |
+
)
|
| 140 |
+
config_file = gr.File(label="Custom Config File (.yaml)", visible=False, height=120)
|
| 141 |
+
weights_file = gr.File(label="Custom Weights File (.pt)", visible=False, height=120)
|
| 142 |
+
|
| 143 |
+
with gr.Column():
|
| 144 |
+
input_file = gr.File(label="Input File (.txt | .csv)", height=120)
|
| 145 |
+
input_spectrometer_frequency = gr.Number(label="Input Spectrometer Frequency (MHz) (0 or empty if the same as in the loaded model)", value=None)
|
| 146 |
+
gr.Markdown("Upload reference spectrum files (optional). Reference spectrum will be plotted for comparison.")
|
| 147 |
+
reference_spectrum_file = gr.File(label="Reference Spectra File (.txt | .csv)", height=120)
|
| 148 |
+
|
| 149 |
+
process_button = gr.Button("Process File")
|
| 150 |
+
plot_output = gr.Plot(label="Spectrum Visualization")
|
| 151 |
+
download_button = gr.File(label="Download Processed File", interactive=False, height=120)
|
| 152 |
+
|
| 153 |
+
# Update visibility of config and weights fields based on model selection
|
| 154 |
+
def update_visibility(selected_model):
|
| 155 |
+
if selected_model == "Custom":
|
| 156 |
+
return gr.update(visible=True), gr.update(visible=True)
|
| 157 |
+
else:
|
| 158 |
+
return gr.update(visible=False), gr.update(visible=False)
|
| 159 |
+
|
| 160 |
+
model_selection.change(
|
| 161 |
+
update_visibility,
|
| 162 |
+
inputs=[model_selection],
|
| 163 |
+
outputs=[config_file, weights_file]
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Process button click logic
|
| 167 |
+
def process_file_with_model(input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file):
|
| 168 |
+
if model_selection == "600 MHz":
|
| 169 |
+
config_file = "configs/shimnet_600.yaml"
|
| 170 |
+
weights_file = "weights/shimnet_600MHz.pt"
|
| 171 |
+
elif model_selection == "700 MHz":
|
| 172 |
+
config_file = "configs/shimnet_700.yaml"
|
| 173 |
+
weights_file = "weights/shimnet_700MHz.pt"
|
| 174 |
+
else:
|
| 175 |
+
config_file = config_file.name
|
| 176 |
+
weights_file = weights_file.name
|
| 177 |
+
|
| 178 |
+
return process_file(input_file.name, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file.name if reference_spectrum_file else None)
|
| 179 |
+
|
| 180 |
+
process_button.click(
|
| 181 |
+
process_file_with_model,
|
| 182 |
+
inputs=[input_file, model_selection, config_file, weights_file, input_spectrometer_frequency, reference_spectrum_file],
|
| 183 |
+
outputs=[plot_output, download_button]
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
app.launch(share=args.share, server_name=args.server_name)
|
| 187 |
+
|
| 188 |
+
# '#636efa',
|
| 189 |
+
# '#EF553B',
|
| 190 |
+
# '#00cc96',
|
| 191 |
+
# '#ab63fa',
|
| 192 |
+
# '#FFA15A',
|
| 193 |
+
# '#19d3f3',
|
| 194 |
+
# '#FF6692',
|
| 195 |
+
# '#B6E880',
|
| 196 |
+
# '#FF97FF',
|
| 197 |
+
# '#FECB52'
|
predict.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
torch.set_grad_enabled(False)
|
| 3 |
+
import numpy as np
|
| 4 |
+
import argparse
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys, os
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
|
| 9 |
+
from src.models import ShimNetWithSCRF, Predictor
|
| 10 |
+
|
| 11 |
+
# silent deprecation warnings
|
| 12 |
+
# https://github.com/pytorch/pytorch/issues/97207#issuecomment-1494781560
|
| 13 |
+
import warnings
|
| 14 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
| 15 |
+
|
| 16 |
+
class Defaults:
|
| 17 |
+
SCALE = 16.0
|
| 18 |
+
|
| 19 |
+
def parse_args():
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument("input_files", help="Input files", nargs="+")
|
| 22 |
+
parser.add_argument("--config", help="config file .yaml")
|
| 23 |
+
parser.add_argument("--weights", help="model weights")
|
| 24 |
+
parser.add_argument("-o", "--output_dir", default=".", help="Output directory")
|
| 25 |
+
parser.add_argument("--input_spectrometer_frequency", default=None, type=float, help="spectrometer frequency in MHz (input sample collection frequency). Empty if the same as in the training data")
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
return args
|
| 28 |
+
|
| 29 |
+
# functions
|
| 30 |
+
def resample_input_spectrum(input_freqs, input_spectrum, Mhz_per_point):
|
| 31 |
+
"""resample input spectrum to match the model's frequency range"""
|
| 32 |
+
freqs = np.arange(input_freqs.min(), input_freqs.max(), Mhz_per_point)
|
| 33 |
+
spectrum = np.interp(freqs, input_freqs, input_spectrum)
|
| 34 |
+
return freqs, spectrum
|
| 35 |
+
|
| 36 |
+
def resample_output_spectrum(input_freqs, freqs, prediction):
|
| 37 |
+
"""resample prediction to match the input spectrum's frequency range"""
|
| 38 |
+
prediction = np.interp(input_freqs, freqs, prediction)
|
| 39 |
+
return prediction
|
| 40 |
+
|
| 41 |
+
def initialize_predictor(config, weights_file):
|
| 42 |
+
model = ShimNetWithSCRF(**config.model.kwargs)
|
| 43 |
+
predictor = Predictor(model, weights_file)
|
| 44 |
+
return predictor
|
| 45 |
+
|
| 46 |
+
# run
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
args = parse_args()
|
| 49 |
+
output_dir = Path(args.output_dir)
|
| 50 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 51 |
+
|
| 52 |
+
config = OmegaConf.load(args.config)
|
| 53 |
+
model_ppm_per_point = config.data.frq_step / config.metadata.spectrometer_frequency
|
| 54 |
+
predictor = initialize_predictor(config, args.weights)
|
| 55 |
+
|
| 56 |
+
for input_file in args.input_files:
|
| 57 |
+
print(f"processing {input_file} ...")
|
| 58 |
+
|
| 59 |
+
# load data
|
| 60 |
+
input_data = np.loadtxt(input_file)
|
| 61 |
+
input_freqs_input_ppm, input_spectrum = input_data[:,0], input_data[:,1]
|
| 62 |
+
|
| 63 |
+
# convert input frequencies to model's frequency - correct for zero filling ad spectrometer frequency
|
| 64 |
+
if args.input_spectrometer_frequency is not None:
|
| 65 |
+
input_freqs_model_ppm = input_freqs_input_ppm * args.input_spectrometer_frequency / config.metadata.spectrometer_frequency
|
| 66 |
+
else:
|
| 67 |
+
input_freqs_model_ppm = input_freqs_input_ppm
|
| 68 |
+
|
| 69 |
+
freqs, spectrum = resample_input_spectrum(input_freqs_model_ppm, input_spectrum, model_ppm_per_point)
|
| 70 |
+
|
| 71 |
+
spectrum = torch.tensor(spectrum).float()
|
| 72 |
+
# scale height of the spectrum
|
| 73 |
+
scaling_factor = Defaults.SCALE / spectrum.max()
|
| 74 |
+
spectrum *= scaling_factor
|
| 75 |
+
|
| 76 |
+
# correct spectrum
|
| 77 |
+
prediction = predictor(spectrum).numpy()
|
| 78 |
+
|
| 79 |
+
# rescale height
|
| 80 |
+
prediction /= scaling_factor
|
| 81 |
+
|
| 82 |
+
# resample the output to match the input spectrum
|
| 83 |
+
output_prediction = resample_output_spectrum(input_freqs_model_ppm, freqs, prediction)
|
| 84 |
+
|
| 85 |
+
# save result
|
| 86 |
+
output_file = output_dir / f"{Path(input_file).stem}_processed{Path(input_file).suffix}"
|
| 87 |
+
|
| 88 |
+
np.savetxt(output_file, np.column_stack((input_freqs_input_ppm, output_prediction)))
|
| 89 |
+
print(f"saved to {output_file}")
|
requirements-cpu.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.4.1+cpu
|
| 2 |
+
torchaudio==2.4.1+cpu
|
| 3 |
+
nmrglue==0.11
|
| 4 |
+
torchdata==0.9.0
|
| 5 |
+
numpy==2.0.2
|
| 6 |
+
matplotlib==3.9.3
|
| 7 |
+
pandas==2.2.3
|
| 8 |
+
tqdm==4.67.1
|
| 9 |
+
hydra-core==1.3.2
|
requirements-gpu.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.4.1
|
| 2 |
+
torchaudio==2.4.1
|
| 3 |
+
nmrglue==0.11
|
| 4 |
+
torchdata==0.9.0
|
| 5 |
+
numpy==2.0.2
|
| 6 |
+
matplotlib==3.9.3
|
| 7 |
+
pandas==2.2.3
|
| 8 |
+
tqdm==4.67.1
|
| 9 |
+
hydra-core==1.3.2
|
requirements-gui.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==5.23.2
|
| 2 |
+
plotly==6.0.1
|
sample_data/2ethylonaphthalene_bestshims_700MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/2ethylonaphthalene_up_1mm_700MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_20ul_700MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_X_supressed_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_Z1Z2Z3Z4_supressed_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_Z1Z2_supressed_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_besteshims_supressed_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Azarone_bestshims_700MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/CresolRed_after_styrene_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/CresolRed_after_styrene_cut_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/CresolRed_bestshims_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/CresolRed_cut_bestshims_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Geraniol_bestshims_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/Geraniol_up_1mm_600MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/SodiumButyrate_after_glucose_cut_700MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sample_data/SodiumButyrate_cut_bestshims_700MHz.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/generators.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torchdata
|
| 4 |
+
# from itertools import islice
|
| 5 |
+
|
| 6 |
+
def random_value(min_value, max_value):
|
| 7 |
+
return (min_value + torch.rand(1) * (max_value - min_value)).item()
|
| 8 |
+
|
| 9 |
+
def random_loguniform(min_value, max_value):
|
| 10 |
+
return (min_value * torch.exp(torch.rand(1) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
|
| 11 |
+
|
| 12 |
+
def calculate_theoretical_spectrum(peaks_parameters: dict, frq_frq:torch.Tensor):
|
| 13 |
+
# extract parameters
|
| 14 |
+
tff_lin = peaks_parameters["tff_lin"]
|
| 15 |
+
twf_lin = peaks_parameters["twf_lin"]
|
| 16 |
+
thf_lin = peaks_parameters["thf_lin"]
|
| 17 |
+
trf_lin = peaks_parameters["trf_lin"]
|
| 18 |
+
|
| 19 |
+
lwf_lin = twf_lin
|
| 20 |
+
lhf_lin = thf_lin * (1. - trf_lin)
|
| 21 |
+
gwf_lin = twf_lin
|
| 22 |
+
gdf_lin = gwf_lin / torch.tensor(2.).log().mul(2.).sqrt()
|
| 23 |
+
ghf_lin = thf_lin * trf_lin
|
| 24 |
+
# calculate Lorenz peaks contriubutions
|
| 25 |
+
lsf_linfrq = lwf_lin[:, None] ** 2 / (lwf_lin[:, None] ** 2 + (frq_frq - tff_lin[:, None]) ** 2) * lhf_lin[:, None]
|
| 26 |
+
# calculate Gaussian peaks contriubutions
|
| 27 |
+
gsf_linfrq = torch.exp(-(frq_frq - tff_lin[:, None]) ** 2 / gdf_lin[:, None] ** 2 / 2.) * ghf_lin[:, None]
|
| 28 |
+
tsf_linfrq = lsf_linfrq + gsf_linfrq
|
| 29 |
+
# sum peaks contriubutions
|
| 30 |
+
tsf_frq = tsf_linfrq.sum(0, keepdim = True)
|
| 31 |
+
return tsf_frq
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
pascal_triangle = [(1,), (1,1), (1,2,1), (1,3,3,1), (1,4,6,4,1), (1,5,10,10,5,1), (1,6,15,20,15,6,1), (1,7, 21,35,35,21,7,1)]
|
| 35 |
+
normalized_pascal_triangle = [torch.tensor(x)/sum(x) for x in pascal_triangle]
|
| 36 |
+
|
| 37 |
+
def pascal_multiplicity(multiplicity):
|
| 38 |
+
intensities = normalized_pascal_triangle[multiplicity-1]
|
| 39 |
+
n_peaks = len(intensities)
|
| 40 |
+
shifts = torch.arange(n_peaks)-((n_peaks-1)/2)
|
| 41 |
+
return shifts, intensities
|
| 42 |
+
|
| 43 |
+
def double_multiplicity(multiplicity1, multiplicity2, j1=1, j2=1):
|
| 44 |
+
shifts1, intensities1 = pascal_multiplicity(multiplicity1)
|
| 45 |
+
shifts2, intensities2 = pascal_multiplicity(multiplicity2)
|
| 46 |
+
|
| 47 |
+
shifts = (j1*shifts1.reshape(-1,1) + j2*shifts2.reshape(1,-1)).flatten()
|
| 48 |
+
intensities = (intensities1.reshape(-1,1) * intensities2.reshape(1,-1)).flatten()
|
| 49 |
+
return shifts, intensities
|
| 50 |
+
|
| 51 |
+
def generate_multiplet_parameters(multiplicity, tff_lin, thf_lin, twf_lin, trf_lin, j1, j2):
|
| 52 |
+
shifts, intensities = double_multiplicity(multiplicity[0], multiplicity[1], j1, j2)
|
| 53 |
+
n_peaks = len(shifts)
|
| 54 |
+
|
| 55 |
+
return {
|
| 56 |
+
"tff_lin": shifts + tff_lin,
|
| 57 |
+
"thf_lin": intensities * thf_lin,
|
| 58 |
+
"twf_lin": torch.full((n_peaks,), twf_lin),
|
| 59 |
+
"trf_lin": torch.full((n_peaks,), trf_lin),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
def value_to_index(values, table):
|
| 63 |
+
span = table[-1] - table[0]
|
| 64 |
+
indices = ((values - table[0])/span * (len(table)-1)) #.round().type(torch.int64)
|
| 65 |
+
return indices
|
| 66 |
+
|
| 67 |
+
def generate_theoretical_spectrum(
|
| 68 |
+
number_of_signals_min, number_of_signals_max,
|
| 69 |
+
spectrum_width_min, spectrum_width_max,
|
| 70 |
+
relative_width_min, relative_width_max,
|
| 71 |
+
tff_min, tff_max,
|
| 72 |
+
thf_min, thf_max,
|
| 73 |
+
trf_min, trf_max,
|
| 74 |
+
relative_height_min, relative_height_max,
|
| 75 |
+
multiplicity_j1_min, multiplicity_j1_max,
|
| 76 |
+
multiplicity_j2_min, multiplicity_j2_max,
|
| 77 |
+
atom_groups_data,
|
| 78 |
+
frq_frq
|
| 79 |
+
):
|
| 80 |
+
number_of_signals = torch.randint(number_of_signals_min, number_of_signals_max+1, [])
|
| 81 |
+
atom_group_indices = torch.randint(0, len(atom_groups_data), [number_of_signals])
|
| 82 |
+
width_spectrum = random_loguniform(spectrum_width_min, spectrum_width_max)
|
| 83 |
+
height_spectrum = random_loguniform(thf_min, thf_max)
|
| 84 |
+
|
| 85 |
+
peak_parameters_data = []
|
| 86 |
+
theoretical_spectrum = None
|
| 87 |
+
for atom_group_index in atom_group_indices:
|
| 88 |
+
relative_intensity, multiplicity1, multiplicity2 = atom_groups_data[atom_group_index]
|
| 89 |
+
position = random_value(tff_min, tff_max)
|
| 90 |
+
j1 = random_value(multiplicity_j1_min, multiplicity_j1_max)
|
| 91 |
+
j2 = random_value(multiplicity_j2_min, multiplicity_j2_max)
|
| 92 |
+
width = width_spectrum*random_loguniform(relative_width_min, relative_width_max)
|
| 93 |
+
height = height_spectrum*relative_intensity*random_loguniform(relative_height_min, relative_height_max)
|
| 94 |
+
gaussian_contribution = random_value(trf_min, trf_max)
|
| 95 |
+
|
| 96 |
+
peaks_parameters = generate_multiplet_parameters(multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin= width, trf_lin= gaussian_contribution, j1=j1, j2=j2)
|
| 97 |
+
peaks_parameters["tff_relative"] = value_to_index(peaks_parameters["tff_lin"], frq_frq)
|
| 98 |
+
peak_parameters_data.append(peaks_parameters)
|
| 99 |
+
spectrum_contribution = calculate_theoretical_spectrum(peaks_parameters, frq_frq)
|
| 100 |
+
if theoretical_spectrum is None:
|
| 101 |
+
theoretical_spectrum = spectrum_contribution
|
| 102 |
+
else:
|
| 103 |
+
theoretical_spectrum += spectrum_contribution
|
| 104 |
+
return theoretical_spectrum, peak_parameters_data
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def theoretical_generator(
|
| 108 |
+
atom_groups_data,
|
| 109 |
+
pixels=2048, frq_step=11160.7142857 / 32768,
|
| 110 |
+
number_of_signals_min=1, number_of_signals_max=8,
|
| 111 |
+
spectrum_width_min=0.2, spectrum_width_max=1,
|
| 112 |
+
relative_width_min=1, relative_width_max=2,
|
| 113 |
+
relative_height_min=1, relative_height_max=1,
|
| 114 |
+
relative_frequency_min=-0.4, relative_frequency_max=0.4,
|
| 115 |
+
thf_min=1/16, thf_max=16,
|
| 116 |
+
trf_min=0, trf_max=1,
|
| 117 |
+
multiplicity_j1_min=0, multiplicity_j1_max=15,
|
| 118 |
+
multiplicity_j2_min=0, multiplicity_j2_max=15,
|
| 119 |
+
):
|
| 120 |
+
tff_min = relative_frequency_min * pixels * frq_step
|
| 121 |
+
tff_max = relative_frequency_max * pixels * frq_step
|
| 122 |
+
frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step
|
| 123 |
+
|
| 124 |
+
while True:
|
| 125 |
+
yield generate_theoretical_spectrum(
|
| 126 |
+
number_of_signals_min=number_of_signals_min,
|
| 127 |
+
number_of_signals_max=number_of_signals_max,
|
| 128 |
+
spectrum_width_min=spectrum_width_min,
|
| 129 |
+
spectrum_width_max=spectrum_width_max,
|
| 130 |
+
relative_width_min=relative_width_min,
|
| 131 |
+
relative_width_max=relative_width_max,
|
| 132 |
+
relative_height_min=relative_height_min,
|
| 133 |
+
relative_height_max=relative_height_max,
|
| 134 |
+
tff_min=tff_min, tff_max=tff_max,
|
| 135 |
+
thf_min=thf_min, thf_max=thf_max,
|
| 136 |
+
trf_min=trf_min, trf_max=trf_max,
|
| 137 |
+
multiplicity_j1_min=multiplicity_j1_min,
|
| 138 |
+
multiplicity_j1_max=multiplicity_j1_max,
|
| 139 |
+
multiplicity_j2_min=multiplicity_j2_min,
|
| 140 |
+
multiplicity_j2_max=multiplicity_j2_max,
|
| 141 |
+
atom_groups_data=atom_groups_data,
|
| 142 |
+
frq_frq=frq_frq
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
class ResponseLibrary:
|
| 146 |
+
def __init__(self, reponse_files, normalize=True):
|
| 147 |
+
self.data = [torch.load(f, map_location='cpu', weights_only=True).flatten(0,-4) for f in reponse_files]
|
| 148 |
+
if normalize:
|
| 149 |
+
self.data = [data/torch.sum(data, dim=(-1,), keepdim=True) for data in self.data]
|
| 150 |
+
lengths = [len(data) for data in self.data]
|
| 151 |
+
self.start_indices = torch.cumsum(torch.tensor([0] + lengths[:-1]), 0)
|
| 152 |
+
self.total_length = sum(lengths)
|
| 153 |
+
|
| 154 |
+
def __getitem__(self, idx):
|
| 155 |
+
if idx >= self.total_length:
|
| 156 |
+
raise ValueError(f'index {idx} out of range')
|
| 157 |
+
tensor_index = torch.searchsorted(self.start_indices, idx, right=True) - 1
|
| 158 |
+
return self.data[tensor_index][idx - self.start_indices[tensor_index]]
|
| 159 |
+
|
| 160 |
+
def __len__(self):
|
| 161 |
+
return self.total_length
|
| 162 |
+
|
| 163 |
+
def generator(
|
| 164 |
+
theoretical_generator_params,
|
| 165 |
+
response_function_library,
|
| 166 |
+
response_function_stretch_min=0.5,
|
| 167 |
+
response_function_stretch_max=2.0,
|
| 168 |
+
response_function_noise=0.,
|
| 169 |
+
spectrum_noise_min=0.,
|
| 170 |
+
spectrum_noise_max=1/64,
|
| 171 |
+
include_spectrum_data=False,
|
| 172 |
+
include_peak_mask=False,
|
| 173 |
+
include_response_function=False,
|
| 174 |
+
flip_response_function=False
|
| 175 |
+
|
| 176 |
+
):
|
| 177 |
+
for theoretical_spectrum, theoretical_spectrum_data in theoretical_generator(**theoretical_generator_params):
|
| 178 |
+
# get response function
|
| 179 |
+
response_function = response_function_library[torch.randint(0, len(response_function_library), [1])][0]
|
| 180 |
+
# stretch response function
|
| 181 |
+
padding_size = (response_function.shape[-1] - 1)//2
|
| 182 |
+
padding_size = round(random_loguniform(response_function_stretch_min, response_function_stretch_max)*padding_size) #torch.randint(round(padding_size*response_function_stretch_min), round(padding_size*response_function_stretch_max), [1]).item()
|
| 183 |
+
response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
|
| 184 |
+
response_function /= response_function.sum() # normalize sum of response function to 1
|
| 185 |
+
# add noise to response function
|
| 186 |
+
response_function += torch.randn(response_function.shape) * response_function_noise
|
| 187 |
+
response_function /= response_function.sum() # normalize sum of response function to 1
|
| 188 |
+
if flip_response_function and (torch.rand(1).item() < 0.5):
|
| 189 |
+
response_function = response_function.flip(-1)
|
| 190 |
+
# disturbed spectrum
|
| 191 |
+
disturbed_spectrum = torch.nn.functional.conv1d(theoretical_spectrum, response_function, padding=padding_size)
|
| 192 |
+
# add noise
|
| 193 |
+
noised_spectrum = disturbed_spectrum + torch.randn(disturbed_spectrum.shape) * random_value(spectrum_noise_min, spectrum_noise_max)
|
| 194 |
+
|
| 195 |
+
out = {
|
| 196 |
+
# 'response_function': response_function,
|
| 197 |
+
'theoretical_spectrum': theoretical_spectrum,
|
| 198 |
+
'disturbed_spectrum': disturbed_spectrum,
|
| 199 |
+
'noised_spectrum': noised_spectrum,
|
| 200 |
+
}
|
| 201 |
+
if include_response_function:
|
| 202 |
+
out['response_function'] = response_function
|
| 203 |
+
if include_spectrum_data:
|
| 204 |
+
out["theoretical_spectrum_data"] = theoretical_spectrum_data
|
| 205 |
+
if include_peak_mask:
|
| 206 |
+
all_peaks_rel = torch.cat([peak_data["tff_relative"] for peak_data in theoretical_spectrum_data])
|
| 207 |
+
peaks_indices = all_peaks_rel.round().type(torch.int64)
|
| 208 |
+
out["peaks_mask"] = torch.scatter(torch.zeros(out["theoretical_spectrum"].shape[1]), 0, peaks_indices, 1.).unsqueeze(0)
|
| 209 |
+
|
| 210 |
+
yield out
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def collate_with_spectrum_data(batch):
|
| 214 |
+
tensor_keys = set(batch[0].keys())
|
| 215 |
+
tensor_keys.remove('theoretical_spectrum_data')
|
| 216 |
+
out = {k: torch.stack([item[k] for item in batch]) for k in tensor_keys}
|
| 217 |
+
out["theoretical_spectrum_data"] = [item["theoretical_spectrum_data"] for item in batch]
|
| 218 |
+
return out
|
| 219 |
+
|
| 220 |
+
def get_datapipe(
|
| 221 |
+
response_functions_files,
|
| 222 |
+
atom_groups_data_file=None,
|
| 223 |
+
batch_size=64,
|
| 224 |
+
pixels=2048, frq_step=11160.7142857 / 32768,
|
| 225 |
+
number_of_signals_min=1, number_of_signals_max=8,
|
| 226 |
+
spectrum_width_min=0.2, spectrum_width_max=1,
|
| 227 |
+
relative_width_min=1, relative_width_max=2,
|
| 228 |
+
relative_height_min=1, relative_height_max=1,
|
| 229 |
+
relative_frequency_min=-0.4, relative_frequency_max=0.4,
|
| 230 |
+
thf_min=1/16, thf_max=16,
|
| 231 |
+
trf_min=0, trf_max=1,
|
| 232 |
+
multiplicity_j1_min=0, multiplicity_j1_max=15,
|
| 233 |
+
multiplicity_j2_min=0, multiplicity_j2_max=15,
|
| 234 |
+
response_function_stretch_min=0.5,
|
| 235 |
+
response_function_stretch_max=2.0,
|
| 236 |
+
response_function_noise=0.,
|
| 237 |
+
spectrum_noise_min=0.,
|
| 238 |
+
spectrum_noise_max=1/64,
|
| 239 |
+
include_spectrum_data=False,
|
| 240 |
+
include_peak_mask=False,
|
| 241 |
+
include_response_function=False,
|
| 242 |
+
flip_response_function=False
|
| 243 |
+
):
|
| 244 |
+
# singlets
|
| 245 |
+
if atom_groups_data_file is None:
|
| 246 |
+
atom_groups_data = np.ones((1,3), dtype=int)
|
| 247 |
+
else:
|
| 248 |
+
atom_groups_data = np.loadtxt(atom_groups_data_file, usecols=(1,2,3), dtype=int)
|
| 249 |
+
response_function_library = ResponseLibrary(response_functions_files)
|
| 250 |
+
g = generator(
|
| 251 |
+
theoretical_generator_params=dict(
|
| 252 |
+
atom_groups_data=atom_groups_data,
|
| 253 |
+
pixels=pixels, frq_step=frq_step,
|
| 254 |
+
number_of_signals_min=number_of_signals_min, number_of_signals_max=number_of_signals_max,
|
| 255 |
+
spectrum_width_min=spectrum_width_min, spectrum_width_max=spectrum_width_max,
|
| 256 |
+
relative_width_min=relative_width_min, relative_width_max=relative_width_max,
|
| 257 |
+
relative_height_min=relative_height_min, relative_height_max=relative_height_max,
|
| 258 |
+
relative_frequency_min=relative_frequency_min, relative_frequency_max=relative_frequency_max,
|
| 259 |
+
thf_min=thf_min, thf_max=thf_max,
|
| 260 |
+
trf_min=trf_min, trf_max=trf_max,
|
| 261 |
+
multiplicity_j1_min=multiplicity_j1_min, multiplicity_j1_max=multiplicity_j1_max,
|
| 262 |
+
multiplicity_j2_min=multiplicity_j2_min, multiplicity_j2_max=multiplicity_j2_max
|
| 263 |
+
),
|
| 264 |
+
response_function_library=response_function_library,
|
| 265 |
+
response_function_stretch_min=response_function_stretch_min,
|
| 266 |
+
response_function_stretch_max=response_function_stretch_max,
|
| 267 |
+
response_function_noise=response_function_noise,
|
| 268 |
+
spectrum_noise_min=spectrum_noise_min,
|
| 269 |
+
spectrum_noise_max=spectrum_noise_max,
|
| 270 |
+
include_spectrum_data=include_spectrum_data,
|
| 271 |
+
include_peak_mask=include_peak_mask,
|
| 272 |
+
include_response_function=include_response_function,
|
| 273 |
+
flip_response_function=flip_response_function
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
pipe = torchdata.datapipes.iter.IterableWrapper(g, deepcopy=False)
|
| 277 |
+
pipe = pipe.batch(batch_size)
|
| 278 |
+
pipe = pipe.collate(collate_fn=collate_with_spectrum_data if include_spectrum_data else None)
|
| 279 |
+
|
| 280 |
+
return pipe
|
src/models.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
class ConvEncoder(torch.nn.Module):
|
| 4 |
+
def __init__(self, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
|
| 5 |
+
super().__init__()
|
| 6 |
+
if output_dim is None:
|
| 7 |
+
output_dim = hidden_dim
|
| 8 |
+
self.conv4 = torch.nn.Conv1d(1, hidden_dim, kernel_size)
|
| 9 |
+
self.conv3 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
|
| 10 |
+
self.conv2 = torch.nn.Conv1d(hidden_dim, hidden_dim, kernel_size)
|
| 11 |
+
self.conv1 = torch.nn.Conv1d(hidden_dim, output_dim, kernel_size)
|
| 12 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 13 |
+
|
| 14 |
+
def forward(self, feature): #(samples, 1, 2048)
|
| 15 |
+
feature = self.dropout(self.conv4(feature)) #(samples, 64, 2042)
|
| 16 |
+
feature = feature.relu()
|
| 17 |
+
feature = self.dropout(self.conv3(feature)) #(samples, 64, 2036)
|
| 18 |
+
feature = feature.relu()
|
| 19 |
+
feature = self.dropout(self.conv2(feature)) #(samples, 64, 2030)
|
| 20 |
+
feature = feature.relu()
|
| 21 |
+
feature = self.dropout(self.conv1(feature)) #(samples, 64, 2024)
|
| 22 |
+
return feature
|
| 23 |
+
|
| 24 |
+
class ConvDecoder(torch.nn.Module):
|
| 25 |
+
def __init__(self, input_dim=None, hidden_dim=64, output_dim=None, dropout=0, kernel_size=7):
|
| 26 |
+
super().__init__()
|
| 27 |
+
if output_dim is None:
|
| 28 |
+
output_dim = hidden_dim
|
| 29 |
+
self.convTranspose1 = torch.nn.ConvTranspose1d(input_dim, hidden_dim, kernel_size)
|
| 30 |
+
self.convTranspose2 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
|
| 31 |
+
self.convTranspose3 = torch.nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size)
|
| 32 |
+
self.convTranspose4 = torch.nn.ConvTranspose1d(hidden_dim, 1, kernel_size)
|
| 33 |
+
|
| 34 |
+
def forward(self, feature): #(samples, 1, 2048)
|
| 35 |
+
feature = self.convTranspose1(feature) #(samples, 64, 2030)
|
| 36 |
+
feature = feature.relu()
|
| 37 |
+
feature = self.convTranspose2(feature) #(samples, 64, 2036)
|
| 38 |
+
feature = feature.relu()
|
| 39 |
+
feature = self.convTranspose3(feature) #(samples, 64, 2042)
|
| 40 |
+
feature = feature.relu()
|
| 41 |
+
feature = self.convTranspose4(feature)
|
| 42 |
+
return feature
|
| 43 |
+
|
| 44 |
+
class ResponseHead(torch.nn.Module):
|
| 45 |
+
def __init__(self, input_dim, output_length, hidden_dims=[128]):
|
| 46 |
+
super().__init__()
|
| 47 |
+
response_head_dims = [input_dim]+hidden_dims + [output_length]
|
| 48 |
+
response_head_layers = [torch.nn.Linear(response_head_dims[0], response_head_dims[1])]
|
| 49 |
+
for dims_in, dims_out in zip(response_head_dims[1:-1], response_head_dims[2:]):
|
| 50 |
+
response_head_layers.extend([
|
| 51 |
+
torch.nn.GELU(),
|
| 52 |
+
torch.nn.Linear(dims_in, dims_out)
|
| 53 |
+
])
|
| 54 |
+
self.response_head = torch.nn.Sequential(*response_head_layers)
|
| 55 |
+
|
| 56 |
+
def forward(self, feature):
|
| 57 |
+
return self.response_head(feature)
|
| 58 |
+
|
| 59 |
+
class ShimNetWithSCRF(torch.nn.Module):
|
| 60 |
+
def __init__(self,
|
| 61 |
+
encoder_hidden_dims=64,
|
| 62 |
+
encoder_dropout=0,
|
| 63 |
+
bottleneck_dim=64,
|
| 64 |
+
rensponse_length=61,
|
| 65 |
+
resnponse_head_dims=[128],
|
| 66 |
+
decoder_hidden_dims=64
|
| 67 |
+
):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.encoder = ConvEncoder(hidden_dim=encoder_hidden_dims, output_dim=bottleneck_dim, dropout=encoder_dropout)
|
| 70 |
+
self.query = torch.nn.Parameter(torch.empty(1, 1, bottleneck_dim))
|
| 71 |
+
torch.nn.init.xavier_normal_(self.query)
|
| 72 |
+
|
| 73 |
+
self.decoder = ConvDecoder(input_dim=2*bottleneck_dim, hidden_dim=decoder_hidden_dims)
|
| 74 |
+
|
| 75 |
+
self.rensponse_length = rensponse_length
|
| 76 |
+
self.response_head = ResponseHead(bottleneck_dim, rensponse_length, resnponse_head_dims)
|
| 77 |
+
|
| 78 |
+
def forward(self, feature): #(samples, 1, 2048)
|
| 79 |
+
feature = self.encoder(feature) #(samples, 64, 2042)
|
| 80 |
+
energy = self.query @ feature #(samples, 1, 2024)
|
| 81 |
+
weight = torch.nn.functional.softmax(energy, 2) #(samples, 1, 2024)
|
| 82 |
+
global_features = feature @ weight.transpose(1, 2) #(samples, 64, 1)
|
| 83 |
+
|
| 84 |
+
response = self.response_head(global_features.squeeze(-1))
|
| 85 |
+
|
| 86 |
+
feature, global_features = torch.broadcast_tensors(feature, global_features) #(samples, 64, 2048)
|
| 87 |
+
feature = torch.cat([feature, global_features], 1) #(samples, 128, 2024)
|
| 88 |
+
denoised_spectrum = self.decoder(feature) #(samples, 1, 2048)
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
'denoised': denoised_spectrum,
|
| 92 |
+
'response': response,
|
| 93 |
+
'attention': weight.squeeze(1)
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
class Predictor:
|
| 97 |
+
def __init__(self, model=None, weights_file=None):
|
| 98 |
+
self.model = model
|
| 99 |
+
if weights_file is not None:
|
| 100 |
+
self.model.load_state_dict(torch.load(weights_file, map_location='cpu', weights_only=True))
|
| 101 |
+
|
| 102 |
+
def __call__(self, nsf_frq):
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
msf_frq = self.model(nsf_frq[None, None])["denoised"]
|
| 105 |
+
return msf_frq[0, 0]
|
train.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, torchaudio
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from hydra.utils import instantiate
|
| 6 |
+
import datetime
|
| 7 |
+
import sys
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import matplotlib
|
| 12 |
+
matplotlib.use('Agg')
|
| 13 |
+
|
| 14 |
+
# silent deprecation_warning() from datapipes
|
| 15 |
+
import warnings
|
| 16 |
+
warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
|
| 17 |
+
|
| 18 |
+
from src import models
|
| 19 |
+
from src.generators import get_datapipe
|
| 20 |
+
|
| 21 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 22 |
+
if len(sys.argv) < 2:
|
| 23 |
+
print("Please provide the run directory as an argument.")
|
| 24 |
+
sys.exit(1)
|
| 25 |
+
run_dir = Path(sys.argv[1])
|
| 26 |
+
|
| 27 |
+
config = OmegaConf.load(run_dir / "config.yaml")
|
| 28 |
+
|
| 29 |
+
if (run_dir / "train.txt").is_file():
|
| 30 |
+
minimum = np.min(np.loadtxt(run_dir / "train.txt")[:,2])
|
| 31 |
+
else:
|
| 32 |
+
minimum = float("inf")
|
| 33 |
+
|
| 34 |
+
# initialization
|
| 35 |
+
model = instantiate({"_target_": f"__main__.models.{config.model.name}", **config.model.kwargs}).to(device)
|
| 36 |
+
model_weights_file = run_dir / f'model.pt'
|
| 37 |
+
optimizer = torch.optim.Adam(model.parameters())
|
| 38 |
+
optimizer_weights_file = run_dir / f'optimizer.pt'
|
| 39 |
+
|
| 40 |
+
def evaluate_model(stage=0, epoch=0):
|
| 41 |
+
plot_dir = run_dir / "plots" / f"{stage}_{epoch}"
|
| 42 |
+
plot_dir.mkdir(exist_ok=True, parents=True)
|
| 43 |
+
|
| 44 |
+
torch.save(model.state_dict(), plot_dir / "model.pt")
|
| 45 |
+
torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt")
|
| 46 |
+
|
| 47 |
+
num_plots = config.logging.num_plots
|
| 48 |
+
pipe = get_datapipe(
|
| 49 |
+
**config.data,
|
| 50 |
+
include_response_function=True,
|
| 51 |
+
batch_size=num_plots
|
| 52 |
+
)
|
| 53 |
+
batch = next(iter(pipe))
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
out = model(batch['noised_spectrum'].to(device))
|
| 57 |
+
noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same").cpu()
|
| 58 |
+
|
| 59 |
+
for i in range(num_plots):
|
| 60 |
+
plt.figure(figsize=(30,6))
|
| 61 |
+
plt.plot(batch['theoretical_spectrum'].cpu().numpy()[i,0])
|
| 62 |
+
plt.plot(out['denoised'].cpu().numpy()[i,0])
|
| 63 |
+
plt.savefig(plot_dir / f"{i:03d}_spectrum_clean.png")
|
| 64 |
+
|
| 65 |
+
plt.figure(figsize=(30,6))
|
| 66 |
+
plt.plot(batch['noised_spectrum'].cpu().numpy()[i,0])
|
| 67 |
+
plt.plot(noised_est.cpu().numpy()[i,0])
|
| 68 |
+
plt.savefig(plot_dir / f"{i:03d}_spectrum_noise.png")
|
| 69 |
+
|
| 70 |
+
plt.figure(figsize=(10,6))
|
| 71 |
+
plt.plot(batch['response_function'].cpu().numpy()[i,0,0])
|
| 72 |
+
plt.plot(out['response'].cpu().numpy()[i])
|
| 73 |
+
plt.savefig(plot_dir / f"{i:03d}_response.png")
|
| 74 |
+
|
| 75 |
+
if "attention" in out:
|
| 76 |
+
plt.figure(figsize=(10, 6))
|
| 77 |
+
plt.plot(out['attention'].cpu().numpy()[i])
|
| 78 |
+
plt.savefig(plot_dir / f"{i:03d}_attention.png")
|
| 79 |
+
|
| 80 |
+
plt.close("all")
|
| 81 |
+
|
| 82 |
+
for i_stage, training_stage in enumerate(config.training):
|
| 83 |
+
if model_weights_file.is_file():
|
| 84 |
+
model.load_state_dict(torch.load(model_weights_file, weights_only=True))
|
| 85 |
+
|
| 86 |
+
if optimizer_weights_file.is_file():
|
| 87 |
+
optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True))
|
| 88 |
+
optimizer.param_groups[0]['lr'] = training_stage.learning_rate
|
| 89 |
+
|
| 90 |
+
pipe = get_datapipe(
|
| 91 |
+
**config.data,
|
| 92 |
+
include_response_function=True,
|
| 93 |
+
batch_size=training_stage.batch_size
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
losses_history = []
|
| 97 |
+
losses_history_limit = 64*100 // training_stage.batch_size
|
| 98 |
+
|
| 99 |
+
last_evaluation = 0
|
| 100 |
+
for epoch, batch in pipe.enumerate():
|
| 101 |
+
|
| 102 |
+
# logging
|
| 103 |
+
iters_done = epoch*training_stage.batch_size
|
| 104 |
+
if (iters_done - last_evaluation) > config.logging.step:
|
| 105 |
+
evaluate_model(i_stage, epoch)
|
| 106 |
+
last_evaluation = iters_done
|
| 107 |
+
|
| 108 |
+
if iters_done > training_stage.max_iters:
|
| 109 |
+
evaluate_model(i_stage, epoch)
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
# run model
|
| 113 |
+
out = model(batch['noised_spectrum'].to(device))
|
| 114 |
+
# calculate losses
|
| 115 |
+
loss_response = torch.nn.functional.mse_loss(out['response'], batch['response_function'].squeeze(dim=(1,2)).to(device))
|
| 116 |
+
loss_clean = torch.nn.functional.mse_loss(out['denoised'], batch['theoretical_spectrum'].to(device))
|
| 117 |
+
noised_est = torchaudio.functional.convolve(out['denoised'], out['response'].flip(dims=(-1,)).unsqueeze(1), mode="same")
|
| 118 |
+
loss_noised = torch.nn.functional.mse_loss(noised_est, batch['noised_spectrum'].to(device))
|
| 119 |
+
loss = config.losses_weights.response*loss_response + config.losses_weights.clean*loss_clean + config.losses_weights.noised*loss_noised
|
| 120 |
+
|
| 121 |
+
# logging
|
| 122 |
+
losses_history.append(loss_clean.item())
|
| 123 |
+
losses_history = losses_history[-losses_history_limit:]
|
| 124 |
+
loss_avg = sum(losses_history)/len(losses_history)
|
| 125 |
+
message = f"{epoch:7d} {loss:0.3e} {loss_avg:0.3e} {loss_clean:0.3e} {loss_response:0.3e} {loss_noised:0.3e}"
|
| 126 |
+
# message = '%7i %.3e %.3e %.3e' % (epoch, loss, regress, classify)
|
| 127 |
+
with open(run_dir / f'train.txt', 'a') as f:
|
| 128 |
+
f.write(message + '\n')
|
| 129 |
+
print(message, flush = True)
|
| 130 |
+
|
| 131 |
+
# save best
|
| 132 |
+
if loss_avg < minimum:
|
| 133 |
+
minimum = loss_avg
|
| 134 |
+
torch.save(model.state_dict(), model_weights_file)
|
| 135 |
+
torch.save(optimizer.state_dict(),optimizer_weights_file)
|
| 136 |
+
|
| 137 |
+
# update weights
|
| 138 |
+
loss.backward()
|
| 139 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
|
| 140 |
+
optimizer.step()
|
| 141 |
+
optimizer.zero_grad()
|
weights/shimnet_600MHz.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8bac4bbff78b4d898c07ac0daab30203ead33a896d5803060593746bbb15dd10
|
| 3 |
+
size 889765
|
weights/shimnet_700MHz.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b42c9a2dfd4cb3ed0ce3aa328ea09cce03b9b3713e92b728229f24fff3781835
|
| 3 |
+
size 880746
|