Marek Bukowicki commited on
Commit
4710de6
·
1 Parent(s): 2495192

separate peaks data generation from spectra generation

Browse files
configs/shimnet_600_modular.yaml CHANGED
@@ -25,27 +25,29 @@ data:
25
  batch_size: null # to be set in training script
26
  clean_spectra_generator:
27
  _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
28
- atom_groups_data_file: data/multiplets_10000_parsed.txt
29
  pixels: 2048
30
  frq_step: ${metadata.frq_step}
31
- number_of_signals_min: 2
32
- number_of_signals_max: 5
33
- spectrum_width_min: 0.2
34
- spectrum_width_max: 1.0
35
- relative_width_min: 1.0
36
- relative_width_max: 2.0
37
- relative_height_min: 0.5
38
- relative_height_max: 4
39
  relative_frequency_min: -0.4
40
  relative_frequency_max: 0.4
41
- thf_min: 0.5
42
- thf_max: 2
43
- trf_min: 0.0
44
- trf_max: 1.0
45
- multiplicity_j1_min: 0.0
46
- multiplicity_j1_max: 15
47
- multiplicity_j2_min: 0.0
48
- multiplicity_j2_max: 15
 
 
 
 
 
 
 
 
 
 
 
49
  response_generator:
50
  _target_: shimnet.generators.ResponseGenerator
51
  response_function_library:
 
25
  batch_size: null # to be set in training script
26
  clean_spectra_generator:
27
  _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
 
28
  pixels: 2048
29
  frq_step: ${metadata.frq_step}
 
 
 
 
 
 
 
 
30
  relative_frequency_min: -0.4
31
  relative_frequency_max: 0.4
32
+ peaks_parameter_generator:
33
+ _target_: shimnet.generators.PeaksParameterDataGenerator
34
+ atom_groups_data_file: data/multiplets_10000_parsed.txt
35
+ number_of_signals_min: 2
36
+ number_of_signals_max: 5
37
+ spectrum_width_min: 0.2
38
+ spectrum_width_max: 1.0
39
+ relative_width_min: 1.0
40
+ relative_width_max: 2.0
41
+ relative_height_min: 0.5
42
+ relative_height_max: 4
43
+ thf_min: 0.5
44
+ thf_max: 2
45
+ trf_min: 0.0
46
+ trf_max: 1.0
47
+ multiplicity_j1_min: 0.0
48
+ multiplicity_j1_max: 15
49
+ multiplicity_j2_min: 0.0
50
+ multiplicity_j2_max: 15
51
  response_generator:
52
  _target_: shimnet.generators.ResponseGenerator
53
  response_function_library:
shimnet/generators.py CHANGED
@@ -12,27 +12,37 @@ def random_value(min_value, max_value, generator=None):
12
  def random_loguniform(min_value, max_value, generator=None):
