Marek Bukowicki commited on
Commit
a58e9bb
·
1 Parent(s): c1d3733

add peak data generator from singlets list

Browse files
configs/data_generator_from_peak_list.yaml CHANGED
@@ -3,10 +3,15 @@ data:
3
  input_normalization_height: 16.0
4
  clean_spectra_generator:
5
  _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
 
 
6
  peaks_parameter_generator:
7
  _target_: shimnet.generators.MultipletDataFromMultipletsLibrary
8
  number_of_signals_min: 2 # null to use all signals (both min and max must be null)
9
  number_of_signals_max: 10 # null to use all signals (both min and max must be null)
 
 
 
10
  multiplet_height_factor_min: 0.5
11
  multiplet_height_factor_max: 2
12
  multiplet_width_factor_min: 0.5
@@ -19,13 +24,7 @@ data:
19
  - data/multiplets_lists/mieszanina_poReakcji_mono-click.csv
20
  peak_data_parser:
21
  _target_: shimnet.generators.PeaksParametersParser
22
- use_original_peak_position: false
23
  seed: null
24
- pixels: 2048
25
- frq_step: ${metadata.frq_step}
26
- include_tff_relative: true
27
- relative_frequency_min: -0.4
28
- relative_frequency_max: 0.4
29
  response_generator:
30
  _target_: shimnet.generators.ResponseGenerator
31
  response_function_library:
@@ -40,4 +39,7 @@ data:
40
  include_peak_mask: false
41
  include_response_function: true
42
  seed: 44 # null means random seed
43
- batch_size: 64 # to be set in training script
 
 
 
 
3
  input_normalization_height: 16.0
4
  clean_spectra_generator:
5
  _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
6
+ pixels: 2048
7
+ frq_step: ${metadata.frq_step}
8
  peaks_parameter_generator:
9
  _target_: shimnet.generators.MultipletDataFromMultipletsLibrary
10
  number_of_signals_min: 2 # null to use all signals (both min and max must be null)
11
  number_of_signals_max: 10 # null to use all signals (both min and max must be null)
12
+ use_original_peak_position: false
13
+ relative_frequency_min: -0.4
14
+ relative_frequency_max: 0.4
15
  multiplet_height_factor_min: 0.5
16
  multiplet_height_factor_max: 2
17
  multiplet_width_factor_min: 0.5
 
24
  - data/multiplets_lists/mieszanina_poReakcji_mono-click.csv
25
  peak_data_parser:
26
  _target_: shimnet.generators.PeaksParametersParser
 
27
  seed: null
 
 
 
 
 
28
  response_generator:
29
  _target_: shimnet.generators.ResponseGenerator
30
  response_function_library:
 
39
  include_peak_mask: false
40
  include_response_function: true
41
  seed: 44 # null means random seed
