Spaces:
Sleeping
Sleeping
Marek Bukowicki
commited on
Commit
·
c1d3733
1
Parent(s):
3de469d
working modular pipeline to work with peak lists
Browse files- configs/data_generator_from_peak_list.yaml +43 -0
- shimnet/generators.py +246 -7
configs/data_generator_from_peak_list.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
_target_: shimnet.generators.Generator
|
| 3 |
+
input_normalization_height: 16.0
|
| 4 |
+
clean_spectra_generator:
|
| 5 |
+
_target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
|
| 6 |
+
peaks_parameter_generator:
|
| 7 |
+
_target_: shimnet.generators.MultipletDataFromMultipletsLibrary
|
| 8 |
+
number_of_signals_min: 2 # null to use all signals (both min and max must be null)
|
| 9 |
+
number_of_signals_max: 10 # null to use all signals (both min and max must be null)
|
| 10 |
+
multiplet_height_factor_min: 0.5
|
| 11 |
+
multiplet_height_factor_max: 2
|
| 12 |
+
multiplet_width_factor_min: 0.5
|
| 13 |
+
multiplet_width_factor_max: 2
|
| 14 |
+
multiplets_library:
|
| 15 |
+
_target_: shimnet.generators.MultipletsLibrary
|
| 16 |
+
csv_files_paths:
|
| 17 |
+
- data/multiplets_lists/azydekbenzylu_sub1_mono-click.csv
|
| 18 |
+
- data/multiplets_lists/fenyloacetylen_sub2_mono-click.csv
|
| 19 |
+
- data/multiplets_lists/mieszanina_poReakcji_mono-click.csv
|
| 20 |
+
peak_data_parser:
|
| 21 |
+
_target_: shimnet.generators.PeaksParametersParser
|
| 22 |
+
use_original_peak_position: false
|
| 23 |
+
seed: null
|
| 24 |
+
pixels: 2048
|
| 25 |
+
frq_step: ${metadata.frq_step}
|
| 26 |
+
include_tff_relative: true
|
| 27 |
+
relative_frequency_min: -0.4
|
| 28 |
+
relative_frequency_max: 0.4
|
| 29 |
+
response_generator:
|
| 30 |
+
_target_: shimnet.generators.ResponseGenerator
|
| 31 |
+
response_function_library:
|
| 32 |
+
_target_: shimnet.generators.ResponseLibrary
|
| 33 |
+
response_files:
|
| 34 |
+
- data/scrf_81_600MHz.pt
|
| 35 |
+
noise_generator:
|
| 36 |
+
_target_: shimnet.generators.NoiseGenerator
|
| 37 |
+
spectrum_noise_min: 0.0
|
| 38 |
+
spectrum_noise_max: 0.1
|
| 39 |
+
include_spectrum_data: false
|
| 40 |
+
include_peak_mask: false
|
| 41 |
+
include_response_function: true
|
| 42 |
+
seed: 44 # null means random seed
|
| 43 |
+
batch_size: 64 # to be set in training script
|
shimnet/generators.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
|
|
|
| 2 |
import torch
|
| 3 |
import torchdata
|
| 4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
from abc import ABC, abstractmethod
|
| 6 |
|
| 7 |
-
# from itertools import islice
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
def random_value(min_value, max_value, generator=None):
|
| 14 |
return (min_value + torch.rand(1, generator=generator) * (max_value - min_value)).item()
|
| 15 |
|
|
@@ -529,6 +527,238 @@ class TheoreticalMultipletSpectraGenerator:
|
|
| 529 |
|
| 530 |
return spectrum, {"spectrum_data": peaks_parameters_data, "frq_frq": self.frq_frq}
|
| 531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
class ResponseGenerator:
|
| 533 |
def __init__(self, response_function_library, response_function_stretch_min=1., response_function_stretch_max=1., pad_to=None,
|
| 534 |
response_function_noise=0.0, flip_response_function=False, seed=42):
|
|
@@ -674,7 +904,7 @@ class BaseGeneratorMultithread(ABC):
|
|
| 674 |
|
| 675 |
class Generator(BaseGenerator):
|
| 676 |
def __init__(self, clean_spectra_generator, response_generator, noise_generator, batch_size=64,
|
| 677 |
-
include_spectrum_data=False, include_peak_mask=False, include_response_function=False, seed=None):
|
| 678 |
super().__init__(batch_size=batch_size, seed=seed)
|
| 679 |
self.clean_spectra_generator = clean_spectra_generator
|
| 680 |
self.response_generator = response_generator
|
|
@@ -682,6 +912,7 @@ class Generator(BaseGenerator):
|
|
| 682 |
self.include_spectrum_data = include_spectrum_data
|
| 683 |
self.include_peak_mask = include_peak_mask
|
| 684 |
self.include_response_function = include_response_function
|
|
|
|
| 685 |
|
| 686 |
def _generate_element(self, seed):
|
| 687 |
# Generate different seeds for each generator from the provided seed
|
|
@@ -700,7 +931,15 @@ class Generator(BaseGenerator):
|
|
| 700 |
response_function = self.response_generator(seed=response_seed)
|
| 701 |
padding_size = (response_function.shape[-1] - 1)//2
|
| 702 |
disturbed_spectrum = torch.nn.functional.conv1d(clean_spectrum, response_function, padding=padding_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
noised_spectrum = self.noise_generator(disturbed_spectrum, seed=noise_seed)
|
|
|
|
| 704 |
out = {
|
| 705 |
'theoretical_spectrum': clean_spectrum,
|
| 706 |
'disturbed_spectrum': disturbed_spectrum,
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
# from pathlib import Path
|
| 4 |
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
import torch
|
| 7 |
import torchdata
|
| 8 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 9 |
from abc import ABC, abstractmethod
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def random_value(min_value, max_value, generator=None):
|
| 12 |
return (min_value + torch.rand(1, generator=generator) * (max_value - min_value)).item()
|
| 13 |
|
|
|
|
| 527 |
|
| 528 |
return spectrum, {"spectrum_data": peaks_parameters_data, "frq_frq": self.frq_frq}
|
| 529 |
|
| 530 |
+
|
| 531 |
+
class PeaksParametersNames(Enum):
|
| 532 |
+
"""Enum for standardized peak parameter names."""
|
| 533 |
+
tff_lin = "position_hz"
|
| 534 |
+
thf_lin = "height"
|
| 535 |
+
twf_lin = "width_hz"
|
| 536 |
+
trf_lin = "gaussian_fraction"
|
| 537 |
+
|
| 538 |
+
@classmethod
|
| 539 |
+
def keys(cls):
|
| 540 |
+
return [member.value for member in cls]
|
| 541 |
+
|
| 542 |
+
@classmethod
|
| 543 |
+
def values(cls):
|
| 544 |
+
return [member.name for member in cls]
|
| 545 |
+
|
| 546 |
+
class PeaksParametersParser:
|
| 547 |
+
def __init__(self,
|
| 548 |
+
alias_position_hz = None,
|
| 549 |
+
alias_height = None,
|
| 550 |
+
alias_width_hz = None,
|
| 551 |
+
alias_gaussian_fraction = None,
|
| 552 |
+
default_position_hz = None,
|
| 553 |
+
default_height = None,
|
| 554 |
+
default_width_hz = None,
|
| 555 |
+
default_gaussian_fraction = 0.,
|
| 556 |
+
):
|
| 557 |
+
self.alias_position_hz = alias_position_hz if alias_position_hz is not None else "position_hz"
|
| 558 |
+
self.alias_height = alias_height if alias_height is not None else "height"
|
| 559 |
+
self.alias_width_hz = alias_width_hz if alias_width_hz is not None else "width_hz"
|
| 560 |
+
self.alias_gaussian_fraction = alias_gaussian_fraction if alias_gaussian_fraction is not None else "gaussian_fraction"
|
| 561 |
+
self.default_position_hz = default_position_hz
|
| 562 |
+
self.default_height = default_height
|
| 563 |
+
self.default_width_hz = default_width_hz
|
| 564 |
+
self.default_gaussian_fraction = default_gaussian_fraction
|
| 565 |
+
|
| 566 |
+
def transform_single_peak(self, peak: dict) -> dict:
|
| 567 |
+
parsed_peak = {
|
| 568 |
+
PeaksParametersNames("position_hz").name: peak.get(self.alias_position_hz, self.default_position_hz),
|
| 569 |
+
PeaksParametersNames("height").name: peak.get(self.alias_height, self.default_height),
|
| 570 |
+
PeaksParametersNames("width_hz").name: peak.get(self.alias_width_hz, self.default_width_hz),
|
| 571 |
+
PeaksParametersNames("gaussian_fraction").name: peak.get(self.alias_gaussian_fraction, self.default_gaussian_fraction),
|
| 572 |
+
}
|
| 573 |
+
# Validate and convert other peak parameters
|
| 574 |
+
for k, v in parsed_peak.items():
|
| 575 |
+
if v is None:
|
| 576 |
+
raise ValueError(f"Peak parameter '{k}' is None.")
|
| 577 |
+
parsed_peak[k] = torch.atleast_1d(torch.tensor(v, dtype=torch.float32))
|
| 578 |
+
return parsed_peak
|
| 579 |
+
|
| 580 |
+
def transform(self, spectrum_peaks: list[dict]) -> list[dict]:
|
| 581 |
+
parsed_peaks = []
|
| 582 |
+
for peak in spectrum_peaks:
|
| 583 |
+
parsed_peaks.append(self.transform_single_peak(peak))
|
| 584 |
+
return parsed_peaks
|
| 585 |
+
|
| 586 |
+
def csv_file_to_multiplets_dict(file_path: str) -> list[dict]:
|
| 587 |
+
peaks_data = pd.read_csv(file_path)
|
| 588 |
+
multiplets = {k: v.drop(columns="multiplet_name").to_dict(orient='list') for k, v in peaks_data.groupby("multiplet_name")}
|
| 589 |
+
return multiplets
|
| 590 |
+
|
| 591 |
+
def combine_multiplets(multiplets_list: list[dict]) -> dict:
|
| 592 |
+
composed_multiplets = {}
|
| 593 |
+
for multiplets in multiplets_list:
|
| 594 |
+
for k, v in multiplets.items():
|
| 595 |
+
if not k in composed_multiplets:
|
| 596 |
+
composed_multiplets[k] = v
|
| 597 |
+
else:
|
| 598 |
+
composed_multiplets[k].extend(v)
|
| 599 |
+
return composed_multiplets
|
| 600 |
+
|
| 601 |
+
class MultipletsLibrary:
|
| 602 |
+
def __init__(self, csv_files_paths: list[str], peak_data_parser: PeaksParametersParser = None, return_name=False):
|
| 603 |
+
self.csv_files_paths = csv_files_paths
|
| 604 |
+
self.multiplets_data = {}
|
| 605 |
+
self.peak_data_parser = peak_data_parser
|
| 606 |
+
for file_path in csv_files_paths:
|
| 607 |
+
self.multiplets_data.update(self._get_multiplet_data_from_file(file_path))
|
| 608 |
+
|
| 609 |
+
self.names = sorted(self.multiplets_data.keys())
|
| 610 |
+
self.return_name = return_name
|
| 611 |
+
|
| 612 |
+
def _get_multiplet_data_from_file(self, file_path: str) -> dict:
|
| 613 |
+
multiplets = csv_file_to_multiplets_dict(file_path) # dict[dict]
|
| 614 |
+
multiplets_out = {}
|
| 615 |
+
for k, v in multiplets.items():
|
| 616 |
+
multiplets_out[f"{file_path}/{k}"] = self.peak_data_parser.transform([v])[0] if self.peak_data_parser else v
|
| 617 |
+
return multiplets_out
|
| 618 |
+
|
| 619 |
+
def get_by_name(self, name: str) -> dict:
|
| 620 |
+
return self.multiplets_data.get(name, None)
|
| 621 |
+
|
| 622 |
+
def __getitem__(self, idx: int) -> dict:
|
| 623 |
+
name = self.names[idx]
|
| 624 |
+
multiplet_data = deepcopy(self.multiplets_data[name])
|
| 625 |
+
if self.return_name:
|
| 626 |
+
return name, multiplet_data
|
| 627 |
+
return multiplet_data
|
| 628 |
+
|
| 629 |
+
def __len__(self):
|
| 630 |
+
return len(self.multiplets_data)
|
| 631 |
+
|
| 632 |
+
class SectraLibrary(MultipletsLibrary):
|
| 633 |
+
def _get_multiplet_data_from_file(self, file_path: str) -> dict:
|
| 634 |
+
multiplets = csv_file_to_multiplets_dict(file_path) # dict[dict]
|
| 635 |
+
combined_multiplet = combine_multiplets(multiplets.values()) # dict
|
| 636 |
+
return {f"{file_path}": self.peak_data_parser.transform([combined_multiplet])[0]}
|
| 637 |
+
|
| 638 |
+
class MultipletDataFromMultipletsLibrary:
|
| 639 |
+
def __init__(self,
|
| 640 |
+
multiplets_library,
|
| 641 |
+
tff_min=None, #may be assigned after initialization if the original peak positions are not used
|
| 642 |
+
tff_max=None, #may be assigned after initialization if the original peak positions are not used
|
| 643 |
+
use_original_peak_position=True,
|
| 644 |
+
number_of_signals_min=None,
|
| 645 |
+
number_of_signals_max=None,
|
| 646 |
+
spectrum_width_factor_min=1,
|
| 647 |
+
spectrum_width_factor_max=1,
|
| 648 |
+
multiplet_width_factor_min=1,
|
| 649 |
+
multiplet_width_factor_max=1,
|
| 650 |
+
spectrum_height_factor_min=1,
|
| 651 |
+
spectrum_height_factor_max=1,
|
| 652 |
+
multiplet_height_factor_min=1,
|
| 653 |
+
multiplet_height_factor_max=1,
|
| 654 |
+
position_shift_min=0,
|
| 655 |
+
position_shift_max=0,
|
| 656 |
+
gaussian_fraction_change_min=None,
|
| 657 |
+
gaussian_fraction_change_max=None,
|
| 658 |
+
seed=42
|
| 659 |
+
):
|
| 660 |
+
|
| 661 |
+
if (number_of_signals_min is None) != (number_of_signals_max is None):
|
| 662 |
+
raise ValueError("Both number_of_signals_min and number_of_signals_max should be provided or both should be None.")
|
| 663 |
+
|
| 664 |
+
self.multiplets_library = multiplets_library
|
| 665 |
+
self.rng_getter = RngGetter(seed=seed)
|
| 666 |
+
self.tff_min = tff_min
|
| 667 |
+
self.tff_max = tff_max
|
| 668 |
+
self.use_original_peak_position = use_original_peak_position
|
| 669 |
+
self.number_of_signals_min = number_of_signals_min
|
| 670 |
+
self.number_of_signals_max = number_of_signals_max
|
| 671 |
+
self.spectrum_width_factor_min = spectrum_width_factor_min
|
| 672 |
+
self.spectrum_width_factor_max = spectrum_width_factor_max
|
| 673 |
+
self.multiplet_width_factor_min = multiplet_width_factor_min
|
| 674 |
+
self.multiplet_width_factor_max = multiplet_width_factor_max
|
| 675 |
+
self.spectrum_height_factor_min = spectrum_height_factor_min
|
| 676 |
+
self.spectrum_height_factor_max = spectrum_height_factor_max
|
| 677 |
+
self.multiplet_height_factor_min = multiplet_height_factor_min
|
| 678 |
+
self.multiplet_height_factor_max = multiplet_height_factor_max
|
| 679 |
+
self.position_shift_min = position_shift_min
|
| 680 |
+
self.position_shift_max = position_shift_max
|
| 681 |
+
self.gaussian_fraction_change_min = gaussian_fraction_change_min
|
| 682 |
+
self.gaussian_fraction_change_max = gaussian_fraction_change_max
|
| 683 |
+
|
| 684 |
+
def set_tff_range(self, tff_min, tff_max):
|
| 685 |
+
self.tff_min = tff_min
|
| 686 |
+
self.tff_max = tff_max
|
| 687 |
+
|
| 688 |
+
def __call__(self, seed=None):
|
| 689 |
+
if (not self.use_original_peak_position) and (self.tff_min is None or self.tff_max is None):
|
| 690 |
+
raise ValueError("for use_original_peak_position=False, tff_min and tff_max must be set before calling the generator.")
|
| 691 |
+
|
| 692 |
+
rng = self.rng_getter.get_rng(seed=seed)
|
| 693 |
+
|
| 694 |
+
# select number of signals and their indices
|
| 695 |
+
if self.number_of_signals_min is None:
|
| 696 |
+
number_of_signals = len(self.multiplets_library)
|
| 697 |
+
multiplets_indices = list(range(len(self.multiplets_library)))
|
| 698 |
+
else:
|
| 699 |
+
number_of_signals = torch.randint(
|
| 700 |
+
self.number_of_signals_min,
|
| 701 |
+
self.number_of_signals_max + 1,
|
| 702 |
+
[],
|
| 703 |
+
generator=rng
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
multiplets_indices = torch.randint(
|
| 707 |
+
0,
|
| 708 |
+
len(self.multiplets_library),
|
| 709 |
+
[number_of_signals],
|
| 710 |
+
generator=rng
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
# spectrum width and height factors
|
| 714 |
+
spectrum_width_factor = random_loguniform(
|
| 715 |
+
self.spectrum_width_factor_min,
|
| 716 |
+
self.spectrum_width_factor_max,
|
| 717 |
+
generator=rng
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
spectrum_height_factor = random_loguniform(
|
| 721 |
+
self.spectrum_height_factor_min,
|
| 722 |
+
self.spectrum_height_factor_max,
|
| 723 |
+
generator=rng
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
# get and modify peaks parameters data
|
| 727 |
+
peaks_parameters_data = [self.multiplets_library[idx] for idx in multiplets_indices]
|
| 728 |
+
for peak_parameters in peaks_parameters_data:
|
| 729 |
+
|
| 730 |
+
# position
|
| 731 |
+
if not self.use_original_peak_position:
|
| 732 |
+
new_position_center = random_value(self.tff_min, self.tff_max, generator=rng)
|
| 733 |
+
peak_parameters["tff_lin"] += new_position_center - torch.mean(peak_parameters["tff_lin"])
|
| 734 |
+
else:
|
| 735 |
+
position_shift = random_value(self.position_shift_min, self.position_shift_max, generator=rng)
|
| 736 |
+
peak_parameters["tff_lin"] += position_shift
|
| 737 |
+
|
| 738 |
+
# width
|
| 739 |
+
multiplet_width_factor = random_loguniform(
|
| 740 |
+
self.multiplet_width_factor_min,
|
| 741 |
+
self.multiplet_width_factor_max,
|
| 742 |
+
generator=rng
|
| 743 |
+
)
|
| 744 |
+
peak_parameters["twf_lin"] = peak_parameters["twf_lin"] * spectrum_width_factor * multiplet_width_factor
|
| 745 |
+
|
| 746 |
+
# height
|
| 747 |
+
multiplet_height_factor = random_loguniform(
|
| 748 |
+
self.multiplet_height_factor_min,
|
| 749 |
+
self.multiplet_height_factor_max,
|
| 750 |
+
generator=rng
|
| 751 |
+
)
|
| 752 |
+
peak_parameters["thf_lin"] = peak_parameters["thf_lin"] * spectrum_height_factor * multiplet_height_factor
|
| 753 |
+
|
| 754 |
+
# gaussian contribution
|
| 755 |
+
if self.gaussian_fraction_change_min is not None:
|
| 756 |
+
gaussian_contribution_shift = random_value(self.gaussian_fraction_change_min, self.gaussian_fraction_change_max, generator=rng)
|
| 757 |
+
peak_parameters["trf_lin"] = torch.clip(peak_parameters["trf_lin"] + gaussian_contribution_shift, 0., 1.)
|
| 758 |
+
|
| 759 |
+
return peaks_parameters_data
|
| 760 |
+
|
| 761 |
+
|
| 762 |
class ResponseGenerator:
|
| 763 |
def __init__(self, response_function_library, response_function_stretch_min=1., response_function_stretch_max=1., pad_to=None,
|
| 764 |
response_function_noise=0.0, flip_response_function=False, seed=42):
|
|
|
|
| 904 |
|
| 905 |
class Generator(BaseGenerator):
|
| 906 |
def __init__(self, clean_spectra_generator, response_generator, noise_generator, batch_size=64,
|
| 907 |
+
include_spectrum_data=False, include_peak_mask=False, include_response_function=False, input_normalization_height=None, seed=None):
|
| 908 |
super().__init__(batch_size=batch_size, seed=seed)
|
| 909 |
self.clean_spectra_generator = clean_spectra_generator
|
| 910 |
self.response_generator = response_generator
|
|
|
|
| 912 |
self.include_spectrum_data = include_spectrum_data
|
| 913 |
self.include_peak_mask = include_peak_mask
|
| 914 |
self.include_response_function = include_response_function
|
| 915 |
+
self.input_normalization_height = input_normalization_height
|
| 916 |
|
| 917 |
def _generate_element(self, seed):
|
| 918 |
# Generate different seeds for each generator from the provided seed
|
|
|
|
| 931 |
response_function = self.response_generator(seed=response_seed)
|
| 932 |
padding_size = (response_function.shape[-1] - 1)//2
|
| 933 |
disturbed_spectrum = torch.nn.functional.conv1d(clean_spectrum, response_function, padding=padding_size)
|
| 934 |
+
|
| 935 |
+
if self.input_normalization_height is not None:
|
| 936 |
+
max_val = torch.max(disturbed_spectrum)
|
| 937 |
+
clean_spectrum = clean_spectrum / max_val * self.input_normalization_height
|
| 938 |
+
disturbed_spectrum = disturbed_spectrum / max_val * self.input_normalization_height
|
| 939 |
+
|
| 940 |
+
# noise after normalization to better control noise level
|
| 941 |
noised_spectrum = self.noise_generator(disturbed_spectrum, seed=noise_seed)
|
| 942 |
+
|
| 943 |
out = {
|
| 944 |
'theoretical_spectrum': clean_spectrum,
|
| 945 |
'disturbed_spectrum': disturbed_spectrum,
|