13
  return (min_value * torch.exp(torch.rand(1, generator=generator) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
14
 
15
- def calculate_theoretical_spectrum(peaks_parameters: dict, frq_frq:torch.Tensor):
16
- # extract parameters
17
- tff_lin = peaks_parameters["tff_lin"]
18
- twf_lin = peaks_parameters["twf_lin"]
19
- thf_lin = peaks_parameters["thf_lin"]
20
- trf_lin = peaks_parameters["trf_lin"]
21
-
22
- lwf_lin = twf_lin
23
- lhf_lin = thf_lin * (1. - trf_lin)
24
- gwf_lin = twf_lin
25
- gdf_lin = gwf_lin / torch.tensor(2.).log().mul(2.).sqrt()
26
- ghf_lin = thf_lin * trf_lin
27
- # calculate Lorenz peaks contriubutions
28
- lsf_linfrq = lwf_lin[:, None] ** 2 / (lwf_lin[:, None] ** 2 + (frq_frq - tff_lin[:, None]) ** 2) * lhf_lin[:, None]
29
- # calculate Gaussian peaks contriubutions
30
- gsf_linfrq = torch.exp(-(frq_frq - tff_lin[:, None]) ** 2 / gdf_lin[:, None] ** 2 / 2.) * ghf_lin[:, None]
31
- tsf_linfrq = lsf_linfrq + gsf_linfrq
32
- # sum peaks contriubutions
33
- tsf_frq = tsf_linfrq.sum(0, keepdim = True)
34
- return tsf_frq
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  pascal_triangle = [(1,), (1,1), (1,2,1), (1,3,3,1), (1,4,6,4,1), (1,5,10,10,5,1), (1,6,15,20,15,6,1), (1,7, 21,35,35,21,7,1)]
38
  normalized_pascal_triangle = [torch.tensor(x)/sum(x) for x in pascal_triangle]
@@ -65,8 +75,8 @@ def generate_multiplet_parameters(multiplicity, tff_lin, thf_lin, twf_lin, trf_l
65
  def value_to_index(values, table):
66
  span = table[-1] - table[0]
67
  indices = ((values - table[0])/span * (len(table)-1)) #.round().type(torch.int64)
68
- return indices
69
-
70
  def generate_theoretical_spectrum(
71
  number_of_signals_min, number_of_signals_max,
72
  spectrum_width_min, spectrum_width_max,
@@ -328,20 +338,44 @@ class RngGetter:
328
  rng = self.rng
329
  return rng
330
 
331
- class TheoreticalMultipletSpectraGenerator:
332
- def __init__(self, atom_groups_data_file=None, pixels=2048, frq_step=11160.7142857 / 32768,
333
- number_of_signals_min=1, number_of_signals_max=8,
334
- spectrum_width_min=0.2, spectrum_width_max=1, relative_width_min=1, relative_width_max=2,
335
- relative_height_min=1, relative_height_max=1, relative_frequency_min=-0.4, relative_frequency_max=0.4,
336
- thf_min=1/16, thf_max=16, trf_min=0, trf_max=1, multiplicity_j1_min=0, multiplicity_j1_max=15,
337
- multiplicity_j2_min=0, multiplicity_j2_max=15, seed=42, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  # Read atom_groups_data from file
339
  if atom_groups_data_file is None:
340
  self.atom_groups_data = np.ones((1,3), dtype=int)
341
  else:
342
  self.atom_groups_data = np.atleast_2d(np.loadtxt(atom_groups_data_file, usecols=(1,2,3), dtype=int))
343
- self.pixels = pixels
344
- self.frq_step = frq_step
 
345
  self.number_of_signals_min = number_of_signals_min
346
  self.number_of_signals_max = number_of_signals_max
347
  self.spectrum_width_min = spectrum_width_min
@@ -350,8 +384,6 @@ class TheoreticalMultipletSpectraGenerator:
350
  self.relative_width_max = relative_width_max
351
  self.relative_height_min = relative_height_min
352
  self.relative_height_max = relative_height_max
353
- self.relative_frequency_min = relative_frequency_min
354
- self.relative_frequency_max = relative_frequency_max
355
  self.thf_min = thf_min
356
  self.thf_max = thf_max
357
  self.trf_min = trf_min
@@ -360,36 +392,135 @@ class TheoreticalMultipletSpectraGenerator:
360
  self.multiplicity_j1_max = multiplicity_j1_max
361
  self.multiplicity_j2_min = multiplicity_j2_min
362
  self.multiplicity_j2_max = multiplicity_j2_max
363
- self.frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step
364
- self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator
 
 
 
 
365
 
366
  def __call__(self, seed=None):
 
 
 
 
 
 
 
 
 
367
  rng = self.rng_getter.get_rng(seed=seed)
368
-
369
- spectrum, spectrum_data = generate_theoretical_spectrum(
370
- number_of_signals_min=self.number_of_signals_min,
371
- number_of_signals_max=self.number_of_signals_max,
372
- spectrum_width_min=self.spectrum_width_min,
373
- spectrum_width_max=self.spectrum_width_max,
374
- relative_width_min=self.relative_width_min,
375
- relative_width_max=self.relative_width_max,
376
- tff_min=self.relative_frequency_min * self.pixels * self.frq_step,
377
- tff_max=self.relative_frequency_max * self.pixels * self.frq_step,
378
- thf_min=self.thf_min,
379
- thf_max=self.thf_max,
380
- trf_min=self.trf_min,
381
- trf_max=self.trf_max,
382
- relative_height_min=self.relative_height_min,
383
- relative_height_max=self.relative_height_max,
384
- multiplicity_j1_min=self.multiplicity_j1_min,
385
- multiplicity_j1_max=self.multiplicity_j1_max,
386
- multiplicity_j2_min=self.multiplicity_j2_min,
387
- multiplicity_j2_max=self.multiplicity_j2_max,
388
- atom_groups_data=self.atom_groups_data,
389
- frq_frq=self.frq_frq,
390
  generator=rng
391
  )
392
- return spectrum, {"spectrum_data": spectrum_data, "frq_frq": self.frq_frq}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
  class ResponseGenerator:
395
  def __init__(self, response_function_library, response_function_stretch_min=1., response_function_stretch_max=1., pad_to=None,
 
12
  def random_loguniform(min_value, max_value, generator=None):
13
  return (min_value * torch.exp(torch.rand(1, generator=generator) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
14
 
15
+ def spectrum_from_peaks_data(peaks_parameters: dict | list, frq_frq:torch.Tensor, relative_frequency=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ if isinstance(peaks_parameters, dict):
18
+ peaks_parameters = [peaks_parameters]
19
+
20
+ spectrum = torch.zeros((1, frq_frq.shape[0]))
21
+ for peak_params in peaks_parameters:
22
+ # extract parameters
23
+ if relative_frequency:
24
+ tff_lin = frq_frq[0] + peak_params["tff_relative"]*(frq_frq[1]-frq_frq[0])
25
+ else:
26
+ tff_lin = peak_params["tff_lin"]
27
+ twf_lin = peak_params["twf_lin"]
28
+ thf_lin = peak_params["thf_lin"]
29
+ trf_lin = peak_params["trf_lin"]
30
+
31
+ lwf_lin = twf_lin
32
+ lhf_lin = thf_lin * (1. - trf_lin)
33
+ gwf_lin = twf_lin
34
+ gdf_lin = gwf_lin / torch.tensor(2.).log().mul(2.).sqrt()
35
+ ghf_lin = thf_lin * trf_lin
36
+ # calculate Lorenz peaks contriubutions
37
+ lsf_linfrq = lwf_lin[:, None] ** 2 / (lwf_lin[:, None] ** 2 + (frq_frq - tff_lin[:, None]) ** 2) * lhf_lin[:, None]
38
+ # calculate Gaussian peaks contriubutions
39
+ gsf_linfrq = torch.exp(-(frq_frq - tff_lin[:, None]) ** 2 / gdf_lin[:, None] ** 2 / 2.) * ghf_lin[:, None]
40
+ tsf_linfrq = lsf_linfrq + gsf_linfrq
41
+ # sum peaks contriubutions
42
+ spectrum += tsf_linfrq.sum(0, keepdim = True)
43
+ return spectrum
44
+
45
+ calculate_theoretical_spectrum = spectrum_from_peaks_data # Alias for backward compatibility
46
 
47
  pascal_triangle = [(1,), (1,1), (1,2,1), (1,3,3,1), (1,4,6,4,1), (1,5,10,10,5,1), (1,6,15,20,15,6,1), (1,7, 21,35,35,21,7,1)]
48
  normalized_pascal_triangle = [torch.tensor(x)/sum(x) for x in pascal_triangle]
 
75
  def value_to_index(values, table):
76
  span = table[-1] - table[0]
77
  indices = ((values - table[0])/span * (len(table)-1)) #.round().type(torch.int64)
78
+ return indices
79
+
80
  def generate_theoretical_spectrum(
81
  number_of_signals_min, number_of_signals_max,
82
  spectrum_width_min, spectrum_width_max,
 
338
  rng = self.rng
339
  return rng
340
 
341
+
342
+ class PeaksParameterDataGenerator:
343
+ """
344
+ Generates peak parameter data for NMR multiplets.
345
+
346
+ This class is responsible for generating the parameters that describe individual peaks
347
+ in an NMR spectrum (frequencies, heights, widths, Gaussian/Lorentzian ratio).
348
+ """
349
+ def __init__(self,
350
+ tff_min=None, #may be assigned after initialization
351
+ tff_max=None, #may be assigned after initialization
352
+ atom_groups_data_file=None,
353
+ number_of_signals_min=1,
354
+ number_of_signals_max=8,
355
+ spectrum_width_min=0.2,
356
+ spectrum_width_max=1,
357
+ relative_width_min=1,
358
+ relative_width_max=2,
359
+ relative_height_min=1,
360
+ relative_height_max=1,
361
+ thf_min=1/16,
362
+ thf_max=16,
363
+ trf_min=0,
364
+ trf_max=1,
365
+ multiplicity_j1_min=0,
366
+ multiplicity_j1_max=15,
367
+ multiplicity_j2_min=0,
368
+ multiplicity_j2_max=15,
369
+ seed=42
370
+ ):
371
  # Read atom_groups_data from file
372
  if atom_groups_data_file is None:
373
  self.atom_groups_data = np.ones((1,3), dtype=int)
374
  else:
375
  self.atom_groups_data = np.atleast_2d(np.loadtxt(atom_groups_data_file, usecols=(1,2,3), dtype=int))
376
+
377
+ self.tff_min = tff_min
378
+ self.tff_max = tff_max
379
  self.number_of_signals_min = number_of_signals_min
380
  self.number_of_signals_max = number_of_signals_max
381
  self.spectrum_width_min = spectrum_width_min
 
384
  self.relative_width_max = relative_width_max
385
  self.relative_height_min = relative_height_min
386
  self.relative_height_max = relative_height_max
 
 
387
  self.thf_min = thf_min
388
  self.thf_max = thf_max
389
  self.trf_min = trf_min
 
392
  self.multiplicity_j1_max = multiplicity_j1_max
393
  self.multiplicity_j2_min = multiplicity_j2_min
394
  self.multiplicity_j2_max = multiplicity_j2_max
395
+
396
+ self.rng_getter = RngGetter(seed=seed)
397
+
398
+ def set_tff_range(self, tff_min, tff_max):
399
+ self.tff_min = tff_min
400
+ self.tff_max = tff_max
401
 
402
  def __call__(self, seed=None):
403
+ """
404
+ Generate peak parameters data.
405
+
406
+ Args:
407
+ seed: Optional seed for reproducibility
408
+
409
+ Returns:
410
+ List of dicts containing peak parameters (without tff_relative)
411
+ """
412
  rng = self.rng_getter.get_rng(seed=seed)
413
+
414
+ number_of_signals = torch.randint(
415
+ self.number_of_signals_min,
416
+ self.number_of_signals_max + 1,
417
+ [],
418
+ generator=rng
419
+ )
420
+ atom_group_indices = torch.randint(
421
+ 0,
422
+ len(self.atom_groups_data),
423
+ [number_of_signals],
424
+ generator=rng
425
+ )
426
+ width_spectrum = random_loguniform(
427
+ self.spectrum_width_min,
428
+ self.spectrum_width_max,
 
 
 
 
 
 
429
  generator=rng
430
  )
431
+ height_spectrum = random_loguniform(
432
+ self.thf_min,
433
+ self.thf_max,
434
+ generator=rng
435
+ )
436
+
437
+ peaks_parameters_data = []
438
+ for atom_group_index in atom_group_indices:
439
+ relative_intensity, multiplicity1, multiplicity2 = self.atom_groups_data[atom_group_index]
440
+ position = random_value(self.tff_min, self.tff_max, generator=rng)
441
+ j1 = random_value(self.multiplicity_j1_min, self.multiplicity_j1_max, generator=rng)
442
+ j2 = random_value(self.multiplicity_j2_min, self.multiplicity_j2_max, generator=rng)
443
+ width = width_spectrum * random_loguniform(
444
+ self.relative_width_min,
445
+ self.relative_width_max,
446
+ generator=rng
447
+ )
448
+ height = height_spectrum * relative_intensity * random_loguniform(
449
+ self.relative_height_min,
450
+ self.relative_height_max,
451
+ generator=rng
452
+ )
453
+ gaussian_contribution = random_value(self.trf_min, self.trf_max, generator=rng)
454
+
455
+ peak_parameters = generate_multiplet_parameters(
456
+ multiplicity=(multiplicity1, multiplicity2),
457
+ tff_lin=position,
458
+ thf_lin=height,
459
+ twf_lin=width,
460
+ trf_lin=gaussian_contribution,
461
+ j1=j1,
462
+ j2=j2
463
+ )
464
+ peaks_parameters_data.append(peak_parameters)
465
+
466
+ return peaks_parameters_data
467
+
468
+ class TheoreticalMultipletSpectraGenerator:
469
+ """
470
+ Generates theoretical NMR multiplet spectra.
471
+
472
+ This class combines peak parameter generation with spectrum calculation.
473
+ It can accept either a PeaksParameterDataGenerator instance or parameters to create one.
474
+ """
475
+ def __init__(self,
476
+ peaks_parameter_generator,
477
+ pixels=2048,
478
+ frq_step=11160.7142857 / 32768,
479
+ relative_frequency_min=-0.4,
480
+ relative_frequency_max=0.4,
481
+ include_tff_relative=False,
482
+ seed=42
483
+ ):
484
+
485
+ # Spectrum-level parameters
486
+ self.pixels = pixels
487
+ self.frq_step = frq_step
488
+ self.relative_frequency_min = relative_frequency_min
489
+ self.relative_frequency_max = relative_frequency_max
490
+ self.include_tff_relative = include_tff_relative
491
+ self.frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step
492
+
493
+ self.peaks_parameter_generator = peaks_parameter_generator
494
+ self.peaks_parameter_generator.set_tff_range(
495
+ tff_min=relative_frequency_min * pixels * frq_step,
496
+ tff_max=relative_frequency_max * pixels * frq_step
497
+ )
498
+
499
+ # self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator
500
+
501
+ def __call__(self, seed=None):
502
+ """
503
+ Generate a theoretical spectrum.
504
+
505
+ Args:
506
+ seed: Optional seed for reproducibility
507
+
508
+ Returns:
509
+ Tuple of (spectrum, dict with spectrum_data and frq_frq)
510
+ """
511
+ # Generate peak parameters (peaks_parameter_generator has its own RngGetter)
512
+ peaks_parameters_data = self.peaks_parameter_generator(seed=seed)
513
+
514
+
515
+ # Add tff_relative if requested
516
+ if self.include_tff_relative:
517
+ for peak_params in peaks_parameters_data:
518
+ peak_params["tff_relative"] = value_to_index(peak_params["tff_lin"], self.frq_frq)
519
+
520
+ # Create spectrum from peaks
521
+ spectrum = spectrum_from_peaks_data(peaks_parameters_data, self.frq_frq)
522
+
523
+ return spectrum, {"spectrum_data": peaks_parameters_data, "frq_frq": self.frq_frq}
524
 
525
  class ResponseGenerator:
526
  def __init__(self, response_function_library, response_function_stretch_min=1., response_function_stretch_max=1., pad_to=None,