42
+ batch_size: null # to be set in training script
43
+ metadata:
44
+ frq_step: 0.30048
45
+ spectrometer_frequency: 600.0
configs/from_peak_list/singlets_fixed_positions.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: ShimNetWithSCRF
3
+ kwargs:
4
+ rensponse_length: 81
5
+ resnponse_head_dims:
6
+ - 128
7
+ training:
8
+ - batch_size: 64
9
+ learning_rate: 0.001
10
+ max_iters: 1600000
11
+ - batch_size: 512
12
+ learning_rate: 0.001
13
+ max_iters: 25600000
14
+ - batch_size: 512
15
+ learning_rate: 0.0005
16
+ max_iters: 12800000
17
+ losses_weights:
18
+ clean: 1.0
19
+ noised: 1.0
20
+ response: 10.0 # increased due to input height 16
21
+ data:
22
+ _target_: shimnet.generators.Generator
23
+ input_normalization_height: 16.0
24
+ clean_spectra_generator:
25
+ _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
26
+ pixels: null # will be inferred from frequency range and step
27
+ frq_step: ${metadata.frq_step}
28
+ frequency_min: -50
29
+ frequency_max: 650
30
+ peaks_parameter_generator:
31
+ _target_: shimnet.generators.PeaksParametersFromSinglets
32
+ number_of_signals_min: 15
33
+ number_of_signals_max: 45
34
+ use_original_position: true
35
+ position_hz_change_min: -5.0
36
+ position_hz_change_max: 5.0
37
+ use_original_height: false
38
+ height_min: 0.02
39
+ height_max: 10.0
40
+ width_factor_min: 0.8
41
+ width_factor_max: 1.2
42
+ gaussian_fraction_change_min: -0.2
43
+ gaussian_fraction_change_max: 0.2
44
+ singlets_files:
45
+ - data/multiplets_lists/mieszanina_po_reakcji_2_squeezed-0.0-20.0Hz.csv
46
+ response_generator:
47
+ _target_: shimnet.generators.ResponseGenerator
48
+ response_function_library:
49
+ _target_: shimnet.generators.ResponseLibrary
50
+ response_files:
51
+ - data/scrf_81_600MHz.pt
52
+ noise_generator:
53
+ _target_: shimnet.generators.NoiseGenerator
54
+ spectrum_noise_min: 0.0
55
+ spectrum_noise_max: 0.01
56
+ include_spectrum_data: false
57
+ include_peak_mask: false
58
+ include_response_function: true
59
+ seed: 44 # null means random seed
60
+ batch_size: null # to be set in training script
61
+ logging:
62
+ step: 1000000
63
+ num_plots: 32
64
+ metadata:
65
+ frq_step: 0.30048
66
+ spectrometer_frequency: 600.0
configs/from_peak_list/singlets_random_positions.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: ShimNetWithSCRF
3
+ kwargs:
4
+ rensponse_length: 81
5
+ resnponse_head_dims:
6
+ - 128
7
+ training:
8
+ - batch_size: 64
9
+ learning_rate: 0.001
10
+ max_iters: 1600000
11
+ - batch_size: 512
12
+ learning_rate: 0.001
13
+ max_iters: 25600000
14
+ - batch_size: 512
15
+ learning_rate: 0.0005
16
+ max_iters: 12800000
17
+ losses_weights:
18
+ clean: 1.0
19
+ noised: 1.0
20
+ response: 10.0 # increased due to input height 16
21
+ data:
22
+ _target_: shimnet.generators.Generator
23
+ input_normalization_height: 16.0
24
+ clean_spectra_generator:
25
+ _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
26
+ pixels: 2048
27
+ frq_step: ${metadata.frq_step}
28
+ relative_frequency_min: -0.4
29
+ relative_frequency_max: 0.4
30
+ peaks_parameter_generator:
31
+ _target_: shimnet.generators.PeaksParametersFromSinglets
32
+ number_of_signals_min: 15
33
+ number_of_signals_max: 30
34
+ use_original_position: false
35
+ use_original_height: false
36
+ height_min: 0.02
37
+ height_max: 10.0
38
+ width_factor_min: 0.8
39
+ width_factor_max: 1.2
40
+ gaussian_fraction_change_min: -0.2
41
+ gaussian_fraction_change_max: 0.2
42
+ singlets_files:
43
+ - data/multiplets_lists/azydekbenzylu_sub1_mono-click.csv
44
+ response_generator:
45
+ _target_: shimnet.generators.ResponseGenerator
46
+ response_function_library:
47
+ _target_: shimnet.generators.ResponseLibrary
48
+ response_files:
49
+ - data/scrf_81_600MHz.pt
50
+ noise_generator:
51
+ _target_: shimnet.generators.NoiseGenerator
52
+ spectrum_noise_min: 0.0
53
+ spectrum_noise_max: 0.01
54
+ include_spectrum_data: false
55
+ include_peak_mask: false
56
+ include_response_function: true
57
+ seed: 44 # null means random seed
58
+ batch_size: null # to be set in training script
59
+ logging:
60
+ step: 1000000
61
+ num_plots: 32
62
+ metadata:
63
+ frq_step: 0.30048
64
+ spectrometer_frequency: 600.0
configs/shimnet_600_modular.yaml CHANGED
@@ -7,13 +7,13 @@ model:
7
  training:
8
  - batch_size: 64
9
  learning_rate: 0.001
10
- max_iters: 1600000
11
  - batch_size: 512
12
  learning_rate: 0.001
13
- max_iters: 25600000
14
  - batch_size: 512
15
  learning_rate: 0.0005
16
- max_iters: 12800000
17
  losses_weights:
18
  clean: 1.0
19
  noised: 1.0
@@ -27,13 +27,13 @@ data:
27
  _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
28
  pixels: 2048
29
  frq_step: ${metadata.frq_step}
30
- relative_frequency_min: -0.4
31
- relative_frequency_max: 0.4
32
  peaks_parameter_generator:
33
  _target_: shimnet.generators.PeaksParameterDataGenerator
34
  atom_groups_data_file: data/multiplets_10000_parsed.txt
35
  number_of_signals_min: 2
36
  number_of_signals_max: 5
 
 
37
  spectrum_width_min: 0.2
38
  spectrum_width_max: 1.0
39
  relative_width_min: 1.0
 
7
  training:
8
  - batch_size: 64
9
  learning_rate: 0.001
10
+ max_iters: 16000
11
  - batch_size: 512
12
  learning_rate: 0.001
13
+ max_iters: 256000
14
  - batch_size: 512
