Marek Bukowicki commited on
Commit
c1d3733
·
1 Parent(s): 3de469d

working modular pipeline to work with peak lists

Browse files
configs/data_generator_from_peak_list.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ _target_: shimnet.generators.Generator
3
+ input_normalization_height: 16.0
4
+ clean_spectra_generator:
5
+ _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
6
+ peaks_parameter_generator:
7
+ _target_: shimnet.generators.MultipletDataFromMultipletsLibrary
8
+ number_of_signals_min: 2 # null to use all signals (both min and max must be null)
9
+ number_of_signals_max: 10 # null to use all signals (both min and max must be null)
10
+ multiplet_height_factor_min: 0.5
11
+ multiplet_height_factor_max: 2
12
+ multiplet_width_factor_min: 0.5
13
+ multiplet_width_factor_max: 2
14
+ multiplets_library:
15
+ _target_: shimnet.generators.MultipletsLibrary
16
+ csv_files_paths:
17
+ - data/multiplets_lists/azydekbenzylu_sub1_mono-click.csv
18
+ - data/multiplets_lists/fenyloacetylen_sub2_mono-click.csv
19
+ - data/multiplets_lists/mieszanina_poReakcji_mono-click.csv
20
+ peak_data_parser:
21
+ _target_: shimnet.generators.PeaksParametersParser
22
+ use_original_peak_position: false
23
+ seed: null
24
+ pixels: 2048
25
+ frq_step: ${metadata.frq_step}
26
+ include_tff_relative: true
27
+ relative_frequency_min: -0.4
28
+ relative_frequency_max: 0.4
29
+ response_generator:
30
+ _target_: shimnet.generators.ResponseGenerator
31
+ response_function_library:
32
+ _target_: shimnet.generators.ResponseLibrary
33
+ response_files:
34
+ - data/scrf_81_600MHz.pt
35
+ noise_generator:
36
+ _target_: shimnet.generators.NoiseGenerator
37
+ spectrum_noise_min: 0.0
38
+ spectrum_noise_max: 0.1
39
+ include_spectrum_data: false
40
+ include_peak_mask: false
41
+ include_response_function: true
42
+ seed: 44 # null means random seed
43
+ batch_size: 64 # to be set in training script
shimnet/generators.py CHANGED
@@ -1,15 +1,13 @@
 
 
 
1
  import numpy as np
 
2
  import torch
3
  import torchdata
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  from abc import ABC, abstractmethod
6
 
7
- # from itertools import islice
8
-
9
-
10
-
11
-
12
-
13
  def random_value(min_value, max_value, generator=None):
14
  return (min_value + torch.rand(1, generator=generator) * (max_value - min_value)).item()
15
 
@@ -529,6 +527,238 @@ class TheoreticalMultipletSpectraGenerator:
529
 
530
  return spectrum, {"spectrum_data": peaks_parameters_data, "frq_frq": self.frq_frq}
531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  class ResponseGenerator:
533
  def __init__(self, response_function_library, response_function_stretch_min=1., response_function_stretch_max=1., pad_to=None,
534
  response_function_noise=0.0, flip_response_function=False, seed=42):
@@ -674,7 +904,7 @@ class BaseGeneratorMultithread(ABC):
674
 
675
  class Generator(BaseGenerator):
676
  def __init__(self, clean_spectra_generator, response_generator, noise_generator, batch_size=64,
677
- include_spectrum_data=False, include_peak_mask=False, include_response_function=False, seed=None):
678
  super().__init__(batch_size=batch_size, seed=seed)
679
  self.clean_spectra_generator = clean_spectra_generator
680
  self.response_generator = response_generator
@@ -682,6 +912,7 @@ class Generator(BaseGenerator):
682
  self.include_spectrum_data = include_spectrum_data
683
  self.include_peak_mask = include_peak_mask
684
  self.include_response_function = include_response_function
 
685
 
686
  def _generate_element(self, seed):
687
  # Generate different seeds for each generator from the provided seed
@@ -700,7 +931,15 @@ class Generator(BaseGenerator):
700
  response_function = self.response_generator(seed=response_seed)
701
  padding_size = (response_function.shape[-1] - 1)//2
702
  disturbed_spectrum = torch.nn.functional.conv1d(clean_spectrum, response_function, padding=padding_size)
 
 
 
 
 
 
 
703
  noised_spectrum = self.noise_generator(disturbed_spectrum, seed=noise_seed)
 
