Marek Bukowicki commited on
Commit
02f3d93
·
1 Parent(s): a58e9bb

fix data generation from peaks data

Browse files
Files changed (1) hide show
  1. shimnet/generators.py +20 -11
shimnet/generators.py CHANGED
@@ -411,8 +411,10 @@ class PeaksParameterDataGenerator:
411
  self.rng_getter = RngGetter(seed=seed)
412
 
413
  def set_frq_range(self, frq_min, frq_max):
414
- self.tff_min = frq_min * self.relative_frequency_min
415
- self.tff_max = frq_max * self.relative_frequency_max
 
 
416
 
417
  def __call__(self, seed=None):
418
  """
@@ -533,7 +535,6 @@ class TheoreticalMultipletSpectraGenerator:
533
  frq_frq = torch.arange(0, pixels) * frq_step + frequency_min
534
  else:
535
  raise ValueError("Insufficient parameters to determine frequency axis.")
536
-
537
  return frq_frq, frq_frq[0], frq_frq[-1]
538
 
539
 
@@ -565,7 +566,7 @@ class PeaksParametersNames(Enum):
565
  """Enum for standardized peak parameter names."""
566
  position_hz ="tff_lin"
567
  height = "thf_lin"
568
- width_hz = "twf_lin"
569
  gaussian_fraction = "trf_lin"
570
 
571
  @classmethod
@@ -587,6 +588,7 @@ class PeaksParametersParser:
587
  default_height = None,
588
  default_width_hz = None,
589
  default_gaussian_fraction = 0.,
 
590
  ):
591
  self.alias_position_hz = alias_position_hz if alias_position_hz is not None else "position_hz"
592
  self.alias_height = alias_height if alias_height is not None else "height"
@@ -596,12 +598,13 @@ class PeaksParametersParser:
596
  self.default_height = default_height
597
  self.default_width_hz = default_width_hz
598
  self.default_gaussian_fraction = default_gaussian_fraction
 
599
 
600
  def transform_single_peak(self, peak: dict) -> dict:
601
  parsed_peak = {
602
  PeaksParametersNames.position_hz.value: peak.get(self.alias_position_hz, self.default_position_hz),
603
  PeaksParametersNames.height.value: peak.get(self.alias_height, self.default_height),
604
- PeaksParametersNames.width_hz.value: peak.get(self.alias_width_hz, self.default_width_hz),
605
  PeaksParametersNames.gaussian_fraction.value: peak.get(self.alias_gaussian_fraction, self.default_gaussian_fraction),
606
  }
607
  # Validate and convert other peak parameters
@@ -732,8 +735,10 @@ class MultipletDataFromMultipletsLibrary:
732
  self.gaussian_fraction_change_additive_max = gaussian_fraction_change_additive_max
733
 
734
  def set_frq_range(self, frq_min, frq_max):
735
- self.tff_min = frq_min * self.relative_frequency_min
736
- self.tff_max = frq_max * self.relative_frequency_max
 
 
737
 
738
 
739
  def __call__(self, seed=None):
@@ -1049,6 +1054,7 @@ class PeaksParametersFromSinglets:
1049
  width_factor_max: float = 1.0,
1050
  width_hz_change_min: float = 0.0,
1051
  width_hz_change_max: float = 0.0,
 
1052
  use_original_height: bool = True,
1053
  height_min: float = 0.1,
1054
  height_max: float = 10.0,
@@ -1084,6 +1090,7 @@ class PeaksParametersFromSinglets:
1084
  self.width_factor_max = width_factor_max
1085
  self.width_hz_change_min = width_hz_change_min
1086
  self.width_hz_change_max = width_hz_change_max
 
1087
  # height
1088
  self.use_original_height = use_original_height
1089
  self.height_min = height_min
@@ -1102,8 +1109,10 @@ class PeaksParametersFromSinglets:
1102
  self.rng_getter = RngGetter(seed=seed)
1103
 
1104
  def set_frq_range(self, frq_min, frq_max):
1105
- self.position_hz_min = frq_min * self.relative_frequency_min
1106
- self.position_hz_max = frq_max * self.relative_frequency_max
 
 
1107
 
1108
  def __call__(self, seed=None) -> list[dict]:
1109
  rng = self.rng_getter.get_rng(seed=seed)
@@ -1124,9 +1133,9 @@ class PeaksParametersFromSinglets:
1124
  multiplet_data[PeaksParametersNames.position_hz.value] = random_uniform_vector(self.position_hz_min, self.position_hz_max, size=len(selected_peaks))
1125
  # width
1126
  if self.use_original_width:
1127
- multiplet_data[PeaksParametersNames.width_hz.value] = torch.tensor(selected_peaks["width_hz"].values, dtype=torch.float32) * random_uniform_vector(self.width_factor_min, self.width_factor_max, size=len(selected_peaks)) + random_uniform_vector(self.width_hz_change_min, self.width_hz_change_max, size=len(selected_peaks))
1128
  else:
1129
- multiplet_data[PeaksParametersNames.width_hz.value] = random_loguniform_vector(self.width_hz_min, self.width_hz_max, size=len(selected_peaks))
1130
  # height
1131
  if self.use_original_height:
1132
  multiplet_data[PeaksParametersNames.height.value] = torch.tensor(selected_peaks["height"].values, dtype=torch.float32) * random_uniform_vector(self.height_factor_min, self.height_factor_max, size=len(selected_peaks)) + random_uniform_vector(self.height_change_min, self.height_change_max, size=len(selected_peaks))
 
411
  self.rng_getter = RngGetter(seed=seed)
412
 
413
  def set_frq_range(self, frq_min, frq_max):
414
+ frq_amplitude = frq_max - frq_min
415
+ frq_center = (frq_max + frq_min) / 2
416
+ self.tff_min = frq_center + frq_amplitude * self.relative_frequency_min
417
+ self.tff_max = frq_center + frq_amplitude * self.relative_frequency_max
418
 
419
  def __call__(self, seed=None):
420
  """
 