15
  learning_rate: 0.0005
16
+ max_iters: 128000
17
  losses_weights:
18
  clean: 1.0
19
  noised: 1.0
 
27
  _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
28
  pixels: 2048
29
  frq_step: ${metadata.frq_step}
 
 
30
  peaks_parameter_generator:
31
  _target_: shimnet.generators.PeaksParameterDataGenerator
32
  atom_groups_data_file: data/multiplets_10000_parsed.txt
33
  number_of_signals_min: 2
34
  number_of_signals_max: 5
35
+ relative_frequency_min: -0.4
36
+ relative_frequency_max: 0.4
37
  spectrum_width_min: 0.2
38
  spectrum_width_max: 1.0
39
  relative_width_min: 1.0
shimnet/generators.py CHANGED
@@ -1,5 +1,6 @@
1
  from enum import Enum
2
  from copy import deepcopy
 
3
  # from pathlib import Path
4
  import numpy as np
5
  import pandas as pd
@@ -8,12 +9,19 @@ import torchdata
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  from abc import ABC, abstractmethod
10
 
11
- def random_value(min_value, max_value, generator=None):
12
  return (min_value + torch.rand(1, generator=generator) * (max_value - min_value)).item()
 
13
 
14
  def random_loguniform(min_value, max_value, generator=None):