704
  out = {
705
  'theoretical_spectrum': clean_spectrum,
706
  'disturbed_spectrum': disturbed_spectrum,
 
1
+ from enum import Enum
2
+ from copy import deepcopy
3
+ # from pathlib import Path
4
  import numpy as np
5
+ import pandas as pd
6
  import torch
7
  import torchdata
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  from abc import ABC, abstractmethod
10
 
 
 
 
 
 
 
11
  def random_value(min_value, max_value, generator=None):
12
  return (min_value + torch.rand(1, generator=generator) * (max_value - min_value)).item()
13
 
 
527
 
528
  return spectrum, {"spectrum_data": peaks_parameters_data, "frq_frq": self.frq_frq}
529
 
530
+
531
+ class PeaksParametersNames(Enum):
532
+ """Enum for standardized peak parameter names."""
533
+ tff_lin = "position_hz"
534
+ thf_lin = "height"
535
+ twf_lin = "width_hz"
536
+ trf_lin = "gaussian_fraction"
537
+
538
+ @classmethod
539
+ def keys(cls):
540
+ return [member.value for member in cls]
541
+
542
+ @classmethod
543
+ def values(cls):
544
+ return [member.name for member in cls]
545
+
546
+ class PeaksParametersParser:
547
+ def __init__(self,
548
+ alias_position_hz = None,
549
+ alias_height = None,
550
+ alias_width_hz = None,
551
+ alias_gaussian_fraction = None,
552
+ default_position_hz = None,
553
+ default_height = None,
554
+ default_width_hz = None,
555
+ default_gaussian_fraction = 0.,
556
+ ):
557
+ self.alias_position_hz = alias_position_hz if alias_position_hz is not None else "position_hz"
558
+ self.alias_height = alias_height if alias_height is not None else "height"
559
+ self.alias_width_hz = alias_width_hz if alias_width_hz is not None else "width_hz"
560
+ self.alias_gaussian_fraction = alias_gaussian_fraction if alias_gaussian_fraction is not None else "gaussian_fraction"
561
+ self.default_position_hz = default_position_hz
562
+ self.default_height = default_height
563
+ self.default_width_hz = default_width_hz
564
+ self.default_gaussian_fraction = default_gaussian_fraction
565
+
566
+ def transform_single_peak(self, peak: dict) -> dict:
567
+ parsed_peak = {
568
+ PeaksParametersNames("position_hz").name: peak.get(self.alias_position_hz, self.default_position_hz),
569
+ PeaksParametersNames("height").name: peak.get(self.alias_height, self.default_height),
570
+ PeaksParametersNames("width_hz").name: peak.get(self.alias_width_hz, self.default_width_hz),
571
+ PeaksParametersNames("gaussian_fraction").name: peak.get(self.alias_gaussian_fraction, self.default_gaussian_fraction),
572
+ }
573
+ # Validate and convert other peak parameters
574
+ for k, v in parsed_peak.items():
575
+ if v is None:
576
+ raise ValueError(f"Peak parameter '{k}' is None.")
577
+ parsed_peak[k] = torch.atleast_1d(torch.tensor(v, dtype=torch.float32))
578
+ return parsed_peak
579
+
580
+ def transform(self, spectrum_peaks: list[dict]) -> list[dict]:
581
+ parsed_peaks = []
582
+ for peak in spectrum_peaks:
583
+ parsed_peaks.append(self.transform_single_peak(peak))
584
+ return parsed_peaks
585
+
586
+ def csv_file_to_multiplets_dict(file_path: str) -> list[dict]:
587
+ peaks_data = pd.read_csv(file_path)
588
+ multiplets = {k: v.drop(columns="multiplet_name").to_dict(orient='list') for k, v in peaks_data.groupby("multiplet_name")}
589
+ return multiplets
590
+
591
+ def combine_multiplets(multiplets_list: list[dict]) -> dict:
592
+ composed_multiplets = {}
593
+ for multiplets in multiplets_list:
594
+ for k, v in multiplets.items():
595
+ if not k in composed_multiplets:
596
+ composed_multiplets[k] = v
597
+ else:
598
+ composed_multiplets[k].extend(v)
599
+ return composed_multiplets
600
+
601
+ class MultipletsLibrary:
602
+ def __init__(self, csv_files_paths: list[str], peak_data_parser: PeaksParametersParser = None, return_name=False):
603
+ self.csv_files_paths = csv_files_paths
604
+ self.multiplets_data = {}
605
+ self.peak_data_parser = peak_data_parser
606
+ for file_path in csv_files_paths:
607
+ self.multiplets_data.update(self._get_multiplet_data_from_file(file_path))
608
+
609
+ self.names = sorted(self.multiplets_data.keys())
610
+ self.return_name = return_name
611
+
612
+ def _get_multiplet_data_from_file(self, file_path: str) -> dict:
613
+ multiplets = csv_file_to_multiplets_dict(file_path) # dict[dict]
614
+ multiplets_out = {}
615
+ for k, v in multiplets.items():
616
+ multiplets_out[f"{file_path}/{k}"] = self.peak_data_parser.transform([v])[0] if self.peak_data_parser else v
617
+ return multiplets_out
618
+
619
+ def get_by_name(self, name: str) -> dict:
620
+ return self.multiplets_data.get(name, None)
621
+
622
+ def __getitem__(self, idx: int) -> dict:
623
+ name = self.names[idx]
624
+ multiplet_data = deepcopy(self.multiplets_data[name])
625
+ if self.return_name:
626
+ return name, multiplet_data
627
+ return multiplet_data
628
+
629
+ def __len__(self):
630
+ return len(self.multiplets_data)
631
+
632
+ class SectraLibrary(MultipletsLibrary):
633
+ def _get_multiplet_data_from_file(self, file_path: str) -> dict:
634
+ multiplets = csv_file_to_multiplets_dict(file_path) # dict[dict]
635
+ combined_multiplet = combine_multiplets(multiplets.values()) # dict
636
+ return {f"{file_path}": self.peak_data_parser.transform([combined_multiplet])[0]}
637
+
638
+ class MultipletDataFromMultipletsLibrary:
639
+ def __init__(self,
640
+ multiplets_library,
641
+ tff_min=None, #may be assigned after initialization if the original peak positions are not used
642
+ tff_max=None, #may be assigned after initialization if the original peak positions are not used
643
+ use_original_peak_position=True,
644
+ number_of_signals_min=None,
645
+ number_of_signals_max=None,
646
+ spectrum_width_factor_min=1,
647
+ spectrum_width_factor_max=1,
648
+ multiplet_width_factor_min=1,
649
+ multiplet_width_factor_max=1,
650
+ spectrum_height_factor_min=1,
651
+ spectrum_height_factor_max=1,
652
+ multiplet_height_factor_min=1,
653
+ multiplet_height_factor_max=1,
654
+ position_shift_min=0,
655
+ position_shift_max=0,
656
+ gaussian_fraction_change_min=None,
657
+ gaussian_fraction_change_max=None,
658
+ seed=42
659
+ ):
660
+
661
+ if (number_of_signals_min is None) != (number_of_signals_max is None):
662
+ raise ValueError("Both number_of_signals_min and number_of_signals_max should be provided or both should be None.")
663
+
664
+ self.multiplets_library = multiplets_library
665
+ self.rng_getter = RngGetter(seed=seed)
666
+ self.tff_min = tff_min
667
+ self.tff_max = tff_max
668
+ self.use_original_peak_position = use_original_peak_position
669
+ self.number_of_signals_min = number_of_signals_min
670
+ self.number_of_signals_max = number_of_signals_max
671
+ self.spectrum_width_factor_min = spectrum_width_factor_min
672
+ self.spectrum_width_factor_max = spectrum_width_factor_max
673
+ self.multiplet_width_factor_min = multiplet_width_factor_min
674
+ self.multiplet_width_factor_max = multiplet_width_factor_max
675
+ self.spectrum_height_factor_min = spectrum_height_factor_min
676
+ self.spectrum_height_factor_max = spectrum_height_factor_max
677
+ self.multiplet_height_factor_min = multiplet_height_factor_min
678
+ self.multiplet_height_factor_max = multiplet_height_factor_max
679
+ self.position_shift_min = position_shift_min
680
+ self.position_shift_max = position_shift_max
681
+ self.gaussian_fraction_change_min = gaussian_fraction_change_min
682
+ self.gaussian_fraction_change_max = gaussian_fraction_change_max
683
+
684
+ def set_tff_range(self, tff_min, tff_max):
685
+ self.tff_min = tff_min
686
+ self.tff_max = tff_max
687
+
688
+ def __call__(self, seed=None):
689
+ if (not self.use_original_peak_position) and (self.tff_min is None or self.tff_max is None):
690
+ raise ValueError("for use_original_peak_position=False, tff_min and tff_max must be set before calling the generator.")
691
+
692
+ rng = self.rng_getter.get_rng(seed=seed)
693
+
694
+ # select number of signals and their indices
695
+ if self.number_of_signals_min is None:
696
+ number_of_signals = len(self.multiplets_library)
697
+ multiplets_indices = list(range(len(self.multiplets_library)))
698
+ else:
699
+ number_of_signals = torch.randint(
700
+ self.number_of_signals_min,
701
+ self.number_of_signals_max + 1,
702
+ [],
703
+ generator=rng
704
+ )
705
+
706
+ multiplets_indices = torch.randint(
707
+ 0,
708
+ len(self.multiplets_library),
709
+ [number_of_signals],
710
+ generator=rng
711
+ )
712
+
713
+ # spectrum width and height factors
714
+ spectrum_width_factor = random_loguniform(
715
+ self.spectrum_width_factor_min,
716
+ self.spectrum_width_factor_max,
717
+ generator=rng
718
+ )
719
+
720
+ spectrum_height_factor = random_loguniform(
721
+ self.spectrum_height_factor_min,
722
+ self.spectrum_height_factor_max,
723
+ generator=rng
724
+ )
725
+
726
+ # get and modify peaks parameters data
727
+ peaks_parameters_data = [self.multiplets_library[idx] for idx in multiplets_indices]
728
+ for peak_parameters in peaks_parameters_data:
729
+
730
+ # position
731
+ if not self.use_original_peak_position:
732
+ new_position_center = random_value(self.tff_min, self.tff_max, generator=rng)
733
+ peak_parameters["tff_lin"] += new_position_center - torch.mean(peak_parameters["tff_lin"])
734
+ else:
735
+ position_shift = random_value(self.position_shift_min, self.position_shift_max, generator=rng)
736
+ peak_parameters["tff_lin"] += position_shift
737
+
738
+ # width
739
+ multiplet_width_factor = random_loguniform(
740
+ self.multiplet_width_factor_min,
741
+ self.multiplet_width_factor_max,
742
+ generator=rng
743
+ )
744
+ peak_parameters["twf_lin"] = peak_parameters["twf_lin"] * spectrum_width_factor * multiplet_width_factor
745
+
746
+ # height
747
+ multiplet_height_factor = random_loguniform(
748
+ self.multiplet_height_factor_min,
749
+ self.multiplet_height_factor_max,
750
+ generator=rng
751
+ )
752
+ peak_parameters["thf_lin"] = peak_parameters["thf_lin"] * spectrum_height_factor * multiplet_height_factor
753
+
754
+ # gaussian contribution
755
+ if self.gaussian_fraction_change_min is not None:
756
+ gaussian_contribution_shift = random_value(self.gaussian_fraction_change_min, self.gaussian_fraction_change_max, generator=rng)
757
+ peak_parameters["trf_lin"] = torch.clip(peak_parameters["trf_lin"] + gaussian_contribution_shift, 0., 1.)
758
+
759
+ return peaks_parameters_data
760
+
761
+
762
  class ResponseGenerator:
763
  def __init__(self, response_function_library, response_function_stretch_min=1., response_function_stretch_max=1., pad_to=None,
764
  response_function_noise=0.0, flip_response_function=False, seed=42):
 
904
 
905
  class Generator(BaseGenerator):
906
  def __init__(self, clean_spectra_generator, response_generator, noise_generator, batch_size=64,
907
+ include_spectrum_data=False, include_peak_mask=False, include_response_function=False, input_normalization_height=None, seed=None):
908
  super().__init__(batch_size=batch_size, seed=seed)
909
  self.clean_spectra_generator = clean_spectra_generator
910
  self.response_generator = response_generator
 
912
  self.include_spectrum_data = include_spectrum_data
913
  self.include_peak_mask = include_peak_mask
914
  self.include_response_function = include_response_function
915
+ self.input_normalization_height = input_normalization_height
916
 
917
  def _generate_element(self, seed):
918
  # Generate different seeds for each generator from the provided seed
 
931
  response_function = self.response_generator(seed=response_seed)
932
  padding_size = (response_function.shape[-1] - 1)//2
933
  disturbed_spectrum = torch.nn.functional.conv1d(clean_spectrum, response_function, padding=padding_size)
934
+
935
+ if self.input_normalization_height is not None:
936
+ max_val = torch.max(disturbed_spectrum)
937
+ clean_spectrum = clean_spectrum / max_val * self.input_normalization_height
938
+ disturbed_spectrum = disturbed_spectrum / max_val * self.input_normalization_height
939
+
940
+ # noise after normalization to better control noise level
941
  noised_spectrum = self.noise_generator(disturbed_spectrum, seed=noise_seed)
942
+
943
  out = {
944
  'theoretical_spectrum': clean_spectrum,
945
  'disturbed_spectrum': disturbed_spectrum,