535
  frq_frq = torch.arange(0, pixels) * frq_step + frequency_min
536
  else:
537
  raise ValueError("Insufficient parameters to determine frequency axis.")
 
538
  return frq_frq, frq_frq[0], frq_frq[-1]
539
 
540
 
 
566
  """Enum for standardized peak parameter names."""
567
  position_hz ="tff_lin"
568
  height = "thf_lin"
569
+ halfwidth_hz = "twf_lin"
570
  gaussian_fraction = "trf_lin"
571
 
572
  @classmethod
 
588
  default_height = None,
589
  default_width_hz = None,
590
  default_gaussian_fraction = 0.,
591
+ convert_width_to_halfwidth = True
592
  ):
593
  self.alias_position_hz = alias_position_hz if alias_position_hz is not None else "position_hz"
594
  self.alias_height = alias_height if alias_height is not None else "height"
 
598
  self.default_height = default_height
599
  self.default_width_hz = default_width_hz
600
  self.default_gaussian_fraction = default_gaussian_fraction
601
+ self.convert_width_to_halfwidth = convert_width_to_halfwidth
602
 
603
  def transform_single_peak(self, peak: dict) -> dict:
604
  parsed_peak = {
605
  PeaksParametersNames.position_hz.value: peak.get(self.alias_position_hz, self.default_position_hz),
606
  PeaksParametersNames.height.value: peak.get(self.alias_height, self.default_height),
607
+ PeaksParametersNames.halfwidth_hz.value: (0.5 if self.convert_width_to_halfwidth else 1.) * peak.get(self.alias_width_hz, self.default_width_hz),
608
  PeaksParametersNames.gaussian_fraction.value: peak.get(self.alias_gaussian_fraction, self.default_gaussian_fraction),
609
  }
610
  # Validate and convert other peak parameters
 
735
  self.gaussian_fraction_change_additive_max = gaussian_fraction_change_additive_max
736
 
737
  def set_frq_range(self, frq_min, frq_max):
738
+ frq_amplitude = frq_max - frq_min
739
+ frq_center = (frq_max + frq_min) / 2
740
+ self.tff_min = frq_center + frq_amplitude * self.relative_frequency_min
741
+ self.tff_max = frq_center + frq_amplitude * self.relative_frequency_max
742
 
743
 
744
  def __call__(self, seed=None):
 
1054
  width_factor_max: float = 1.0,
1055
  width_hz_change_min: float = 0.0,
1056
  width_hz_change_max: float = 0.0,
1057
+ convert_width_to_halfwidth: bool = True,
1058
  use_original_height: bool = True,
1059
  height_min: float = 0.1,
1060
  height_max: float = 10.0,
 
1090
  self.width_factor_max = width_factor_max
1091
  self.width_hz_change_min = width_hz_change_min
1092
  self.width_hz_change_max = width_hz_change_max
1093
+ self.convert_width_to_halfwidth = convert_width_to_halfwidth # if True, the original widths will be divided by 2
1094
  # height
1095
  self.use_original_height = use_original_height
1096
  self.height_min = height_min
 
1109
  self.rng_getter = RngGetter(seed=seed)
1110
 
1111
  def set_frq_range(self, frq_min, frq_max):
1112
+ frq_amplitude = frq_max - frq_min
1113
+ frq_center = (frq_max + frq_min) / 2
1114
+ self.position_hz_min = frq_center + frq_amplitude * self.relative_frequency_min
1115
+ self.position_hz_max = frq_center + frq_amplitude * self.relative_frequency_max
1116
 
1117
  def __call__(self, seed=None) -> list[dict]:
1118
  rng = self.rng_getter.get_rng(seed=seed)
 
1133
  multiplet_data[PeaksParametersNames.position_hz.value] = random_uniform_vector(self.position_hz_min, self.position_hz_max, size=len(selected_peaks))
1134
  # width
1135
  if self.use_original_width:
1136
+ multiplet_data[PeaksParametersNames.halfwidth_hz.value] = (0.5 if self.convert_width_to_halfwidth else 1.)*torch.tensor(selected_peaks["width_hz"].values, dtype=torch.float32) * random_uniform_vector(self.width_factor_min, self.width_factor_max, size=len(selected_peaks)) + random_uniform_vector(self.width_hz_change_min, self.width_hz_change_max, size=len(selected_peaks))
1137
  else:
1138
+ multiplet_data[PeaksParametersNames.halfwidth_hz.value] = random_loguniform_vector(self.width_hz_min, self.width_hz_max, size=len(selected_peaks))
1139
  # height
1140
  if self.use_original_height:
1141
  multiplet_data[PeaksParametersNames.height.value] = torch.tensor(selected_peaks["height"].values, dtype=torch.float32) * random_uniform_vector(self.height_factor_min, self.height_factor_max, size=len(selected_peaks)) + random_uniform_vector(self.height_change_min, self.height_change_max, size=len(selected_peaks))