15
  return (min_value * torch.exp(torch.rand(1, generator=generator) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
16
 
 
 
 
 
 
 
17
  def spectrum_from_peaks_data(peaks_parameters: dict | list, frq_frq:torch.Tensor, relative_frequency=False):
18
 
19
  if isinstance(peaks_parameters, dict):
@@ -354,6 +362,8 @@ class PeaksParameterDataGenerator:
354
  atom_groups_data_file=None,
355
  number_of_signals_min=1,
356
  number_of_signals_max=8,
 
 
357
  spectrum_width_min=0.2,
358
  spectrum_width_max=1,
359
  relative_width_min=1,
@@ -380,6 +390,9 @@ class PeaksParameterDataGenerator:
380
  self.tff_max = tff_max
381
  self.number_of_signals_min = number_of_signals_min
382
  self.number_of_signals_max = number_of_signals_max
 
 
 
383
  self.spectrum_width_min = spectrum_width_min
384
  self.spectrum_width_max = spectrum_width_max
385
  self.relative_width_min = relative_width_min
@@ -397,9 +410,9 @@ class PeaksParameterDataGenerator:
397
 
398
  self.rng_getter = RngGetter(seed=seed)
399
 
400
- def set_tff_range(self, tff_min, tff_max):
401
- self.tff_min = tff_min
402
- self.tff_max = tff_max
403
 
404
  def __call__(self, seed=None):
405
  """
@@ -483,6 +496,8 @@ class TheoreticalMultipletSpectraGenerator:
483
  frq_step=11160.7142857 / 32768,
484
  relative_frequency_min=-0.4,
485
  relative_frequency_max=0.4,
 
 
486
  include_tff_relative=False,
487
  seed=42
488
  ):
@@ -493,15 +508,34 @@ class TheoreticalMultipletSpectraGenerator:
493
  self.relative_frequency_min = relative_frequency_min
494
  self.relative_frequency_max = relative_frequency_max
495
  self.include_tff_relative = include_tff_relative
496
- self.frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step
 
497
 
498
  self.peaks_parameter_generator = peaks_parameter_generator
499
- self.peaks_parameter_generator.set_tff_range(
500
- tff_min=relative_frequency_min * pixels * frq_step,
501
- tff_max=relative_frequency_max * pixels * frq_step
502
- )
503
 
504
  # self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
  def __call__(self, seed=None):
507
  """
@@ -515,7 +549,6 @@ class TheoreticalMultipletSpectraGenerator:
515
  """
516
  # Generate peak parameters (peaks_parameter_generator has its own RngGetter)
517
  peaks_parameters_data = self.peaks_parameter_generator(seed=seed)
518
-
519
 
520
  # Add tff_relative if requested
521
  if self.include_tff_relative:
@@ -530,10 +563,10 @@ class TheoreticalMultipletSpectraGenerator:
530
 
531
  class PeaksParametersNames(Enum):
532
  """Enum for standardized peak parameter names."""
533
- tff_lin = "position_hz"
534
- thf_lin = "height"
535
- twf_lin = "width_hz"
536
- trf_lin = "gaussian_fraction"
537
 
538
  @classmethod
539
  def keys(cls):
@@ -544,6 +577,7 @@ class PeaksParametersNames(Enum):
544
  return [member.name for member in cls]
545
 
546
  class PeaksParametersParser:
 
547
  def __init__(self,
548
  alias_position_hz = None,
549
  alias_height = None,
@@ -565,16 +599,16 @@ class PeaksParametersParser:
565
 
566
  def transform_single_peak(self, peak: dict) -> dict:
567
  parsed_peak = {
568
- PeaksParametersNames("position_hz").name: peak.get(self.alias_position_hz, self.default_position_hz),
569
- PeaksParametersNames("height").name: peak.get(self.alias_height, self.default_height),
570
- PeaksParametersNames("width_hz").name: peak.get(self.alias_width_hz, self.default_width_hz),
571
- PeaksParametersNames("gaussian_fraction").name: peak.get(self.alias_gaussian_fraction, self.default_gaussian_fraction),
572
  }
573
  # Validate and convert other peak parameters
574
  for k, v in parsed_peak.items():
575
  if v is None:
576
  raise ValueError(f"Peak parameter '{k}' is None.")
577
- parsed_peak[k] = torch.atleast_1d(torch.tensor(v, dtype=torch.float32))
578
  return parsed_peak
579
 
580
  def transform(self, spectrum_peaks: list[dict]) -> list[dict]:
@@ -643,18 +677,26 @@ class MultipletDataFromMultipletsLibrary:
643
  use_original_peak_position=True,
644
  number_of_signals_min=None,
645
  number_of_signals_max=None,
 
 
646
  spectrum_width_factor_min=1,
647
  spectrum_width_factor_max=1,
648
  multiplet_width_factor_min=1,
649
  multiplet_width_factor_max=1,
 
 
650
  spectrum_height_factor_min=1,
651
  spectrum_height_factor_max=1,
652
  multiplet_height_factor_min=1,
653
  multiplet_height_factor_max=1,
 
 
654
  position_shift_min=0,
655
  position_shift_max=0,
656
  gaussian_fraction_change_min=None,
657
  gaussian_fraction_change_max=None,
 
 
658
  seed=42
659
  ):
660
 
@@ -665,6 +707,8 @@ class MultipletDataFromMultipletsLibrary:
665
  self.rng_getter = RngGetter(seed=seed)
666
  self.tff_min = tff_min
667
  self.tff_max = tff_max
 
 
668
  self.use_original_peak_position = use_original_peak_position
669
  self.number_of_signals_min = number_of_signals_min
670
  self.number_of_signals_max = number_of_signals_max
@@ -672,18 +716,25 @@ class MultipletDataFromMultipletsLibrary:
672
  self.spectrum_width_factor_max = spectrum_width_factor_max
673
  self.multiplet_width_factor_min = multiplet_width_factor_min
674
  self.multiplet_width_factor_max = multiplet_width_factor_max
 
 
675
  self.spectrum_height_factor_min = spectrum_height_factor_min
676
  self.spectrum_height_factor_max = spectrum_height_factor_max
677
  self.multiplet_height_factor_min = multiplet_height_factor_min
678
  self.multiplet_height_factor_max = multiplet_height_factor_max
 
 
679
  self.position_shift_min = position_shift_min
680
  self.position_shift_max = position_shift_max
681
  self.gaussian_fraction_change_min = gaussian_fraction_change_min
682
  self.gaussian_fraction_change_max = gaussian_fraction_change_max
 
 
 
 
 
 
683
 
684
- def set_tff_range(self, tff_min, tff_max):
685
- self.tff_min = tff_min
686
- self.tff_max = tff_max
687
 
688
  def __call__(self, seed=None):
689
  if (not self.use_original_peak_position) and (self.tff_min is None or self.tff_max is None):
@@ -741,7 +792,12 @@ class MultipletDataFromMultipletsLibrary:
741
  self.multiplet_width_factor_max,
742
  generator=rng
743
  )
744
- peak_parameters["twf_lin"] = peak_parameters["twf_lin"] * spectrum_width_factor * multiplet_width_factor
 
 
 
 
 
745
 
746
  # height
747
  multiplet_height_factor = random_loguniform(
@@ -749,11 +805,18 @@ class MultipletDataFromMultipletsLibrary:
749
  self.multiplet_height_factor_max,
750
  generator=rng
751
  )
752
- peak_parameters["thf_lin"] = peak_parameters["thf_lin"] * spectrum_height_factor * multiplet_height_factor
 
 
 
 
 
753
 
754
  # gaussian contribution
755
  if self.gaussian_fraction_change_min is not None:
756
  gaussian_contribution_shift = random_value(self.gaussian_fraction_change_min, self.gaussian_fraction_change_max, generator=rng)
 
 
757
  peak_parameters["trf_lin"] = torch.clip(peak_parameters["trf_lin"] + gaussian_contribution_shift, 0., 1.)
758
 
759
  return peaks_parameters_data
@@ -965,4 +1028,114 @@ class Generator(BaseGenerator):
965
  out['theoretical_spectrum_data'] = [item['theoretical_spectrum_data'] for item in batch]
966
  if 'frq_frq' in batch[0]:
967
  out['frq_frq'] = [item['frq_frq'] for item in batch]
968
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from enum import Enum
2
  from copy import deepcopy
3
+ from typing import Optional
4
  # from pathlib import Path
5
  import numpy as np
6
  import pandas as pd
 
9
  from concurrent.futures import ThreadPoolExecutor, as_completed
10
  from abc import ABC, abstractmethod
11
 
12
+ def random_uniform(min_value, max_value, generator=None):
13
  return (min_value + torch.rand(1, generator=generator) * (max_value - min_value)).item()
14
+ random_value = random_uniform
15
 
16
  def random_loguniform(min_value, max_value, generator=None):
17
  return (min_value * torch.exp(torch.rand(1, generator=generator) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
18
 
19
+ def random_uniform_vector(min_value, max_value, size, generator=None):
20
+ return min_value + torch.rand(size, generator=generator) * (max_value - min_value)
21
+
22
+ def random_loguniform_vector(min_value, max_value, size, generator=None):
23
+ return min_value * torch.exp(torch.rand(size, generator=generator) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))
24
+
25
  def spectrum_from_peaks_data(peaks_parameters: dict | list, frq_frq:torch.Tensor, relative_frequency=False):
26
 
27
  if isinstance(peaks_parameters, dict):
 
362
  atom_groups_data_file=None,
363
  number_of_signals_min=1,
364
  number_of_signals_max=8,
365
+ relative_frequency_min=-0.4,
366
+ relative_frequency_max=0.4,
367
  spectrum_width_min=0.2,
368
  spectrum_width_max=1,
369
  relative_width_min=1,
 
390
  self.tff_max = tff_max
391
  self.number_of_signals_min = number_of_signals_min
392
  self.number_of_signals_max = number_of_signals_max
393
+ self.relative_frequency_min = relative_frequency_min
394
+ self.relative_frequency_max = relative_frequency_max
395
+
396
  self.spectrum_width_min = spectrum_width_min
397
  self.spectrum_width_max = spectrum_width_max
398
  self.relative_width_min = relative_width_min
 
410
 
411
  self.rng_getter = RngGetter(seed=seed)
412
 
413
+ def set_frq_range(self, frq_min, frq_max):
414
+ self.tff_min = frq_min * self.relative_frequency_min
415
+ self.tff_max = frq_max * self.relative_frequency_max
416
 
417
  def __call__(self, seed=None):
418
  """
 
496
  frq_step=11160.7142857 / 32768,
497
  relative_frequency_min=-0.4,
498
  relative_frequency_max=0.4,
499
+ frequency_min=None, #if None, the 0 will be in the center of spectrum
500
+ frequency_max=None,
501
  include_tff_relative=False,
502
  seed=42
503
  ):
 
508
  self.relative_frequency_min = relative_frequency_min
509
  self.relative_frequency_max = relative_frequency_max
510
  self.include_tff_relative = include_tff_relative
511
+ # Frequency axis
512
+ self.frq_frq, frq_min, frq_max = self._frequency_axis_from_parameters(frq_step, pixels, frequency_min, frequency_max)
513
 
514
  self.peaks_parameter_generator = peaks_parameter_generator
515
+ self.peaks_parameter_generator.set_frq_range(frq_min, frq_max)
 
 
 
516
 
517
  # self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator
518
+
519
+ def _frequency_axis_from_parameters(self, frq_step, pixels, frequency_min, frequency_max):
520
+ """frq_step is never None, pixels, frequency_min or frequency_max can be None
521
+ """
522
+ # Option 1: from pixels and frq_step
523
+ if pixels is not None:
524
+ assert (frequency_min is None) or (frequency_max is None)
525
+ if (frequency_min is None) and (frequency_max is None): # if both are None, center at 0
526
+ frequency_min = -(pixels // 2) * frq_step
527
+ elif frequency_min is None: # frequency_max is not None, use it to calculate frequency_min
528
+ frequency_min = frequency_max - pixels * frq_step
529
+ frq_frq = torch.arange(0, pixels) * frq_step + frequency_min
530
+ # Option 2: from frequency_min and frequency_max
531
+ elif (frequency_min is not None) and (frequency_max is not None):
532
+ pixels = round((frequency_max - frequency_min) / frq_step)
533
+ frq_frq = torch.arange(0, pixels) * frq_step + frequency_min
534
+ else:
535
+ raise ValueError("Insufficient parameters to determine frequency axis.")
536
+
537
+ return frq_frq, frq_frq[0], frq_frq[-1]
538
+
539
 
540
  def __call__(self, seed=None):
541
  """
 
549
  """
550
  # Generate peak parameters (peaks_parameter_generator has its own RngGetter)
551
  peaks_parameters_data = self.peaks_parameter_generator(seed=seed)
 
552
 
553
  # Add tff_relative if requested
554
  if self.include_tff_relative:
 
563
 
564
  class PeaksParametersNames(Enum):
565
  """Enum for standardized peak parameter names."""
566
+ position_hz ="tff_lin"
567
+ height = "thf_lin"
568
+ width_hz = "twf_lin"
569
+ gaussian_fraction = "trf_lin"
570
 
571
  @classmethod
572
  def keys(cls):
 
577
  return [member.name for member in cls]
578
 
579
  class PeaksParametersParser:
580
+ """class to convert peaks parameters from `{"width_hz": [...], "height": ..., ...}` format to `{"twf_lin": torch.tensor([...]), "thf_lin": ..., ...}` format."""
581
  def __init__(self,
582
  alias_position_hz = None,
583
  alias_height = None,
 
599
 
600
  def transform_single_peak(self, peak: dict) -> dict:
601
  parsed_peak = {
602
+ PeaksParametersNames.position_hz.value: peak.get(self.alias_position_hz, self.default_position_hz),
603
+ PeaksParametersNames.height.value: peak.get(self.alias_height, self.default_height),
604
+ PeaksParametersNames.width_hz.value: peak.get(self.alias_width_hz, self.default_width_hz),
605
+ PeaksParametersNames.gaussian_fraction.value: peak.get(self.alias_gaussian_fraction, self.default_gaussian_fraction),
606
  }
607
  # Validate and convert other peak parameters
608
  for k, v in parsed_peak.items():
609
  if v is None:
610
  raise ValueError(f"Peak parameter '{k}' is None.")
611
+ parsed_peak[k] = torch.atleast_1d(v.float() if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=torch.float32))
612
  return parsed_peak
613
 
614
  def transform(self, spectrum_peaks: list[dict]) -> list[dict]:
 
677
  use_original_peak_position=True,
678
  number_of_signals_min=None,
679
  number_of_signals_max=None,
680
+ relative_frequency_min=None,
681
+ relative_frequency_max=None,
682
  spectrum_width_factor_min=1,
683
  spectrum_width_factor_max=1,
684
  multiplet_width_factor_min=1,
685
  multiplet_width_factor_max=1,
686
+ multiplet_width_additive_min=0,
687
+ multiplet_width_additive_max=0,
688
  spectrum_height_factor_min=1,
689
  spectrum_height_factor_max=1,
690
  multiplet_height_factor_min=1,
691
  multiplet_height_factor_max=1,
692
+ multiplet_height_additive_min=0,
693
+ multiplet_height_additive_max=0,
694
  position_shift_min=0,
695
  position_shift_max=0,
696
  gaussian_fraction_change_min=None,
697
  gaussian_fraction_change_max=None,
698
+ gaussian_fraction_change_additive_min=0.,
699
+ gaussian_fraction_change_additive_max=0.,
700
  seed=42
701
  ):
702
 
 
707
  self.rng_getter = RngGetter(seed=seed)
708
  self.tff_min = tff_min
709
  self.tff_max = tff_max
710
+ self.relative_frequency_min = relative_frequency_min
711
+ self.relative_frequency_max = relative_frequency_max
712
  self.use_original_peak_position = use_original_peak_position
713
  self.number_of_signals_min = number_of_signals_min
714
  self.number_of_signals_max = number_of_signals_max
 
716
  self.spectrum_width_factor_max = spectrum_width_factor_max
717
  self.multiplet_width_factor_min = multiplet_width_factor_min
718
  self.multiplet_width_factor_max = multiplet_width_factor_max
719
+ self.multiplet_width_additive_min = multiplet_width_additive_min
720
+ self.multiplet_width_additive_max = multiplet_width_additive_max
721
  self.spectrum_height_factor_min = spectrum_height_factor_min
722
  self.spectrum_height_factor_max = spectrum_height_factor_max
723
  self.multiplet_height_factor_min = multiplet_height_factor_min
724
  self.multiplet_height_factor_max = multiplet_height_factor_max
725
+ self.multiplet_height_additive_min = multiplet_height_additive_min
726
+ self.multiplet_height_additive_max = multiplet_height_additive_max
727
  self.position_shift_min = position_shift_min
728
  self.position_shift_max = position_shift_max
729
  self.gaussian_fraction_change_min = gaussian_fraction_change_min
730
  self.gaussian_fraction_change_max = gaussian_fraction_change_max
731
+ self.gaussian_fraction_change_additive_min = gaussian_fraction_change_additive_min
732
+ self.gaussian_fraction_change_additive_max = gaussian_fraction_change_additive_max
733
+
734
+ def set_frq_range(self, frq_min, frq_max):
735
+ self.tff_min = frq_min * self.relative_frequency_min
736
+ self.tff_max = frq_max * self.relative_frequency_max
737
 
 
 
 
738
 
739
  def __call__(self, seed=None):
740
  if (not self.use_original_peak_position) and (self.tff_min is None or self.tff_max is None):
 
792
  self.multiplet_width_factor_max,
793
  generator=rng
794
  )
