ShimNet / src /generators.py
Marek Bukowicki
add shimnet code
64b4096
import numpy as np
import torch
import torchdata
# from itertools import islice
def random_value(min_value, max_value):
return (min_value + torch.rand(1) * (max_value - min_value)).item()
def random_loguniform(min_value, max_value):
return (min_value * torch.exp(torch.rand(1) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
def calculate_theoretical_spectrum(peaks_parameters: dict, frq_frq:torch.Tensor):
# extract parameters
tff_lin = peaks_parameters["tff_lin"]
twf_lin = peaks_parameters["twf_lin"]
thf_lin = peaks_parameters["thf_lin"]
trf_lin = peaks_parameters["trf_lin"]
lwf_lin = twf_lin
lhf_lin = thf_lin * (1. - trf_lin)
gwf_lin = twf_lin
gdf_lin = gwf_lin / torch.tensor(2.).log().mul(2.).sqrt()
ghf_lin = thf_lin * trf_lin
# calculate Lorenz peaks contriubutions
lsf_linfrq = lwf_lin[:, None] ** 2 / (lwf_lin[:, None] ** 2 + (frq_frq - tff_lin[:, None]) ** 2) * lhf_lin[:, None]
# calculate Gaussian peaks contriubutions
gsf_linfrq = torch.exp(-(frq_frq - tff_lin[:, None]) ** 2 / gdf_lin[:, None] ** 2 / 2.) * ghf_lin[:, None]
tsf_linfrq = lsf_linfrq + gsf_linfrq
# sum peaks contriubutions
tsf_frq = tsf_linfrq.sum(0, keepdim = True)
return tsf_frq
pascal_triangle = [(1,), (1,1), (1,2,1), (1,3,3,1), (1,4,6,4,1), (1,5,10,10,5,1), (1,6,15,20,15,6,1), (1,7, 21,35,35,21,7,1)]
normalized_pascal_triangle = [torch.tensor(x)/sum(x) for x in pascal_triangle]
def pascal_multiplicity(multiplicity):
intensities = normalized_pascal_triangle[multiplicity-1]
n_peaks = len(intensities)
shifts = torch.arange(n_peaks)-((n_peaks-1)/2)
return shifts, intensities
def double_multiplicity(multiplicity1, multiplicity2, j1=1, j2=1):
shifts1, intensities1 = pascal_multiplicity(multiplicity1)
shifts2, intensities2 = pascal_multiplicity(multiplicity2)
shifts = (j1*shifts1.reshape(-1,1) + j2*shifts2.reshape(1,-1)).flatten()
intensities = (intensities1.reshape(-1,1) * intensities2.reshape(1,-1)).flatten()
return shifts, intensities
def generate_multiplet_parameters(multiplicity, tff_lin, thf_lin, twf_lin, trf_lin, j1, j2):
shifts, intensities = double_multiplicity(multiplicity[0], multiplicity[1], j1, j2)
n_peaks = len(shifts)
return {
"tff_lin": shifts + tff_lin,
"thf_lin": intensities * thf_lin,
"twf_lin": torch.full((n_peaks,), twf_lin),
"trf_lin": torch.full((n_peaks,), trf_lin),
}
def value_to_index(values, table):
span = table[-1] - table[0]
indices = ((values - table[0])/span * (len(table)-1)) #.round().type(torch.int64)
return indices
def generate_theoretical_spectrum(
number_of_signals_min, number_of_signals_max,
spectrum_width_min, spectrum_width_max,
relative_width_min, relative_width_max,
tff_min, tff_max,
thf_min, thf_max,
trf_min, trf_max,
relative_height_min, relative_height_max,
multiplicity_j1_min, multiplicity_j1_max,
multiplicity_j2_min, multiplicity_j2_max,
atom_groups_data,
frq_frq
):
number_of_signals = torch.randint(number_of_signals_min, number_of_signals_max+1, [])
atom_group_indices = torch.randint(0, len(atom_groups_data), [number_of_signals])
width_spectrum = random_loguniform(spectrum_width_min, spectrum_width_max)
height_spectrum = random_loguniform(thf_min, thf_max)
peak_parameters_data = []
theoretical_spectrum = None
for atom_group_index in atom_group_indices:
relative_intensity, multiplicity1, multiplicity2 = atom_groups_data[atom_group_index]
position = random_value(tff_min, tff_max)
j1 = random_value(multiplicity_j1_min, multiplicity_j1_max)
j2 = random_value(multiplicity_j2_min, multiplicity_j2_max)
width = width_spectrum*random_loguniform(relative_width_min, relative_width_max)
height = height_spectrum*relative_intensity*random_loguniform(relative_height_min, relative_height_max)
gaussian_contribution = random_value(trf_min, trf_max)
peaks_parameters = generate_multiplet_parameters(multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin= width, trf_lin= gaussian_contribution, j1=j1, j2=j2)
peaks_parameters["tff_relative"] = value_to_index(peaks_parameters["tff_lin"], frq_frq)
peak_parameters_data.append(peaks_parameters)
spectrum_contribution = calculate_theoretical_spectrum(peaks_parameters, frq_frq)
if theoretical_spectrum is None:
theoretical_spectrum = spectrum_contribution
else:
theoretical_spectrum += spectrum_contribution
return theoretical_spectrum, peak_parameters_data
def theoretical_generator(
atom_groups_data,
pixels=2048, frq_step=11160.7142857 / 32768,
number_of_signals_min=1, number_of_signals_max=8,
spectrum_width_min=0.2, spectrum_width_max=1,
relative_width_min=1, relative_width_max=2,
relative_height_min=1, relative_height_max=1,
relative_frequency_min=-0.4, relative_frequency_max=0.4,
thf_min=1/16, thf_max=16,
trf_min=0, trf_max=1,
multiplicity_j1_min=0, multiplicity_j1_max=15,
multiplicity_j2_min=0, multiplicity_j2_max=15,
):
tff_min = relative_frequency_min * pixels * frq_step
tff_max = relative_frequency_max * pixels * frq_step
frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step
while True:
yield generate_theoretical_spectrum(
number_of_signals_min=number_of_signals_min,
number_of_signals_max=number_of_signals_max,
spectrum_width_min=spectrum_width_min,
spectrum_width_max=spectrum_width_max,
relative_width_min=relative_width_min,
relative_width_max=relative_width_max,
relative_height_min=relative_height_min,
relative_height_max=relative_height_max,
tff_min=tff_min, tff_max=tff_max,
thf_min=thf_min, thf_max=thf_max,
trf_min=trf_min, trf_max=trf_max,
multiplicity_j1_min=multiplicity_j1_min,
multiplicity_j1_max=multiplicity_j1_max,
multiplicity_j2_min=multiplicity_j2_min,
multiplicity_j2_max=multiplicity_j2_max,
atom_groups_data=atom_groups_data,
frq_frq=frq_frq
)
class ResponseLibrary:
def __init__(self, reponse_files, normalize=True):
self.data = [torch.load(f, map_location='cpu', weights_only=True).flatten(0,-4) for f in reponse_files]
if normalize:
self.data = [data/torch.sum(data, dim=(-1,), keepdim=True) for data in self.data]
lengths = [len(data) for data in self.data]
self.start_indices = torch.cumsum(torch.tensor([0] + lengths[:-1]), 0)
self.total_length = sum(lengths)
def __getitem__(self, idx):
if idx >= self.total_length:
raise ValueError(f'index {idx} out of range')
tensor_index = torch.searchsorted(self.start_indices, idx, right=True) - 1
return self.data[tensor_index][idx - self.start_indices[tensor_index]]
def __len__(self):
return self.total_length
def generator(
theoretical_generator_params,
response_function_library,
response_function_stretch_min=0.5,
response_function_stretch_max=2.0,
response_function_noise=0.,
spectrum_noise_min=0.,
spectrum_noise_max=1/64,
include_spectrum_data=False,
include_peak_mask=False,
include_response_function=False,
flip_response_function=False
):
for theoretical_spectrum, theoretical_spectrum_data in theoretical_generator(**theoretical_generator_params):
# get response function
response_function = response_function_library[torch.randint(0, len(response_function_library), [1])][0]
# stretch response function
padding_size = (response_function.shape[-1] - 1)//2
padding_size = round(random_loguniform(response_function_stretch_min, response_function_stretch_max)*padding_size) #torch.randint(round(padding_size*response_function_stretch_min), round(padding_size*response_function_stretch_max), [1]).item()
response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
response_function /= response_function.sum() # normalize sum of response function to 1
# add noise to response function
response_function += torch.randn(response_function.shape) * response_function_noise
response_function /= response_function.sum() # normalize sum of response function to 1
if flip_response_function and (torch.rand(1).item() < 0.5):
response_function = response_function.flip(-1)
# disturbed spectrum
disturbed_spectrum = torch.nn.functional.conv1d(theoretical_spectrum, response_function, padding=padding_size)
# add noise
noised_spectrum = disturbed_spectrum + torch.randn(disturbed_spectrum.shape) * random_value(spectrum_noise_min, spectrum_noise_max)
out = {
# 'response_function': response_function,
'theoretical_spectrum': theoretical_spectrum,
'disturbed_spectrum': disturbed_spectrum,
'noised_spectrum': noised_spectrum,
}
if include_response_function:
out['response_function'] = response_function
if include_spectrum_data:
out["theoretical_spectrum_data"] = theoretical_spectrum_data
if include_peak_mask:
all_peaks_rel = torch.cat([peak_data["tff_relative"] for peak_data in theoretical_spectrum_data])
peaks_indices = all_peaks_rel.round().type(torch.int64)
out["peaks_mask"] = torch.scatter(torch.zeros(out["theoretical_spectrum"].shape[1]), 0, peaks_indices, 1.).unsqueeze(0)
yield out
def collate_with_spectrum_data(batch):
tensor_keys = set(batch[0].keys())
tensor_keys.remove('theoretical_spectrum_data')
out = {k: torch.stack([item[k] for item in batch]) for k in tensor_keys}
out["theoretical_spectrum_data"] = [item["theoretical_spectrum_data"] for item in batch]
return out
def get_datapipe(
response_functions_files,
atom_groups_data_file=None,
batch_size=64,
pixels=2048, frq_step=11160.7142857 / 32768,
number_of_signals_min=1, number_of_signals_max=8,
spectrum_width_min=0.2, spectrum_width_max=1,
relative_width_min=1, relative_width_max=2,
relative_height_min=1, relative_height_max=1,
relative_frequency_min=-0.4, relative_frequency_max=0.4,
thf_min=1/16, thf_max=16,
trf_min=0, trf_max=1,
multiplicity_j1_min=0, multiplicity_j1_max=15,
multiplicity_j2_min=0, multiplicity_j2_max=15,
response_function_stretch_min=0.5,
response_function_stretch_max=2.0,
response_function_noise=0.,
spectrum_noise_min=0.,
spectrum_noise_max=1/64,
include_spectrum_data=False,
include_peak_mask=False,
include_response_function=False,
flip_response_function=False
):
# singlets
if atom_groups_data_file is None:
atom_groups_data = np.ones((1,3), dtype=int)
else:
atom_groups_data = np.loadtxt(atom_groups_data_file, usecols=(1,2,3), dtype=int)
response_function_library = ResponseLibrary(response_functions_files)
g = generator(
theoretical_generator_params=dict(
atom_groups_data=atom_groups_data,
pixels=pixels, frq_step=frq_step,
number_of_signals_min=number_of_signals_min, number_of_signals_max=number_of_signals_max,
spectrum_width_min=spectrum_width_min, spectrum_width_max=spectrum_width_max,
relative_width_min=relative_width_min, relative_width_max=relative_width_max,
relative_height_min=relative_height_min, relative_height_max=relative_height_max,
relative_frequency_min=relative_frequency_min, relative_frequency_max=relative_frequency_max,
thf_min=thf_min, thf_max=thf_max,
trf_min=trf_min, trf_max=trf_max,
multiplicity_j1_min=multiplicity_j1_min, multiplicity_j1_max=multiplicity_j1_max,
multiplicity_j2_min=multiplicity_j2_min, multiplicity_j2_max=multiplicity_j2_max
),
response_function_library=response_function_library,
response_function_stretch_min=response_function_stretch_min,
response_function_stretch_max=response_function_stretch_max,
response_function_noise=response_function_noise,
spectrum_noise_min=spectrum_noise_min,
spectrum_noise_max=spectrum_noise_max,
include_spectrum_data=include_spectrum_data,
include_peak_mask=include_peak_mask,
include_response_function=include_response_function,
flip_response_function=flip_response_function
)
pipe = torchdata.datapipes.iter.IterableWrapper(g, deepcopy=False)
pipe = pipe.batch(batch_size)
pipe = pipe.collate(collate_fn=collate_with_spectrum_data if include_spectrum_data else None)
return pipe