from enum import Enum from copy import deepcopy from typing import Optional import numpy as np import pandas as pd import torch from concurrent.futures import ThreadPoolExecutor, as_completed from abc import ABC, abstractmethod def random_uniform(min_value, max_value, generator=None): return (min_value + torch.rand(1, generator=generator) * (max_value - min_value)).item() random_value = random_uniform def random_loguniform(min_value, max_value, generator=None): return (min_value * torch.exp(torch.rand(1, generator=generator) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item() def random_uniform_vector(min_value, max_value, size, generator=None): return min_value + torch.rand(size, generator=generator) * (max_value - min_value) def random_loguniform_vector(min_value, max_value, size, generator=None): return min_value * torch.exp(torch.rand(size, generator=generator) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value)))) def spectrum_from_peaks_data(peaks_parameters: dict | list, frq_frq:torch.Tensor, relative_frequency=False): if isinstance(peaks_parameters, dict): peaks_parameters = [peaks_parameters] spectrum = torch.zeros((1, frq_frq.shape[0])) for peak_params in peaks_parameters: # extract parameters if relative_frequency: tff_lin = frq_frq[0] + peak_params["tff_relative"]*(frq_frq[1]-frq_frq[0]) else: tff_lin = peak_params["tff_lin"] twf_lin = peak_params["twf_lin"] thf_lin = peak_params["thf_lin"] trf_lin = peak_params["trf_lin"] lwf_lin = twf_lin lhf_lin = thf_lin * (1. - trf_lin) gwf_lin = twf_lin gdf_lin = gwf_lin / torch.tensor(2.).log().mul(2.).sqrt() ghf_lin = thf_lin * trf_lin # calculate Lorenz peaks contriubutions lsf_linfrq = lwf_lin[:, None] ** 2 / (lwf_lin[:, None] ** 2 + (frq_frq - tff_lin[:, None]) ** 2) * lhf_lin[:, None] # calculate Gaussian peaks contriubutions gsf_linfrq = torch.exp(-(frq_frq - tff_lin[:, None]) ** 2 / gdf_lin[:, None] ** 2 / 2.) * ghf_lin[:, None] tsf_linfrq = lsf_linfrq + gsf_linfrq # sum peaks contriubutions spectrum += tsf_linfrq.sum(0, keepdim = True) return spectrum calculate_theoretical_spectrum = spectrum_from_peaks_data # Alias for backward compatibility pascal_triangle = [(1,), (1,1), (1,2,1), (1,3,3,1), (1,4,6,4,1), (1,5,10,10,5,1), (1,6,15,20,15,6,1), (1,7, 21,35,35,21,7,1)] normalized_pascal_triangle = [torch.tensor(x)/sum(x) for x in pascal_triangle] def pascal_multiplicity(multiplicity): intensities = normalized_pascal_triangle[multiplicity-1] n_peaks = len(intensities) shifts = torch.arange(n_peaks)-((n_peaks-1)/2) return shifts, intensities def double_multiplicity(multiplicity1, multiplicity2, j1=1, j2=1): shifts1, intensities1 = pascal_multiplicity(multiplicity1) shifts2, intensities2 = pascal_multiplicity(multiplicity2) shifts = (j1*shifts1.reshape(-1,1) + j2*shifts2.reshape(1,-1)).flatten() intensities = (intensities1.reshape(-1,1) * intensities2.reshape(1,-1)).flatten() return shifts, intensities def generate_multiplet_parameters(multiplicity, tff_lin, thf_lin, twf_lin, trf_lin, j1, j2): shifts, intensities = double_multiplicity(multiplicity[0], multiplicity[1], j1, j2) n_peaks = len(shifts) return { "tff_lin": shifts + tff_lin, "thf_lin": intensities * thf_lin, "twf_lin": torch.full((n_peaks,), twf_lin), "trf_lin": torch.full((n_peaks,), trf_lin), } def value_to_index(values, table): span = table[-1] - table[0] indices = ((values - table[0])/span * (len(table)-1)) #.round().type(torch.int64) return indices def generate_theoretical_spectrum( number_of_signals_min, number_of_signals_max, spectrum_width_min, spectrum_width_max, relative_width_min, relative_width_max, tff_min, tff_max, thf_min, thf_max, trf_min, trf_max, relative_height_min, relative_height_max, multiplicity_j1_min, multiplicity_j1_max, multiplicity_j2_min, multiplicity_j2_max, atom_groups_data, frq_frq, generator=None ): number_of_signals = torch.randint(number_of_signals_min, number_of_signals_max+1, [], generator=generator) atom_group_indices = torch.randint(0, len(atom_groups_data), [number_of_signals], generator=generator) width_spectrum = random_loguniform(spectrum_width_min, spectrum_width_max, generator=generator) height_spectrum = random_loguniform(thf_min, thf_max, generator=generator) peak_parameters_data = [] theoretical_spectrum = None for atom_group_index in atom_group_indices: relative_intensity, multiplicity1, multiplicity2 = atom_groups_data[atom_group_index] position = random_value(tff_min, tff_max, generator=generator) j1 = random_value(multiplicity_j1_min, multiplicity_j1_max, generator=generator) j2 = random_value(multiplicity_j2_min, multiplicity_j2_max, generator=generator) width = width_spectrum*random_loguniform(relative_width_min, relative_width_max, generator=generator) height = height_spectrum*relative_intensity*random_loguniform(relative_height_min, relative_height_max, generator=generator) gaussian_contribution = random_value(trf_min, trf_max, generator=generator) peaks_parameters = generate_multiplet_parameters(multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin= width, trf_lin= gaussian_contribution, j1=j1, j2=j2) peaks_parameters["tff_relative"] = value_to_index(peaks_parameters["tff_lin"], frq_frq) peak_parameters_data.append(peaks_parameters) spectrum_contribution = calculate_theoretical_spectrum(peaks_parameters, frq_frq) if theoretical_spectrum is None: theoretical_spectrum = spectrum_contribution else: theoretical_spectrum += spectrum_contribution return theoretical_spectrum, peak_parameters_data def theoretical_generator( atom_groups_data, pixels=2048, frq_step=11160.7142857 / 32768, number_of_signals_min=1, number_of_signals_max=8, spectrum_width_min=0.2, spectrum_width_max=1, relative_width_min=1, relative_width_max=2, relative_height_min=1, relative_height_max=1, relative_frequency_min=-0.4, relative_frequency_max=0.4, thf_min=1/16, thf_max=16, trf_min=0, trf_max=1, multiplicity_j1_min=0, multiplicity_j1_max=15, multiplicity_j2_min=0, multiplicity_j2_max=15, ): tff_min = relative_frequency_min * pixels * frq_step tff_max = relative_frequency_max * pixels * frq_step frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step while True: yield generate_theoretical_spectrum( number_of_signals_min=number_of_signals_min, number_of_signals_max=number_of_signals_max, spectrum_width_min=spectrum_width_min, spectrum_width_max=spectrum_width_max, relative_width_min=relative_width_min, relative_width_max=relative_width_max, relative_height_min=relative_height_min, relative_height_max=relative_height_max, tff_min=tff_min, tff_max=tff_max, thf_min=thf_min, thf_max=thf_max, trf_min=trf_min, trf_max=trf_max, multiplicity_j1_min=multiplicity_j1_min, multiplicity_j1_max=multiplicity_j1_max, multiplicity_j2_min=multiplicity_j2_min, multiplicity_j2_max=multiplicity_j2_max, atom_groups_data=atom_groups_data, frq_frq=frq_frq ) class ResponseLibrary: def __init__(self, response_files, normalize=True): self.data = [torch.load(f, map_location='cpu', weights_only=True).flatten(0,-4) for f in response_files] if normalize: self.data = [data/torch.sum(data, dim=(-1,), keepdim=True) for data in self.data] lengths = [len(data) for data in self.data] self.start_indices = torch.cumsum(torch.tensor([0] + lengths[:-1]), 0) self.total_length = sum(lengths) def __getitem__(self, idx): if idx >= self.total_length: raise ValueError(f'index {idx} out of range') tensor_index = torch.searchsorted(self.start_indices, idx, right=True) - 1 return self.data[tensor_index][idx - self.start_indices[tensor_index]] def __len__(self): return self.total_length @property def max_response_length(self): return max([data.shape[-1] for data in self.data]) def generator( theoretical_generator_params, response_function_library, response_function_stretch_min=0.5, response_function_stretch_max=2.0, response_function_noise=0., spectrum_noise_min=0., spectrum_noise_max=1/64, include_spectrum_data=False, include_peak_mask=False, include_response_function=False, flip_response_function=False ): for theoretical_spectrum, theoretical_spectrum_data in theoretical_generator(**theoretical_generator_params): # get response function response_function = response_function_library[torch.randint(0, len(response_function_library), [1])][0] # stretch response function padding_size = (response_function.shape[-1] - 1)//2 padding_size = round(random_loguniform(response_function_stretch_min, response_function_stretch_max)*padding_size) #torch.randint(round(padding_size*response_function_stretch_min), round(paddingSize*response_function_stretch_max), [1]).item() response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear') response_function /= response_function.sum() # normalize sum of response function to 1 # add noise to response function response_function += torch.randn(response_function.shape) * response_function_noise response_function /= response_function.sum() # normalize sum of response function to 1 if flip_response_function and (torch.rand(1).item() < 0.5): response_function = response_function.flip(-1) # disturbed spectrum disturbed_spectrum = torch.nn.functional.conv1d(theoretical_spectrum, response_function, padding=padding_size) # add noise noised_spectrum = disturbed_spectrum + torch.randn(disturbed_spectrum.shape) * random_value(spectrum_noise_min, spectrum_noise_max) out = { # 'response_function': response_function, 'theoretical_spectrum': theoretical_spectrum, 'disturbed_spectrum': disturbed_spectrum, 'noised_spectrum': noised_spectrum, } if include_response_function: out['response_function'] = response_function if include_spectrum_data: out["theoretical_spectrum_data"] = theoretical_spectrum_data if include_peak_mask: all_peaks_rel = torch.cat([peak_data["tff_relative"] for peak_data in theoretical_spectrum_data]) peaks_indices = all_peaks_rel.round().type(torch.int64) out["peaks_mask"] = torch.scatter(torch.zeros(out["theoretical_spectrum"].shape[1]), 0, peaks_indices, 1.).unsqueeze(0) yield out def collate_with_spectrum_data(batch): tensor_keys = set(batch[0].keys()) tensor_keys.remove('theoretical_spectrum_data') out = {k: torch.stack([item[k] for item in batch]) for k in tensor_keys} out["theoretical_spectrum_data"] = [item["theoretical_spectrum_data"] for item in batch] return out class RngGetter: def __init__(self, seed=42): self.rng = torch.Generator() if seed is not None: self.rng.manual_seed(seed) else: self.rng.seed() def get_rng(self, seed=None): # Use provided seed or fall back to instance RNG if seed is not None: rng = torch.Generator() rng.manual_seed(seed) else: rng = self.rng return rng class PeaksParameterDataGenerator: """ Generates peak parameter data for NMR multiplets. This class is responsible for generating the parameters that describe individual peaks in an NMR spectrum (frequencies, heights, widths, Gaussian/Lorentzian ratio). """ def __init__(self, tff_min=None, #may be assigned after initialization tff_max=None, #may be assigned after initialization atom_groups_data_file=None, number_of_signals_min=1, number_of_signals_max=8, relative_frequency_min=-0.4, relative_frequency_max=0.4, spectrum_width_min=0.2, spectrum_width_max=1, relative_width_min=1, relative_width_max=2, relative_height_min=1, relative_height_max=1, thf_min=1/16, thf_max=16, trf_min=0, trf_max=1, multiplicity_j1_min=0, multiplicity_j1_max=15, multiplicity_j2_min=0, multiplicity_j2_max=15, seed=42 ): # Read atom_groups_data from file if atom_groups_data_file is None: self.atom_groups_data = np.ones((1,3), dtype=int) else: self.atom_groups_data = np.atleast_2d(np.loadtxt(atom_groups_data_file, usecols=(1,2,3), dtype=int)) self.tff_min = tff_min self.tff_max = tff_max self.number_of_signals_min = number_of_signals_min self.number_of_signals_max = number_of_signals_max self.relative_frequency_min = relative_frequency_min self.relative_frequency_max = relative_frequency_max self.spectrum_width_min = spectrum_width_min self.spectrum_width_max = spectrum_width_max self.relative_width_min = relative_width_min self.relative_width_max = relative_width_max self.relative_height_min = relative_height_min self.relative_height_max = relative_height_max self.thf_min = thf_min self.thf_max = thf_max self.trf_min = trf_min self.trf_max = trf_max self.multiplicity_j1_min = multiplicity_j1_min self.multiplicity_j1_max = multiplicity_j1_max self.multiplicity_j2_min = multiplicity_j2_min self.multiplicity_j2_max = multiplicity_j2_max self.rng_getter = RngGetter(seed=seed) def set_frq_range(self, frq_min, frq_max): frq_amplitude = frq_max - frq_min frq_center = (frq_max + frq_min) / 2 self.tff_min = frq_center + frq_amplitude * self.relative_frequency_min self.tff_max = frq_center + frq_amplitude * self.relative_frequency_max def __call__(self, seed=None): """ Generate peak parameters data. Args: seed: Optional seed for reproducibility Returns: List of dicts containing peak parameters (without tff_relative) """ if self.tff_min is None or self.tff_max is None: raise ValueError("tff_min and tff_max must be set before calling the generator.") rng = self.rng_getter.get_rng(seed=seed) number_of_signals = torch.randint( self.number_of_signals_min, self.number_of_signals_max + 1, [], generator=rng ) atom_group_indices = torch.randint( 0, len(self.atom_groups_data), [number_of_signals], generator=rng ) width_spectrum = random_loguniform( self.spectrum_width_min, self.spectrum_width_max, generator=rng ) height_spectrum = random_loguniform( self.thf_min, self.thf_max, generator=rng ) peaks_parameters_data = [] for atom_group_index in atom_group_indices: relative_intensity, multiplicity1, multiplicity2 = self.atom_groups_data[atom_group_index] position = random_value(self.tff_min, self.tff_max, generator=rng) j1 = random_value(self.multiplicity_j1_min, self.multiplicity_j1_max, generator=rng) j2 = random_value(self.multiplicity_j2_min, self.multiplicity_j2_max, generator=rng) width = width_spectrum * random_loguniform( self.relative_width_min, self.relative_width_max, generator=rng ) height = height_spectrum * relative_intensity * random_loguniform( self.relative_height_min, self.relative_height_max, generator=rng ) gaussian_contribution = random_value(self.trf_min, self.trf_max, generator=rng) peak_parameters = generate_multiplet_parameters( multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin=width, trf_lin=gaussian_contribution, j1=j1, j2=j2 ) peaks_parameters_data.append(peak_parameters) return peaks_parameters_data class TheoreticalMultipletSpectraGenerator: """ Generates theoretical NMR multiplet spectra. This class combines peak parameter generation with spectrum calculation. It can accept either a PeaksParameterDataGenerator instance or parameters to create one. """ def __init__(self, peaks_parameter_generator, pixels=2048, frq_step=11160.7142857 / 32768, relative_frequency_min=-0.4, relative_frequency_max=0.4, frequency_min=None, #if None, the 0 will be in the center of spectrum frequency_max=None, include_tff_relative=False, seed=42 ): # Spectrum-level parameters self.pixels = pixels self.frq_step = frq_step self.relative_frequency_min = relative_frequency_min self.relative_frequency_max = relative_frequency_max self.include_tff_relative = include_tff_relative # Frequency axis self.frq_frq, frq_min, frq_max = self._frequency_axis_from_parameters(frq_step, pixels, frequency_min, frequency_max) self.peaks_parameter_generator = peaks_parameter_generator self.peaks_parameter_generator.set_frq_range(frq_min, frq_max) # self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator def _frequency_axis_from_parameters(self, frq_step, pixels, frequency_min, frequency_max): """frq_step is never None, pixels, frequency_min or frequency_max can be None """ # Option 1: from pixels and frq_step if pixels is not None: assert (frequency_min is None) or (frequency_max is None) if (frequency_min is None) and (frequency_max is None): # if both are None, center at 0 frequency_min = -(pixels // 2) * frq_step elif frequency_min is None: # frequency_max is not None, use it to calculate frequency_min frequency_min = frequency_max - pixels * frq_step frq_frq = torch.arange(0, pixels) * frq_step + frequency_min # Option 2: from frequency_min and frequency_max elif (frequency_min is not None) and (frequency_max is not None): pixels = round((frequency_max - frequency_min) / frq_step) frq_frq = torch.arange(0, pixels) * frq_step + frequency_min else: raise ValueError("Insufficient parameters to determine frequency axis.") return frq_frq, frq_frq[0], frq_frq[-1] def __call__(self, seed=None): """ Generate a theoretical spectrum. Args: seed: Optional seed for reproducibility Returns: Tuple of (spectrum, dict with spectrum_data and frq_frq) """ # Generate peak parameters (peaks_parameter_generator has its own RngGetter) peaks_parameters_data = self.peaks_parameter_generator(seed=seed) # Add tff_relative if requested if self.include_tff_relative: for peak_params in peaks_parameters_data: peak_params["tff_relative"] = value_to_index(peak_params["tff_lin"], self.frq_frq) # Create spectrum from peaks spectrum = spectrum_from_peaks_data(peaks_parameters_data, self.frq_frq) return spectrum, {"spectrum_data": peaks_parameters_data, "frq_frq": self.frq_frq} class PeaksParametersNames(Enum): """Enum for standardized peak parameter names.""" position_hz ="tff_lin" height = "thf_lin" halfwidth_hz = "twf_lin" gaussian_fraction = "trf_lin" @classmethod def keys(cls): return [member.value for member in cls] @classmethod def values(cls): return [member.name for member in cls] class PeaksParametersParser: """class to convert peaks parameters from `{"width_hz": [...], "height": ..., ...}` format to `{"twf_lin": torch.tensor([...]), "thf_lin": ..., ...}` format.""" def __init__(self, alias_position_hz = None, alias_height = None, alias_width_hz = None, alias_gaussian_fraction = None, default_position_hz = None, default_height = None, default_width_hz = None, default_gaussian_fraction = 0., convert_width_to_halfwidth = True ): self.alias_position_hz = alias_position_hz if alias_position_hz is not None else "position_hz" self.alias_height = alias_height if alias_height is not None else "height" self.alias_width_hz = alias_width_hz if alias_width_hz is not None else "width_hz" self.alias_gaussian_fraction = alias_gaussian_fraction if alias_gaussian_fraction is not None else "gaussian_fraction" self.default_position_hz = default_position_hz self.default_height = default_height self.default_width_hz = default_width_hz self.default_gaussian_fraction = default_gaussian_fraction self.convert_width_to_halfwidth = convert_width_to_halfwidth def transform_single_peak(self, peak: dict) -> dict: parsed_peak = { PeaksParametersNames.position_hz.value: peak.get(self.alias_position_hz, self.default_position_hz), PeaksParametersNames.height.value: peak.get(self.alias_height, self.default_height), PeaksParametersNames.halfwidth_hz.value: (0.5 if self.convert_width_to_halfwidth else 1.) * peak.get(self.alias_width_hz, self.default_width_hz), PeaksParametersNames.gaussian_fraction.value: peak.get(self.alias_gaussian_fraction, self.default_gaussian_fraction), } # Validate and convert other peak parameters for k, v in parsed_peak.items(): if v is None: raise ValueError(f"Peak parameter '{k}' is None.") parsed_peak[k] = torch.atleast_1d(v.float() if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=torch.float32)) return parsed_peak def transform(self, spectrum_peaks: list[dict]) -> list[dict]: parsed_peaks = [] for peak in spectrum_peaks: parsed_peaks.append(self.transform_single_peak(peak)) return parsed_peaks def csv_file_to_multiplets_dict(file_path: str) -> list[dict]: peaks_data = pd.read_csv(file_path) multiplets = {k: v.drop(columns="multiplet_name").to_dict(orient='list') for k, v in peaks_data.groupby("multiplet_name")} return multiplets def combine_multiplets(multiplets_list: list[dict]) -> dict: composed_multiplets = {} for multiplets in multiplets_list: for k, v in multiplets.items(): if not k in composed_multiplets: composed_multiplets[k] = v else: composed_multiplets[k].extend(v) return composed_multiplets class MultipletsLibrary: def __init__(self, csv_files_paths: list[str], peak_data_parser: PeaksParametersParser = None, return_name=False): self.csv_files_paths = csv_files_paths self.multiplets_data = {} self.peak_data_parser = peak_data_parser for file_path in csv_files_paths: self.multiplets_data.update(self._get_multiplet_data_from_file(file_path)) self.names = sorted(self.multiplets_data.keys()) self.return_name = return_name def _get_multiplet_data_from_file(self, file_path: str) -> dict: multiplets = csv_file_to_multiplets_dict(file_path) # dict[dict] multiplets_out = {} for k, v in multiplets.items(): multiplets_out[f"{file_path}/{k}"] = self.peak_data_parser.transform([v])[0] if self.peak_data_parser else v return multiplets_out def get_by_name(self, name: str) -> dict: return self.multiplets_data.get(name, None) def __getitem__(self, idx: int) -> dict: name = self.names[idx] multiplet_data = deepcopy(self.multiplets_data[name]) if self.return_name: return name, multiplet_data return multiplet_data def __len__(self): return len(self.multiplets_data) class SectraLibrary(MultipletsLibrary): def _get_multiplet_data_from_file(self, file_path: str) -> dict: multiplets = csv_file_to_multiplets_dict(file_path) # dict[dict] combined_multiplet = combine_multiplets(multiplets.values()) # dict return {f"{file_path}": self.peak_data_parser.transform([combined_multiplet])[0]} class MultipletDataFromMultipletsLibrary: def __init__(self, multiplets_library, tff_min=None, #may be assigned after initialization if the original peak positions are not used tff_max=None, #may be assigned after initialization if the original peak positions are not used use_original_peak_position=True, number_of_signals_min=None, number_of_signals_max=None, relative_frequency_min=None, relative_frequency_max=None, spectrum_width_factor_min=1, spectrum_width_factor_max=1, multiplet_width_factor_min=1, multiplet_width_factor_max=1, multiplet_width_additive_min=0, multiplet_width_additive_max=0, spectrum_height_factor_min=1, spectrum_height_factor_max=1, multiplet_height_factor_min=1, multiplet_height_factor_max=1, multiplet_height_additive_min=0, multiplet_height_additive_max=0, position_shift_min=0, position_shift_max=0, gaussian_fraction_change_min=None, gaussian_fraction_change_max=None, gaussian_fraction_change_additive_min=0., gaussian_fraction_change_additive_max=0., seed=42 ): if (number_of_signals_min is None) != (number_of_signals_max is None): raise ValueError("Both number_of_signals_min and number_of_signals_max should be provided or both should be None.") self.multiplets_library = multiplets_library self.rng_getter = RngGetter(seed=seed) self.tff_min = tff_min self.tff_max = tff_max self.relative_frequency_min = relative_frequency_min self.relative_frequency_max = relative_frequency_max self.use_original_peak_position = use_original_peak_position self.number_of_signals_min = number_of_signals_min self.number_of_signals_max = number_of_signals_max self.spectrum_width_factor_min = spectrum_width_factor_min self.spectrum_width_factor_max = spectrum_width_factor_max self.multiplet_width_factor_min = multiplet_width_factor_min self.multiplet_width_factor_max = multiplet_width_factor_max self.multiplet_width_additive_min = multiplet_width_additive_min self.multiplet_width_additive_max = multiplet_width_additive_max self.spectrum_height_factor_min = spectrum_height_factor_min self.spectrum_height_factor_max = spectrum_height_factor_max self.multiplet_height_factor_min = multiplet_height_factor_min self.multiplet_height_factor_max = multiplet_height_factor_max self.multiplet_height_additive_min = multiplet_height_additive_min self.multiplet_height_additive_max = multiplet_height_additive_max self.position_shift_min = position_shift_min self.position_shift_max = position_shift_max self.gaussian_fraction_change_min = gaussian_fraction_change_min self.gaussian_fraction_change_max = gaussian_fraction_change_max self.gaussian_fraction_change_additive_min = gaussian_fraction_change_additive_min self.gaussian_fraction_change_additive_max = gaussian_fraction_change_additive_max def set_frq_range(self, frq_min, frq_max): frq_amplitude = frq_max - frq_min frq_center = (frq_max + frq_min) / 2 self.tff_min = frq_center + frq_amplitude * self.relative_frequency_min self.tff_max = frq_center + frq_amplitude * self.relative_frequency_max def __call__(self, seed=None): if (not self.use_original_peak_position) and (self.tff_min is None or self.tff_max is None): raise ValueError("for use_original_peak_position=False, tff_min and tff_max must be set before calling the generator.") rng = self.rng_getter.get_rng(seed=seed) # select number of signals and their indices if self.number_of_signals_min is None: number_of_signals = len(self.multiplets_library) multiplets_indices = list(range(len(self.multiplets_library))) else: number_of_signals = torch.randint( self.number_of_signals_min, self.number_of_signals_max + 1, [], generator=rng ) multiplets_indices = torch.randint( 0, len(self.multiplets_library), [number_of_signals], generator=rng ) # spectrum width and height factors spectrum_width_factor = random_loguniform( self.spectrum_width_factor_min, self.spectrum_width_factor_max, generator=rng ) spectrum_height_factor = random_loguniform( self.spectrum_height_factor_min, self.spectrum_height_factor_max, generator=rng ) # get and modify peaks parameters data peaks_parameters_data = [self.multiplets_library[idx] for idx in multiplets_indices] for peak_parameters in peaks_parameters_data: # position if not self.use_original_peak_position: new_position_center = random_value(self.tff_min, self.tff_max, generator=rng) peak_parameters["tff_lin"] += new_position_center - torch.mean(peak_parameters["tff_lin"]) else: position_shift = random_value(self.position_shift_min, self.position_shift_max, generator=rng) peak_parameters["tff_lin"] += position_shift # width multiplet_width_factor = random_loguniform( self.multiplet_width_factor_min, self.multiplet_width_factor_max, generator=rng ) multiplet_width_additive = random_uniform( self.multiplet_width_additive_min, self.multiplet_width_additive_max, generator=rng ) peak_parameters["twf_lin"] = peak_parameters["twf_lin"] * spectrum_width_factor * multiplet_width_factor + multiplet_width_additive # height multiplet_height_factor = random_loguniform( self.multiplet_height_factor_min, self.multiplet_height_factor_max, generator=rng ) multiplet_height_additive = random_uniform( self.multiplet_height_additive_min, self.multiplet_height_additive_max, generator=rng ) peak_parameters["thf_lin"] = peak_parameters["thf_lin"] * spectrum_height_factor * multiplet_height_factor + multiplet_height_additive # gaussian contribution if self.gaussian_fraction_change_min is not None: gaussian_contribution_shift = random_value(self.gaussian_fraction_change_min, self.gaussian_fraction_change_max, generator=rng) gaussian_contribution_additive = random_value(self.gaussian_fraction_change_additive_min, self.gaussian_fraction_change_additive_max, generator=rng) gaussian_contribution_shift += gaussian_contribution_additive peak_parameters["trf_lin"] = torch.clip(peak_parameters["trf_lin"] + gaussian_contribution_shift, 0., 1.) return peaks_parameters_data class ResponseGenerator: def __init__(self, response_function_library, response_function_stretch_min=1., response_function_stretch_max=1., pad_to=None, response_function_noise=0.0, flip_response_function=False, seed=42): self.response_function_library = response_function_library self.response_function_stretch_min = response_function_stretch_min self.response_function_stretch_max = response_function_stretch_max self.pad_to = pad_to self.response_function_noise = response_function_noise self.flip_response_function = flip_response_function self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator def __call__(self, seed=None): rng = self.rng_getter.get_rng(seed=seed) response_function = self.response_function_library[torch.randint(0, len(self.response_function_library), [1], generator=rng)][0] padding_size = (response_function.shape[-1] - 1)//2 padding_size = round(random_loguniform(self.response_function_stretch_min, self.response_function_stretch_max, generator=rng)*padding_size) response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear') response_function /= response_function.sum() response_function += torch.randn(response_function.shape, generator=rng) * self.response_function_noise response_function /= response_function.sum() if self.flip_response_function and (torch.rand(1, generator=rng).item() < 0.5): response_function = response_function.flip(-1) if self.pad_to is not None: pad_size_left = (self.pad_to - response_function.shape[-1]) // 2 pad_size_right = self.pad_to - response_function.shape[-1] - pad_size_left response_function = torch.nn.functional.pad(response_function, (pad_size_left, pad_size_right)) return response_function class NoiseGenerator: def __init__(self, spectrum_noise_min=0., spectrum_noise_max=1/64, seed=42): self.spectrum_noise_min = spectrum_noise_min self.spectrum_noise_max = spectrum_noise_max self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator def __call__(self, disturbed_spectrum, seed=None): rng = self.rng_getter.get_rng(seed=seed) return disturbed_spectrum + torch.randn(disturbed_spectrum.shape, generator=rng) * random_value(self.spectrum_noise_min, self.spectrum_noise_max, generator=rng) class BaseGenerator(ABC): """ Single-threaded base generator. For this workload, single-threaded execution is typically faster because: - Thread creation/synchronization overhead > computation time - Python GIL contention during object creation - Memory allocator contention when multiple threads allocate tensors - CPU cache thrashing across cores - Small per-thread workload doesn't amortize thread overhead """ def __init__(self, batch_size=64, seed=None): self.batch_size = batch_size self.seed = seed def set_seed(self, seed): self.seed = seed @abstractmethod def _generate_element(self, seed): pass def __iter__(self): rng = torch.Generator() if self.seed is not None: rng.manual_seed(self.seed) else: rng.seed() while True: batch = [] # Generate unique seeds for each element in the batch if self.seed is not None: element_seeds = [torch.randint(0, 2**31, (1,), generator=rng).item() for _ in range(self.batch_size)] else: element_seeds = [None] * self.batch_size # Single-threaded sequential generation for i in range(self.batch_size): batch.append(self._generate_element(element_seeds[i])) yield self.collate_fn(batch) @abstractmethod def collate_fn(self, batch): pass class BaseGeneratorMultithread(ABC): """ Multithreaded base generator (backup option). Use only if profiling shows benefit for your specific use case (e.g., very large/slow generation functions, I/O-bound operations). """ def __init__(self, batch_size=64, num_workers=4, seed=None, ordered_batch=False): self.batch_size = batch_size self.num_workers = num_workers self.seed = seed self.ordered_batch = ordered_batch def set_seed(self, seed): self.seed = seed def set_ordered_batch(self, ordered_batch): self.ordered_batch = ordered_batch @abstractmethod def _generate_element(self, seed): pass def __iter__(self): rng = torch.Generator() if self.seed is not None: rng.manual_seed(self.seed) else: rng.seed() while True: batch = [] # Generate unique seeds for each element in the batch if self.seed is not None: element_seeds = [torch.randint(0, 2**31, (1,), generator=rng).item() for _ in range(self.batch_size)] else: element_seeds = [None] * self.batch_size with ThreadPoolExecutor(max_workers=self.num_workers) as executor: futures = [executor.submit(self._generate_element, element_seeds[i]) for i in range(self.batch_size)] if self.ordered_batch: # Maintain order: iterate futures in submission order for f in futures: batch.append(f.result()) else: # Faster: process as completed (order may vary) for f in as_completed(futures): batch.append(f.result()) yield self.collate_fn(batch) @abstractmethod def collate_fn(self, batch): pass class Generator(BaseGenerator): def __init__(self, clean_spectra_generator, response_generator, noise_generator, batch_size=64, include_spectrum_data=False, include_peak_mask=False, include_response_function=False, input_normalization_height=None, seed=None): super().__init__(batch_size=batch_size, seed=seed) self.clean_spectra_generator = clean_spectra_generator self.response_generator = response_generator self.noise_generator = noise_generator self.include_spectrum_data = include_spectrum_data self.include_peak_mask = include_peak_mask self.include_response_function = include_response_function self.input_normalization_height = input_normalization_height def _generate_element(self, seed): # Generate different seeds for each generator from the provided seed if seed is not None: rng = torch.Generator() rng.manual_seed(seed) clean_seed = torch.randint(0, 2**31, (1,), generator=rng).item() response_seed = torch.randint(0, 2**31, (1,), generator=rng).item() noise_seed = torch.randint(0, 2**31, (1,), generator=rng).item() else: clean_seed = None response_seed = None noise_seed = None clean_spectrum, extra_clean_data = self.clean_spectra_generator(seed=clean_seed) response_function = self.response_generator(seed=response_seed) padding_size = (response_function.shape[-1] - 1)//2 disturbed_spectrum = torch.nn.functional.conv1d(clean_spectrum, response_function, padding=padding_size) if self.input_normalization_height is not None: max_val = torch.max(disturbed_spectrum) clean_spectrum = clean_spectrum / max_val * self.input_normalization_height disturbed_spectrum = disturbed_spectrum / max_val * self.input_normalization_height # noise after normalization to better control noise level noised_spectrum = self.noise_generator(disturbed_spectrum, seed=noise_seed) out = { 'theoretical_spectrum': clean_spectrum, 'disturbed_spectrum': disturbed_spectrum, 'noised_spectrum': noised_spectrum, } if self.include_spectrum_data: out['theoretical_spectrum_data'] = extra_clean_data['spectrum_data'] out['frq_frq'] = extra_clean_data['frq_frq'] if self.include_peak_mask and extra_clean_data is not None: all_peaks_rel = torch.cat([peak_data["tff_relative"] for peak_data in extra_clean_data['spectrum_data']]) peaks_indices = all_peaks_rel.round().type(torch.int64) out["peaks_mask"] = torch.scatter(torch.zeros(out["theoretical_spectrum"].shape[1]), 0, peaks_indices, 1.).unsqueeze(0) if self.include_response_function: out['response_function'] = response_function return out def collate_fn(self, batch): tensor_keys = set(batch[0].keys()) for k in ['theoretical_spectrum_data', 'frq_frq']: tensor_keys.discard(k) out = {k: torch.stack([item[k] for item in batch]) for k in tensor_keys} if 'theoretical_spectrum_data' in batch[0]: out['theoretical_spectrum_data'] = [item['theoretical_spectrum_data'] for item in batch] if 'frq_frq' in batch[0]: out['frq_frq'] = [item['frq_frq'] for item in batch] return out class PeaksParametersFromSinglets: def __init__(self, singlets_files: list[pd.DataFrame], number_of_signals_min: int = 5, number_of_signals_max: int = 20, use_original_position: bool = True, position_hz_min: Optional[float] = None, position_hz_max: Optional[float] = None, position_hz_change_min: float = 0.0, position_hz_change_max: float = 0.0, relative_frequency_min: float = -0.4, # used only if position_hz_min/max are None relative_frequency_max: float = 0.4, use_original_width: bool = True, width_hz_min: float = 0.2, width_hz_max: float = 2.0, width_factor_min: float = 1.0, width_factor_max: float = 1.0, width_hz_change_min: float = 0.0, width_hz_change_max: float = 0.0, convert_width_to_halfwidth: bool = True, use_original_height: bool = True, height_min: float = 0.1, height_max: float = 10.0, height_factor_min: float = 1.0, height_factor_max: float = 1.0, height_change_min: float = 0.0, height_change_max: float = 0.0, use_original_gaussian_fraction: bool = True, gaussian_fraction_min: float = 0.0, gaussian_fraction_max: float = 1.0, gaussian_fraction_change_min: float = 0.0, gaussian_fraction_change_max: float = 0.0, seed=42 ): self.peaks_rows = pd.concat([pd.read_csv(f) for f in singlets_files], ignore_index=True) # number of signals self.number_of_signals_min = number_of_signals_min self.number_of_signals_max = number_of_signals_max # position self.use_original_position = use_original_position self.position_hz_min = position_hz_min self.position_hz_max = position_hz_max self.position_hz_change_min = position_hz_change_min self.position_hz_change_max = position_hz_change_max self.relative_frequency_min = relative_frequency_min self.relative_frequency_max = relative_frequency_max # width self.use_original_width = use_original_width self.width_hz_min = width_hz_min self.width_hz_max = width_hz_max self.width_factor_min = width_factor_min self.width_factor_max = width_factor_max self.width_hz_change_min = width_hz_change_min self.width_hz_change_max = width_hz_change_max self.convert_width_to_halfwidth = convert_width_to_halfwidth # if True, the original widths will be divided by 2 # height self.use_original_height = use_original_height self.height_min = height_min self.height_max = height_max self.height_factor_min = height_factor_min self.height_factor_max = height_factor_max self.height_change_min = height_change_min self.height_change_max = height_change_max # gaussian fraction self.use_original_gaussian_fraction = use_original_gaussian_fraction self.gaussian_fraction_min = gaussian_fraction_min self.gaussian_fraction_max = gaussian_fraction_max self.gaussian_fraction_change_min = gaussian_fraction_change_min self.gaussian_fraction_change_max = gaussian_fraction_change_max self.rng_getter = RngGetter(seed=seed) def set_frq_range(self, frq_min, frq_max): frq_amplitude = frq_max - frq_min frq_center = (frq_max + frq_min) / 2 self.position_hz_min = frq_center + frq_amplitude * self.relative_frequency_min self.position_hz_max = frq_center + frq_amplitude * self.relative_frequency_max def __call__(self, seed=None) -> list[dict]: rng = self.rng_getter.get_rng(seed=seed) number_of_signals = torch.randint( low=self.number_of_signals_min, high=min(self.number_of_signals_max, len(self.peaks_rows) + 1), size=[], generator=rng ) selected_peaks = self.peaks_rows.sample(n=number_of_signals.item(), random_state=seed) multiplet_data = {} # position if self.use_original_position: multiplet_data[PeaksParametersNames.position_hz.value] = torch.tensor(selected_peaks["position_hz"].values, dtype=torch.float32) + random_uniform_vector(self.position_hz_change_min, self.position_hz_change_max, size=len(selected_peaks)) else: multiplet_data[PeaksParametersNames.position_hz.value] = random_uniform_vector(self.position_hz_min, self.position_hz_max, size=len(selected_peaks)) # width if self.use_original_width: multiplet_data[PeaksParametersNames.halfwidth_hz.value] = (0.5 if self.convert_width_to_halfwidth else 1.)*torch.tensor(selected_peaks["width_hz"].values, dtype=torch.float32) * random_uniform_vector(self.width_factor_min, self.width_factor_max, size=len(selected_peaks)) + random_uniform_vector(self.width_hz_change_min, self.width_hz_change_max, size=len(selected_peaks)) else: multiplet_data[PeaksParametersNames.halfwidth_hz.value] = random_loguniform_vector(self.width_hz_min, self.width_hz_max, size=len(selected_peaks)) # height if self.use_original_height: multiplet_data[PeaksParametersNames.height.value] = torch.tensor(selected_peaks["height"].values, dtype=torch.float32) * random_uniform_vector(self.height_factor_min, self.height_factor_max, size=len(selected_peaks)) + random_uniform_vector(self.height_change_min, self.height_change_max, size=len(selected_peaks)) else: multiplet_data[PeaksParametersNames.height.value] = random_loguniform_vector(self.height_min, self.height_max, size=len(selected_peaks)) # gaussian fraction if self.use_original_gaussian_fraction: multiplet_data[PeaksParametersNames.gaussian_fraction.value] = torch.clamp(torch.tensor(selected_peaks["gaussian_fraction"].values, dtype=torch.float32) + random_uniform_vector(self.gaussian_fraction_change_min, self.gaussian_fraction_change_max, size=len(selected_peaks)), 0.0, 1.0) else: multiplet_data[PeaksParametersNames.gaussian_fraction.value] = random_uniform_vector(self.gaussian_fraction_min, self.gaussian_fraction_max, size=len(selected_peaks)) return [multiplet_data]