795
+ multiplet_width_additive = random_uniform(
796
+ self.multiplet_width_additive_min,
797
+ self.multiplet_width_additive_max,
798
+ generator=rng
799
+ )
800
+ peak_parameters["twf_lin"] = peak_parameters["twf_lin"] * spectrum_width_factor * multiplet_width_factor + multiplet_width_additive
801
 
802
  # height
803
  multiplet_height_factor = random_loguniform(
 
805
  self.multiplet_height_factor_max,
806
  generator=rng
807
  )
808
+ multiplet_height_additive = random_uniform(
809
+ self.multiplet_height_additive_min,
810
+ self.multiplet_height_additive_max,
811
+ generator=rng
812
+ )
813
+ peak_parameters["thf_lin"] = peak_parameters["thf_lin"] * spectrum_height_factor * multiplet_height_factor + multiplet_height_additive
814
 
815
  # gaussian contribution
816
  if self.gaussian_fraction_change_min is not None:
817
  gaussian_contribution_shift = random_value(self.gaussian_fraction_change_min, self.gaussian_fraction_change_max, generator=rng)
818
+ gaussian_contribution_additive = random_value(self.gaussian_fraction_change_additive_min, self.gaussian_fraction_change_additive_max, generator=rng)
819
+ gaussian_contribution_shift += gaussian_contribution_additive
820
  peak_parameters["trf_lin"] = torch.clip(peak_parameters["trf_lin"] + gaussian_contribution_shift, 0., 1.)
