Spaces:
Sleeping
Sleeping
Marek Bukowicki commited on
Commit ·
2495192
1
Parent(s): 7544717
rewrite datapipe as modular
Browse files- configs/shimnet_600_modular.yaml +68 -0
- shimnet/generators.py +330 -19
- train.py +23 -15
configs/shimnet_600_modular.yaml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: ShimNetWithSCRF
|
| 3 |
+
kwargs:
|
| 4 |
+
rensponse_length: 81
|
| 5 |
+
resnponse_head_dims:
|
| 6 |
+
- 128
|
| 7 |
+
training:
|
| 8 |
+
- batch_size: 64
|
| 9 |
+
learning_rate: 0.001
|
| 10 |
+
max_iters: 1600000
|
| 11 |
+
- batch_size: 512
|
| 12 |
+
learning_rate: 0.001
|
| 13 |
+
max_iters: 25600000
|
| 14 |
+
- batch_size: 512
|
| 15 |
+
learning_rate: 0.0005
|
| 16 |
+
max_iters: 12800000
|
| 17 |
+
losses_weights:
|
| 18 |
+
clean: 1.0
|
| 19 |
+
noised: 1.0
|
| 20 |
+
response: 1.0
|
| 21 |
+
data:
|
| 22 |
+
_target_: shimnet.generators.Generator
|
| 23 |
+
include_response_function: true
|
| 24 |
+
seed: null # null means random seed
|
| 25 |
+
batch_size: null # to be set in training script
|
| 26 |
+
clean_spectra_generator:
|
| 27 |
+
_target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
|
| 28 |
+
atom_groups_data_file: data/multiplets_10000_parsed.txt
|
| 29 |
+
pixels: 2048
|
| 30 |
+
frq_step: ${metadata.frq_step}
|
| 31 |
+
number_of_signals_min: 2
|
| 32 |
+
number_of_signals_max: 5
|
| 33 |
+
spectrum_width_min: 0.2
|
| 34 |
+
spectrum_width_max: 1.0
|
| 35 |
+
relative_width_min: 1.0
|
| 36 |
+
relative_width_max: 2.0
|
| 37 |
+
relative_height_min: 0.5
|
| 38 |
+
relative_height_max: 4
|
| 39 |
+
relative_frequency_min: -0.4
|
| 40 |
+
relative_frequency_max: 0.4
|
| 41 |
+
thf_min: 0.5
|
| 42 |
+
thf_max: 2
|
| 43 |
+
trf_min: 0.0
|
| 44 |
+
trf_max: 1.0
|
| 45 |
+
multiplicity_j1_min: 0.0
|
| 46 |
+
multiplicity_j1_max: 15
|
| 47 |
+
multiplicity_j2_min: 0.0
|
| 48 |
+
multiplicity_j2_max: 15
|
| 49 |
+
response_generator:
|
| 50 |
+
_target_: shimnet.generators.ResponseGenerator
|
| 51 |
+
response_function_library:
|
| 52 |
+
_target_: shimnet.generators.ResponseLibrary
|
| 53 |
+
response_files:
|
| 54 |
+
- data/scrf_81_600MHz.pt
|
| 55 |
+
response_function_stretch_min: 1.0
|
| 56 |
+
response_function_stretch_max: 1.0
|
| 57 |
+
response_function_noise: 0.0
|
| 58 |
+
flip_response_function: false
|
| 59 |
+
noise_generator:
|
| 60 |
+
_target_: shimnet.generators.NoiseGenerator
|
| 61 |
+
spectrum_noise_min: 0.0
|
| 62 |
+
spectrum_noise_max: 0.015625
|
| 63 |
+
logging:
|
| 64 |
+
step: 1000000
|
| 65 |
+
num_plots: 32
|
| 66 |
+
metadata:
|
| 67 |
+
frq_step: 0.30048
|
| 68 |
+
spectrometer_frequency: 600.0
|
shimnet/generators.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
import torchdata
|
|
|
|
|
|
|
|
|
|
| 4 |
# from itertools import islice
|
| 5 |
|
| 6 |
-
def random_value(min_value, max_value):
|
| 7 |
-
return (min_value + torch.rand(1) * (max_value - min_value)).item()
|
| 8 |
|
| 9 |
-
def random_loguniform(min_value, max_value):
|
| 10 |
-
return (min_value * torch.exp(torch.rand(1) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
|
| 11 |
|
| 12 |
def calculate_theoretical_spectrum(peaks_parameters: dict, frq_frq:torch.Tensor):
|
| 13 |
# extract parameters
|
|
@@ -75,23 +78,24 @@ def generate_theoretical_spectrum(
|
|
| 75 |
multiplicity_j1_min, multiplicity_j1_max,
|
| 76 |
multiplicity_j2_min, multiplicity_j2_max,
|
| 77 |
atom_groups_data,
|
| 78 |
-
frq_frq
|
|
|
|
| 79 |
):
|
| 80 |
-
number_of_signals = torch.randint(number_of_signals_min, number_of_signals_max+1, [])
|
| 81 |
-
atom_group_indices = torch.randint(0, len(atom_groups_data), [number_of_signals])
|
| 82 |
-
width_spectrum = random_loguniform(spectrum_width_min, spectrum_width_max)
|
| 83 |
-
height_spectrum = random_loguniform(thf_min, thf_max)
|
| 84 |
|
| 85 |
peak_parameters_data = []
|
| 86 |
theoretical_spectrum = None
|
| 87 |
for atom_group_index in atom_group_indices:
|
| 88 |
relative_intensity, multiplicity1, multiplicity2 = atom_groups_data[atom_group_index]
|
| 89 |
-
position = random_value(tff_min, tff_max)
|
| 90 |
-
j1 = random_value(multiplicity_j1_min, multiplicity_j1_max)
|
| 91 |
-
j2 = random_value(multiplicity_j2_min, multiplicity_j2_max)
|
| 92 |
-
width = width_spectrum*random_loguniform(relative_width_min, relative_width_max)
|
| 93 |
-
height = height_spectrum*relative_intensity*random_loguniform(relative_height_min, relative_height_max)
|
| 94 |
-
gaussian_contribution = random_value(trf_min, trf_max)
|
| 95 |
|
| 96 |
peaks_parameters = generate_multiplet_parameters(multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin= width, trf_lin= gaussian_contribution, j1=j1, j2=j2)
|
| 97 |
peaks_parameters["tff_relative"] = value_to_index(peaks_parameters["tff_lin"], frq_frq)
|
|
@@ -143,8 +147,8 @@ def theoretical_generator(
|
|
| 143 |
)
|
| 144 |
|
| 145 |
class ResponseLibrary:
|
| 146 |
-
def __init__(self,
|
| 147 |
-
self.data = [torch.load(f, map_location='cpu', weights_only=True).flatten(0,-4) for f in
|
| 148 |
if normalize:
|
| 149 |
self.data = [data/torch.sum(data, dim=(-1,), keepdim=True) for data in self.data]
|
| 150 |
lengths = [len(data) for data in self.data]
|
|
@@ -159,6 +163,10 @@ class ResponseLibrary:
|
|
| 159 |
|
| 160 |
def __len__(self):
|
| 161 |
return self.total_length
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
def generator(
|
| 164 |
theoretical_generator_params,
|
|
@@ -179,7 +187,7 @@ def generator(
|
|
| 179 |
response_function = response_function_library[torch.randint(0, len(response_function_library), [1])][0]
|
| 180 |
# stretch response function
|
| 181 |
padding_size = (response_function.shape[-1] - 1)//2
|
| 182 |
-
padding_size = round(random_loguniform(response_function_stretch_min, response_function_stretch_max)*padding_size) #torch.randint(round(padding_size*response_function_stretch_min), round(
|
| 183 |
response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
|
| 184 |
response_function /= response_function.sum() # normalize sum of response function to 1
|
| 185 |
# add noise to response function
|
|
@@ -277,4 +285,307 @@ def get_datapipe(
|
|
| 277 |
pipe = pipe.batch(batch_size)
|
| 278 |
pipe = pipe.collate(collate_fn=collate_with_spectrum_data if include_spectrum_data else None)
|
| 279 |
|
| 280 |
-
return pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
import torchdata
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
|
| 7 |
# from itertools import islice
|
| 8 |
|
| 9 |
+
def random_value(min_value, max_value, generator=None):
|
| 10 |
+
return (min_value + torch.rand(1, generator=generator) * (max_value - min_value)).item()
|
| 11 |
|
| 12 |
+
def random_loguniform(min_value, max_value, generator=None):
|
| 13 |
+
return (min_value * torch.exp(torch.rand(1, generator=generator) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
|
| 14 |
|
| 15 |
def calculate_theoretical_spectrum(peaks_parameters: dict, frq_frq:torch.Tensor):
|
| 16 |
# extract parameters
|
|
|
|
| 78 |
multiplicity_j1_min, multiplicity_j1_max,
|
| 79 |
multiplicity_j2_min, multiplicity_j2_max,
|
| 80 |
atom_groups_data,
|
| 81 |
+
frq_frq,
|
| 82 |
+
generator=None
|
| 83 |
):
|
| 84 |
+
number_of_signals = torch.randint(number_of_signals_min, number_of_signals_max+1, [], generator=generator)
|
| 85 |
+
atom_group_indices = torch.randint(0, len(atom_groups_data), [number_of_signals], generator=generator)
|
| 86 |
+
width_spectrum = random_loguniform(spectrum_width_min, spectrum_width_max, generator=generator)
|
| 87 |
+
height_spectrum = random_loguniform(thf_min, thf_max, generator=generator)
|
| 88 |
|
| 89 |
peak_parameters_data = []
|
| 90 |
theoretical_spectrum = None
|
| 91 |
for atom_group_index in atom_group_indices:
|
| 92 |
relative_intensity, multiplicity1, multiplicity2 = atom_groups_data[atom_group_index]
|
| 93 |
+
position = random_value(tff_min, tff_max, generator=generator)
|
| 94 |
+
j1 = random_value(multiplicity_j1_min, multiplicity_j1_max, generator=generator)
|
| 95 |
+
j2 = random_value(multiplicity_j2_min, multiplicity_j2_max, generator=generator)
|
| 96 |
+
width = width_spectrum*random_loguniform(relative_width_min, relative_width_max, generator=generator)
|
| 97 |
+
height = height_spectrum*relative_intensity*random_loguniform(relative_height_min, relative_height_max, generator=generator)
|
| 98 |
+
gaussian_contribution = random_value(trf_min, trf_max, generator=generator)
|
| 99 |
|
| 100 |
peaks_parameters = generate_multiplet_parameters(multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin= width, trf_lin= gaussian_contribution, j1=j1, j2=j2)
|
| 101 |
peaks_parameters["tff_relative"] = value_to_index(peaks_parameters["tff_lin"], frq_frq)
|
|
|
|
| 147 |
)
|
| 148 |
|
| 149 |
class ResponseLibrary:
|
| 150 |
+
def __init__(self, response_files, normalize=True):
|
| 151 |
+
self.data = [torch.load(f, map_location='cpu', weights_only=True).flatten(0,-4) for f in response_files]
|
| 152 |
if normalize:
|
| 153 |
self.data = [data/torch.sum(data, dim=(-1,), keepdim=True) for data in self.data]
|
| 154 |
lengths = [len(data) for data in self.data]
|
|
|
|
| 163 |
|
| 164 |
def __len__(self):
|
| 165 |
return self.total_length
|
| 166 |
+
|
| 167 |
+
@property
|
| 168 |
+
def max_response_length(self):
|
| 169 |
+
return max([data.shape[-1] for data in self.data])
|
| 170 |
|
| 171 |
def generator(
|
| 172 |
theoretical_generator_params,
|
|
|
|
| 187 |
response_function = response_function_library[torch.randint(0, len(response_function_library), [1])][0]
|
| 188 |
# stretch response function
|
| 189 |
padding_size = (response_function.shape[-1] - 1)//2
|
| 190 |
+
padding_size = round(random_loguniform(response_function_stretch_min, response_function_stretch_max)*padding_size) #torch.randint(round(padding_size*response_function_stretch_min), round(paddingSize*response_function_stretch_max), [1]).item()
|
| 191 |
response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
|
| 192 |
response_function /= response_function.sum() # normalize sum of response function to 1
|
| 193 |
# add noise to response function
|
|
|
|
| 285 |
pipe = pipe.batch(batch_size)
|
| 286 |
pipe = pipe.collate(collate_fn=collate_with_spectrum_data if include_spectrum_data else None)
|
| 287 |
|
| 288 |
+
return pipe
|
| 289 |
+
|
| 290 |
+
# response_functions_files,
|
| 291 |
+
# atom_groups_data_file=None,
|
| 292 |
+
# batch_size=64,
|
| 293 |
+
# pixels=2048, frq_step=11160.7142857 / 32768,
|
| 294 |
+
# number_of_signals_min=1, number_of_signals_max=8,
|
| 295 |
+
# spectrum_width_min=0.2, spectrum_width_max=1,
|
| 296 |
+
# relative_width_min=1, relative_width_max=2,
|
| 297 |
+
# relative_height_min=1, relative_height_max=1,
|
| 298 |
+
# relative_frequency_min=-0.4, relative_frequency_max=0.4,
|
| 299 |
+
# thf_min=1/16, thf_max=16,
|
| 300 |
+
# trf_min=0, trf_max=1,
|
| 301 |
+
# multiplicity_j1_min=0, multiplicity_j1_max=15,
|
| 302 |
+
# multiplicity_j2_min=0, multiplicity_j2_max=15,
|
| 303 |
+
# response_function_stretch_min=0.5,
|
| 304 |
+
# response_function_stretch_max=2.0,
|
| 305 |
+
# response_function_noise=0.,
|
| 306 |
+
# spectrum_noise_min=0.,
|
| 307 |
+
# spectrum_noise_max=1/64,
|
| 308 |
+
# include_spectrum_data=False,
|
| 309 |
+
# include_peak_mask=False,
|
| 310 |
+
# include_response_function=False,
|
| 311 |
+
# flip_response_function=False
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class RngGetter:
|
| 315 |
+
def __init__(self, seed=42):
|
| 316 |
+
self.rng = torch.Generator()
|
| 317 |
+
if seed is not None:
|
| 318 |
+
self.rng.manual_seed(seed)
|
| 319 |
+
else:
|
| 320 |
+
self.rng.seed()
|
| 321 |
+
|
| 322 |
+
def get_rng(self, seed=None):
|
| 323 |
+
# Use provided seed or fall back to instance RNG
|
| 324 |
+
if seed is not None:
|
| 325 |
+
rng = torch.Generator()
|
| 326 |
+
rng.manual_seed(seed)
|
| 327 |
+
else:
|
| 328 |
+
rng = self.rng
|
| 329 |
+
return rng
|
| 330 |
+
|
| 331 |
+
class TheoreticalMultipletSpectraGenerator:
|
| 332 |
+
def __init__(self, atom_groups_data_file=None, pixels=2048, frq_step=11160.7142857 / 32768,
|
| 333 |
+
number_of_signals_min=1, number_of_signals_max=8,
|
| 334 |
+
spectrum_width_min=0.2, spectrum_width_max=1, relative_width_min=1, relative_width_max=2,
|
| 335 |
+
relative_height_min=1, relative_height_max=1, relative_frequency_min=-0.4, relative_frequency_max=0.4,
|
| 336 |
+
thf_min=1/16, thf_max=16, trf_min=0, trf_max=1, multiplicity_j1_min=0, multiplicity_j1_max=15,
|
| 337 |
+
multiplicity_j2_min=0, multiplicity_j2_max=15, seed=42, **kwargs):
|
| 338 |
+
# Read atom_groups_data from file
|
| 339 |
+
if atom_groups_data_file is None:
|
| 340 |
+
self.atom_groups_data = np.ones((1,3), dtype=int)
|
| 341 |
+
else:
|
| 342 |
+
self.atom_groups_data = np.atleast_2d(np.loadtxt(atom_groups_data_file, usecols=(1,2,3), dtype=int))
|
| 343 |
+
self.pixels = pixels
|
| 344 |
+
self.frq_step = frq_step
|
| 345 |
+
self.number_of_signals_min = number_of_signals_min
|
| 346 |
+
self.number_of_signals_max = number_of_signals_max
|
| 347 |
+
self.spectrum_width_min = spectrum_width_min
|
| 348 |
+
self.spectrum_width_max = spectrum_width_max
|
| 349 |
+
self.relative_width_min = relative_width_min
|
| 350 |
+
self.relative_width_max = relative_width_max
|
| 351 |
+
self.relative_height_min = relative_height_min
|
| 352 |
+
self.relative_height_max = relative_height_max
|
| 353 |
+
self.relative_frequency_min = relative_frequency_min
|
| 354 |
+
self.relative_frequency_max = relative_frequency_max
|
| 355 |
+
self.thf_min = thf_min
|
| 356 |
+
self.thf_max = thf_max
|
| 357 |
+
self.trf_min = trf_min
|
| 358 |
+
self.trf_max = trf_max
|
| 359 |
+
self.multiplicity_j1_min = multiplicity_j1_min
|
| 360 |
+
self.multiplicity_j1_max = multiplicity_j1_max
|
| 361 |
+
self.multiplicity_j2_min = multiplicity_j2_min
|
| 362 |
+
self.multiplicity_j2_max = multiplicity_j2_max
|
| 363 |
+
self.frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step
|
| 364 |
+
self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator
|
| 365 |
+
|
| 366 |
+
def __call__(self, seed=None):
|
| 367 |
+
rng = self.rng_getter.get_rng(seed=seed)
|
| 368 |
+
|
| 369 |
+
spectrum, spectrum_data = generate_theoretical_spectrum(
|
| 370 |
+
number_of_signals_min=self.number_of_signals_min,
|
| 371 |
+
number_of_signals_max=self.number_of_signals_max,
|
| 372 |
+
spectrum_width_min=self.spectrum_width_min,
|
| 373 |
+
spectrum_width_max=self.spectrum_width_max,
|
| 374 |
+
relative_width_min=self.relative_width_min,
|
| 375 |
+
relative_width_max=self.relative_width_max,
|
| 376 |
+
tff_min=self.relative_frequency_min * self.pixels * self.frq_step,
|
| 377 |
+
tff_max=self.relative_frequency_max * self.pixels * self.frq_step,
|
| 378 |
+
thf_min=self.thf_min,
|
| 379 |
+
thf_max=self.thf_max,
|
| 380 |
+
trf_min=self.trf_min,
|
| 381 |
+
trf_max=self.trf_max,
|
| 382 |
+
relative_height_min=self.relative_height_min,
|
| 383 |
+
relative_height_max=self.relative_height_max,
|
| 384 |
+
multiplicity_j1_min=self.multiplicity_j1_min,
|
| 385 |
+
multiplicity_j1_max=self.multiplicity_j1_max,
|
| 386 |
+
multiplicity_j2_min=self.multiplicity_j2_min,
|
| 387 |
+
multiplicity_j2_max=self.multiplicity_j2_max,
|
| 388 |
+
atom_groups_data=self.atom_groups_data,
|
| 389 |
+
frq_frq=self.frq_frq,
|
| 390 |
+
generator=rng
|
| 391 |
+
)
|
| 392 |
+
return spectrum, {"spectrum_data": spectrum_data, "frq_frq": self.frq_frq}
|
| 393 |
+
|
| 394 |
+
class ResponseGenerator:
|
| 395 |
+
def __init__(self, response_function_library, response_function_stretch_min=1., response_function_stretch_max=1., pad_to=None,
|
| 396 |
+
response_function_noise=0.0, flip_response_function=False, seed=42):
|
| 397 |
+
self.response_function_library = response_function_library
|
| 398 |
+
self.response_function_stretch_min = response_function_stretch_min
|
| 399 |
+
self.response_function_stretch_max = response_function_stretch_max
|
| 400 |
+
self.pad_to = pad_to
|
| 401 |
+
self.response_function_noise = response_function_noise
|
| 402 |
+
self.flip_response_function = flip_response_function
|
| 403 |
+
self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator
|
| 404 |
+
|
| 405 |
+
def __call__(self, seed=None):
|
| 406 |
+
rng = self.rng_getter.get_rng(seed=seed)
|
| 407 |
+
|
| 408 |
+
response_function = self.response_function_library[torch.randint(0, len(self.response_function_library), [1], generator=rng)][0]
|
| 409 |
+
padding_size = (response_function.shape[-1] - 1)//2
|
| 410 |
+
padding_size = round(random_loguniform(self.response_function_stretch_min, self.response_function_stretch_max, generator=rng)*padding_size)
|
| 411 |
+
response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
|
| 412 |
+
response_function /= response_function.sum()
|
| 413 |
+
response_function += torch.randn(response_function.shape, generator=rng) * self.response_function_noise
|
| 414 |
+
response_function /= response_function.sum()
|
| 415 |
+
if self.flip_response_function and (torch.rand(1, generator=rng).item() < 0.5):
|
| 416 |
+
response_function = response_function.flip(-1)
|
| 417 |
+
if self.pad_to is not None:
|
| 418 |
+
pad_size_left = (self.pad_to - response_function.shape[-1]) // 2
|
| 419 |
+
pad_size_right = self.pad_to - response_function.shape[-1] - pad_size_left
|
| 420 |
+
response_function = torch.nn.functional.pad(response_function, (pad_size_left, pad_size_right))
|
| 421 |
+
return response_function
|
| 422 |
+
|
| 423 |
+
class NoiseGenerator:
|
| 424 |
+
def __init__(self, spectrum_noise_min=0., spectrum_noise_max=1/64, seed=42):
|
| 425 |
+
self.spectrum_noise_min = spectrum_noise_min
|
| 426 |
+
self.spectrum_noise_max = spectrum_noise_max
|
| 427 |
+
self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator
|
| 428 |
+
|
| 429 |
+
def __call__(self, disturbed_spectrum, seed=None):
|
| 430 |
+
rng = self.rng_getter.get_rng(seed=seed)
|
| 431 |
+
return disturbed_spectrum + torch.randn(disturbed_spectrum.shape, generator=rng) * random_value(self.spectrum_noise_min, self.spectrum_noise_max, generator=rng)
|
| 432 |
+
|
| 433 |
+
class BaseGenerator(ABC):
|
| 434 |
+
"""
|
| 435 |
+
Single-threaded base generator.
|
| 436 |
+
|
| 437 |
+
For this workload, single-threaded execution is typically faster because:
|
| 438 |
+
- Thread creation/synchronization overhead > computation time
|
| 439 |
+
- Python GIL contention during object creation
|
| 440 |
+
- Memory allocator contention when multiple threads allocate tensors
|
| 441 |
+
- CPU cache thrashing across cores
|
| 442 |
+
- Small per-thread workload doesn't amortize thread overhead
|
| 443 |
+
"""
|
| 444 |
+
def __init__(self, batch_size=64, seed=None):
|
| 445 |
+
self.batch_size = batch_size
|
| 446 |
+
self.seed = seed
|
| 447 |
+
|
| 448 |
+
def set_seed(self, seed):
|
| 449 |
+
self.seed = seed
|
| 450 |
+
|
| 451 |
+
@abstractmethod
|
| 452 |
+
def _generate_element(self, seed):
|
| 453 |
+
pass
|
| 454 |
+
|
| 455 |
+
def __iter__(self):
|
| 456 |
+
rng = torch.Generator()
|
| 457 |
+
if self.seed is not None:
|
| 458 |
+
rng.manual_seed(self.seed)
|
| 459 |
+
else:
|
| 460 |
+
rng.seed()
|
| 461 |
+
|
| 462 |
+
while True:
|
| 463 |
+
batch = []
|
| 464 |
+
# Generate unique seeds for each element in the batch
|
| 465 |
+
if self.seed is not None:
|
| 466 |
+
element_seeds = [torch.randint(0, 2**31, (1,), generator=rng).item() for _ in range(self.batch_size)]
|
| 467 |
+
else:
|
| 468 |
+
element_seeds = [None] * self.batch_size
|
| 469 |
+
|
| 470 |
+
# Single-threaded sequential generation
|
| 471 |
+
for i in range(self.batch_size):
|
| 472 |
+
batch.append(self._generate_element(element_seeds[i]))
|
| 473 |
+
|
| 474 |
+
yield self.collate_fn(batch)
|
| 475 |
+
|
| 476 |
+
@abstractmethod
|
| 477 |
+
def collate_fn(self, batch):
|
| 478 |
+
pass
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class BaseGeneratorMultithread(ABC):
|
| 482 |
+
"""
|
| 483 |
+
Multithreaded base generator (backup option).
|
| 484 |
+
|
| 485 |
+
Use only if profiling shows benefit for your specific use case
|
| 486 |
+
(e.g., very large/slow generation functions, I/O-bound operations).
|
| 487 |
+
"""
|
| 488 |
+
def __init__(self, batch_size=64, num_workers=4, seed=None, ordered_batch=False):
|
| 489 |
+
self.batch_size = batch_size
|
| 490 |
+
self.num_workers = num_workers
|
| 491 |
+
self.seed = seed
|
| 492 |
+
self.ordered_batch = ordered_batch
|
| 493 |
+
|
| 494 |
+
def set_seed(self, seed):
|
| 495 |
+
self.seed = seed
|
| 496 |
+
|
| 497 |
+
def set_ordered_batch(self, ordered_batch):
|
| 498 |
+
self.ordered_batch = ordered_batch
|
| 499 |
+
|
| 500 |
+
@abstractmethod
|
| 501 |
+
def _generate_element(self, seed):
|
| 502 |
+
pass
|
| 503 |
+
|
| 504 |
+
def __iter__(self):
|
| 505 |
+
rng = torch.Generator()
|
| 506 |
+
if self.seed is not None:
|
| 507 |
+
rng.manual_seed(self.seed)
|
| 508 |
+
else:
|
| 509 |
+
rng.seed()
|
| 510 |
+
|
| 511 |
+
while True:
|
| 512 |
+
batch = []
|
| 513 |
+
# Generate unique seeds for each element in the batch
|
| 514 |
+
if self.seed is not None:
|
| 515 |
+
element_seeds = [torch.randint(0, 2**31, (1,), generator=rng).item() for _ in range(self.batch_size)]
|
| 516 |
+
else:
|
| 517 |
+
element_seeds = [None] * self.batch_size
|
| 518 |
+
|
| 519 |
+
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
| 520 |
+
futures = [executor.submit(self._generate_element, element_seeds[i]) for i in range(self.batch_size)]
|
| 521 |
+
|
| 522 |
+
if self.ordered_batch:
|
| 523 |
+
# Maintain order: iterate futures in submission order
|
| 524 |
+
for f in futures:
|
| 525 |
+
batch.append(f.result())
|
| 526 |
+
else:
|
| 527 |
+
# Faster: process as completed (order may vary)
|
| 528 |
+
for f in as_completed(futures):
|
| 529 |
+
batch.append(f.result())
|
| 530 |
+
|
| 531 |
+
yield self.collate_fn(batch)
|
| 532 |
+
|
| 533 |
+
@abstractmethod
|
| 534 |
+
def collate_fn(self, batch):
|
| 535 |
+
pass
|
| 536 |
+
|
| 537 |
+
class Generator(BaseGenerator):
|
| 538 |
+
def __init__(self, clean_spectra_generator, response_generator, noise_generator, batch_size=64,
|
| 539 |
+
include_spectrum_data=False, include_peak_mask=False, include_response_function=False, seed=None):
|
| 540 |
+
super().__init__(batch_size=batch_size, seed=seed)
|
| 541 |
+
self.clean_spectra_generator = clean_spectra_generator
|
| 542 |
+
self.response_generator = response_generator
|
| 543 |
+
self.noise_generator = noise_generator
|
| 544 |
+
self.include_spectrum_data = include_spectrum_data
|
| 545 |
+
self.include_peak_mask = include_peak_mask
|
| 546 |
+
self.include_response_function = include_response_function
|
| 547 |
+
|
| 548 |
+
def _generate_element(self, seed):
|
| 549 |
+
# Generate different seeds for each generator from the provided seed
|
| 550 |
+
if seed is not None:
|
| 551 |
+
rng = torch.Generator()
|
| 552 |
+
rng.manual_seed(seed)
|
| 553 |
+
clean_seed = torch.randint(0, 2**31, (1,), generator=rng).item()
|
| 554 |
+
response_seed = torch.randint(0, 2**31, (1,), generator=rng).item()
|
| 555 |
+
noise_seed = torch.randint(0, 2**31, (1,), generator=rng).item()
|
| 556 |
+
else:
|
| 557 |
+
clean_seed = None
|
| 558 |
+
response_seed = None
|
| 559 |
+
noise_seed = None
|
| 560 |
+
|
| 561 |
+
clean_spectrum, extra_clean_data = self.clean_spectra_generator(seed=clean_seed)
|
| 562 |
+
response_function = self.response_generator(seed=response_seed)
|
| 563 |
+
padding_size = (response_function.shape[-1] - 1)//2
|
| 564 |
+
disturbed_spectrum = torch.nn.functional.conv1d(clean_spectrum, response_function, padding=padding_size)
|
| 565 |
+
noised_spectrum = self.noise_generator(disturbed_spectrum, seed=noise_seed)
|
| 566 |
+
out = {
|
| 567 |
+
'theoretical_spectrum': clean_spectrum,
|
| 568 |
+
'disturbed_spectrum': disturbed_spectrum,
|
| 569 |
+
'noised_spectrum': noised_spectrum,
|
| 570 |
+
}
|
| 571 |
+
if self.include_spectrum_data:
|
| 572 |
+
out['theoretical_spectrum_data'] = extra_clean_data['spectrum_data']
|
| 573 |
+
out['frq_frq'] = extra_clean_data['frq_frq']
|
| 574 |
+
if self.include_peak_mask and extra_clean_data is not None:
|
| 575 |
+
all_peaks_rel = torch.cat([peak_data["tff_relative"] for peak_data in extra_clean_data['spectrum_data']])
|
| 576 |
+
peaks_indices = all_peaks_rel.round().type(torch.int64)
|
| 577 |
+
out["peaks_mask"] = torch.scatter(torch.zeros(out["theoretical_spectrum"].shape[1]), 0, peaks_indices, 1.).unsqueeze(0)
|
| 578 |
+
if self.include_response_function:
|
| 579 |
+
out['response_function'] = response_function
|
| 580 |
+
return out
|
| 581 |
+
|
| 582 |
+
def collate_fn(self, batch):
|
| 583 |
+
tensor_keys = set(batch[0].keys())
|
| 584 |
+
for k in ['theoretical_spectrum_data', 'frq_frq']:
|
| 585 |
+
tensor_keys.discard(k)
|
| 586 |
+
out = {k: torch.stack([item[k] for item in batch]) for k in tensor_keys}
|
| 587 |
+
if 'theoretical_spectrum_data' in batch[0]:
|
| 588 |
+
out['theoretical_spectrum_data'] = [item['theoretical_spectrum_data'] for item in batch]
|
| 589 |
+
if 'frq_frq' in batch[0]:
|
| 590 |
+
out['frq_frq'] = [item['frq_frq'] for item in batch]
|
| 591 |
+
return out
|
train.py
CHANGED
|
@@ -6,7 +6,7 @@ from hydra.utils import instantiate
|
|
| 6 |
import datetime
|
| 7 |
import sys
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
-
|
| 10 |
|
| 11 |
import matplotlib
|
| 12 |
matplotlib.use('Agg')
|
|
@@ -15,8 +15,6 @@ matplotlib.use('Agg')
|
|
| 15 |
import warnings
|
| 16 |
warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
|
| 17 |
|
| 18 |
-
# from shiment import models
|
| 19 |
-
from shimnet.generators import get_datapipe
|
| 20 |
from shimnet.predict_utils import Defaults as PredictDefaults
|
| 21 |
|
| 22 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
@@ -64,6 +62,19 @@ model_weights_file = run_dir / f'model.pt'
|
|
| 64 |
optimizer = torch.optim.Adam(model.parameters())
|
| 65 |
optimizer_weights_file = run_dir / f'optimizer.pt'
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def evaluate_model(stage=0, epoch=0):
|
| 68 |
plot_dir = run_dir / "plots" / f"{stage}_{epoch}"
|
| 69 |
plot_dir.mkdir(exist_ok=True, parents=True)
|
|
@@ -72,11 +83,12 @@ def evaluate_model(stage=0, epoch=0):
|
|
| 72 |
torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt")
|
| 73 |
|
| 74 |
num_plots = config.logging.num_plots
|
| 75 |
-
pipe = get_datapipe(
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
| 80 |
batch = next(iter(pipe))
|
| 81 |
|
| 82 |
with torch.no_grad():
|
|
@@ -154,18 +166,14 @@ for i_stage, training_stage in enumerate(config.training):
|
|
| 154 |
if optimizer_weights_file.is_file():
|
| 155 |
optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True))
|
| 156 |
optimizer.param_groups[0]['lr'] = training_stage.learning_rate
|
| 157 |
-
|
| 158 |
-
pipe = get_datapipe(
|
| 159 |
-
**config.data,
|
| 160 |
-
include_response_function=True,
|
| 161 |
-
batch_size=training_stage.batch_size
|
| 162 |
-
)
|
| 163 |
|
| 164 |
losses_history = []
|
| 165 |
losses_history_limit = 64*100 // training_stage.batch_size
|
| 166 |
|
| 167 |
last_evaluation = 0
|
| 168 |
-
for epoch, batch in
|
| 169 |
|
| 170 |
# logging
|
| 171 |
iters_done = epoch*training_stage.batch_size
|
|
|
|
| 6 |
import datetime
|
| 7 |
import sys
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
+
from copy import deepcopy
|
| 10 |
|
| 11 |
import matplotlib
|
| 12 |
matplotlib.use('Agg')
|
|
|
|
| 15 |
import warnings
|
| 16 |
warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
|
| 17 |
|
|
|
|
|
|
|
| 18 |
from shimnet.predict_utils import Defaults as PredictDefaults
|
| 19 |
|
| 20 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
| 62 |
optimizer = torch.optim.Adam(model.parameters())
|
| 63 |
optimizer_weights_file = run_dir / f'optimizer.pt'
|
| 64 |
|
| 65 |
+
def get_datapipe(config_data, batch_size, alter_seed_by=None):
|
| 66 |
+
data_config = deepcopy(config_data)
|
| 67 |
+
data_config.batch_size = batch_size
|
| 68 |
+
|
| 69 |
+
# we may change the seed for different stages
|
| 70 |
+
if alter_seed_by is not None:
|
| 71 |
+
if "seed" in data_config:
|
| 72 |
+
if data_config.seed is None:
|
| 73 |
+
data_config.seed = alter_seed_by
|
| 74 |
+
else:
|
| 75 |
+
data_config.seed = config_data.seed + alter_seed_by
|
| 76 |
+
return instantiate(data_config)
|
| 77 |
+
|
| 78 |
def evaluate_model(stage=0, epoch=0):
|
| 79 |
plot_dir = run_dir / "plots" / f"{stage}_{epoch}"
|
| 80 |
plot_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
| 83 |
torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt")
|
| 84 |
|
| 85 |
num_plots = config.logging.num_plots
|
| 86 |
+
pipe = get_datapipe(config.data, batch_size=num_plots)
|
| 87 |
+
# if possible, set seed and ordered batch for reproducibility
|
| 88 |
+
if hasattr(pipe, 'set_seed'):
|
| 89 |
+
pipe.set_seed(42)
|
| 90 |
+
if hasattr(pipe, 'set_ordered_batch'):
|
| 91 |
+
pipe.set_ordered_batch(True)
|
| 92 |
batch = next(iter(pipe))
|
| 93 |
|
| 94 |
with torch.no_grad():
|
|
|
|
| 166 |
if optimizer_weights_file.is_file():
|
| 167 |
optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True))
|
| 168 |
optimizer.param_groups[0]['lr'] = training_stage.learning_rate
|
| 169 |
+
|
| 170 |
+
pipe = get_datapipe(config.data, batch_size=training_stage.batch_size, alter_seed_by=i_stage)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
losses_history = []
|
| 173 |
losses_history_limit = 64*100 // training_stage.batch_size
|
| 174 |
|
| 175 |
last_evaluation = 0
|
| 176 |
+
for epoch, batch in enumerate(pipe):
|
| 177 |
|
| 178 |
# logging
|
| 179 |
iters_done = epoch*training_stage.batch_size
|