821
 
822
  return peaks_parameters_data
 
1028
  out['theoretical_spectrum_data'] = [item['theoretical_spectrum_data'] for item in batch]
1029
  if 'frq_frq' in batch[0]:
1030
  out['frq_frq'] = [item['frq_frq'] for item in batch]
1031
+ return out
1032
+
1033
+ class PeaksParametersFromSinglets:
1034
+ def __init__(self,
1035
+ singlets_files: list[pd.DataFrame],
1036
+ number_of_signals_min: int = 5,
1037
+ number_of_signals_max: int = 20,
1038
+ use_original_position: bool = True,
1039
+ position_hz_min: Optional[float] = None,
1040
+ position_hz_max: Optional[float] = None,
1041
+ position_hz_change_min: float = 0.0,
1042
+ position_hz_change_max: float = 0.0,
1043
+ relative_frequency_min: float = -0.4, # used only if position_hz_min/max are None
1044
+ relative_frequency_max: float = 0.4,
1045
+ use_original_width: bool = True,
1046
+ width_hz_min: float = 0.2,
1047
+ width_hz_max: float = 2.0,
1048
+ width_factor_min: float = 1.0,
1049
+ width_factor_max: float = 1.0,
1050
+ width_hz_change_min: float = 0.0,
1051
+ width_hz_change_max: float = 0.0,
1052
+ use_original_height: bool = True,
1053
+ height_min: float = 0.1,
1054
+ height_max: float = 10.0,
1055
+ height_factor_min: float = 1.0,
1056
+ height_factor_max: float = 1.0,
1057
+ height_change_min: float = 0.0,
1058
+ height_change_max: float = 0.0,
1059
+ use_original_gaussian_fraction: bool = True,
1060
+ gaussian_fraction_min: float = 0.0,
1061
+ gaussian_fraction_max: float = 1.0,
1062
+ gaussian_fraction_change_min: float = 0.0,
1063
+ gaussian_fraction_change_max: float = 0.0,
1064
+ seed=42
1065
+ ):
1066
+ self.peaks_rows = pd.concat([pd.read_csv(f) for f in singlets_files], ignore_index=True)
1067
+
1068
+ # number of signals
1069
+ self.number_of_signals_min = number_of_signals_min
1070
+ self.number_of_signals_max = number_of_signals_max
1071
+ # position
1072
+ self.use_original_position = use_original_position
1073
+ self.position_hz_min = position_hz_min
1074
+ self.position_hz_max = position_hz_max
1075
+ self.position_hz_change_min = position_hz_change_min
1076
+ self.position_hz_change_max = position_hz_change_max
1077
+ self.relative_frequency_min = relative_frequency_min
1078
+ self.relative_frequency_max = relative_frequency_max
1079
+ # width
1080
+ self.use_original_width = use_original_width
1081
+ self.width_hz_min = width_hz_min
1082
+ self.width_hz_max = width_hz_max
1083
+ self.width_factor_min = width_factor_min
1084
+ self.width_factor_max = width_factor_max
1085
+ self.width_hz_change_min = width_hz_change_min
1086
+ self.width_hz_change_max = width_hz_change_max
1087
+ # height
1088
+ self.use_original_height = use_original_height
1089
+ self.height_min = height_min
1090
+ self.height_max = height_max
1091
+ self.height_factor_min = height_factor_min
1092
+ self.height_factor_max = height_factor_max
1093
+ self.height_change_min = height_change_min
1094
+ self.height_change_max = height_change_max
1095
+ # gaussian fraction
1096
+ self.use_original_gaussian_fraction = use_original_gaussian_fraction
1097
+ self.gaussian_fraction_min = gaussian_fraction_min
1098
+ self.gaussian_fraction_max = gaussian_fraction_max
1099
+ self.gaussian_fraction_change_min = gaussian_fraction_change_min
1100
+ self.gaussian_fraction_change_max = gaussian_fraction_change_max
1101
+
1102
+ self.rng_getter = RngGetter(seed=seed)
1103
+
1104
+ def set_frq_range(self, frq_min, frq_max):
1105
+ self.position_hz_min = frq_min * self.relative_frequency_min
1106
+ self.position_hz_max = frq_max * self.relative_frequency_max
1107
+
1108
+ def __call__(self, seed=None) -> list[dict]:
1109
+ rng = self.rng_getter.get_rng(seed=seed)
1110
+
1111
+ number_of_signals = torch.randint(
1112
+ low=self.number_of_signals_min,
1113
+ high=min(self.number_of_signals_max, len(self.peaks_rows) + 1),
1114
+ size=[],
1115
+ generator=rng
1116
+ )
1117
+ selected_peaks = self.peaks_rows.sample(n=number_of_signals.item(), random_state=seed)
1118
+
1119
+ multiplet_data = {}
1120
+ # position
1121
+ if self.use_original_position:
1122
+ multiplet_data[PeaksParametersNames.position_hz.value] = torch.tensor(selected_peaks["position_hz"].values, dtype=torch.float32) + random_uniform_vector(self.position_hz_change_min, self.position_hz_change_max, size=len(selected_peaks))
1123
+ else:
1124
+ multiplet_data[PeaksParametersNames.position_hz.value] = random_uniform_vector(self.position_hz_min, self.position_hz_max, size=len(selected_peaks))
1125
+ # width
1126
+ if self.use_original_width:
1127
+ multiplet_data[PeaksParametersNames.width_hz.value] = torch.tensor(selected_peaks["width_hz"].values, dtype=torch.float32) * random_uniform_vector(self.width_factor_min, self.width_factor_max, size=len(selected_peaks)) + random_uniform_vector(self.width_hz_change_min, self.width_hz_change_max, size=len(selected_peaks))
1128
+ else:
1129
+ multiplet_data[PeaksParametersNames.width_hz.value] = random_loguniform_vector(self.width_hz_min, self.width_hz_max, size=len(selected_peaks))
1130
+ # height
1131
+ if self.use_original_height:
1132
+ multiplet_data[PeaksParametersNames.height.value] = torch.tensor(selected_peaks["height"].values, dtype=torch.float32) * random_uniform_vector(self.height_factor_min, self.height_factor_max, size=len(selected_peaks)) + random_uniform_vector(self.height_change_min, self.height_change_max, size=len(selected_peaks))
1133
+ else:
1134
+ multiplet_data[PeaksParametersNames.height.value] = random_loguniform_vector(self.height_min, self.height_max, size=len(selected_peaks))
1135
+ # gaussian fraction
1136
+ if self.use_original_gaussian_fraction:
1137
+ multiplet_data[PeaksParametersNames.gaussian_fraction.value] = torch.clamp(torch.tensor(selected_peaks["gaussian_fraction"].values, dtype=torch.float32) + random_uniform_vector(self.gaussian_fraction_change_min, self.gaussian_fraction_change_max, size=len(selected_peaks)), 0.0, 1.0)
1138
+ else:
1139
+ multiplet_data[PeaksParametersNames.gaussian_fraction.value] = random_uniform_vector(self.gaussian_fraction_min, self.gaussian_fraction_max, size=len(selected_peaks))
1140
+
1141
+ return [multiplet_data]