Marek Bukowicki commited on
Commit
2495192
·
1 Parent(s): 7544717

rewrite datapipe as modular

Browse files
Files changed (3) hide show
  1. configs/shimnet_600_modular.yaml +68 -0
  2. shimnet/generators.py +330 -19
  3. train.py +23 -15
configs/shimnet_600_modular.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: ShimNetWithSCRF
3
+ kwargs:
4
+ rensponse_length: 81
5
+ resnponse_head_dims:
6
+ - 128
7
+ training:
8
+ - batch_size: 64
9
+ learning_rate: 0.001
10
+ max_iters: 1600000
11
+ - batch_size: 512
12
+ learning_rate: 0.001
13
+ max_iters: 25600000
14
+ - batch_size: 512
15
+ learning_rate: 0.0005
16
+ max_iters: 12800000
17
+ losses_weights:
18
+ clean: 1.0
19
+ noised: 1.0
20
+ response: 1.0
21
+ data:
22
+ _target_: shimnet.generators.Generator
23
+ include_response_function: true
24
+ seed: null # null means random seed
25
+ batch_size: null # to be set in training script
26
+ clean_spectra_generator:
27
+ _target_: shimnet.generators.TheoreticalMultipletSpectraGenerator
28
+ atom_groups_data_file: data/multiplets_10000_parsed.txt
29
+ pixels: 2048
30
+ frq_step: ${metadata.frq_step}
31
+ number_of_signals_min: 2
32
+ number_of_signals_max: 5
33
+ spectrum_width_min: 0.2
34
+ spectrum_width_max: 1.0
35
+ relative_width_min: 1.0
36
+ relative_width_max: 2.0
37
+ relative_height_min: 0.5
38
+ relative_height_max: 4
39
+ relative_frequency_min: -0.4
40
+ relative_frequency_max: 0.4
41
+ thf_min: 0.5
42
+ thf_max: 2
43
+ trf_min: 0.0
44
+ trf_max: 1.0
45
+ multiplicity_j1_min: 0.0
46
+ multiplicity_j1_max: 15
47
+ multiplicity_j2_min: 0.0
48
+ multiplicity_j2_max: 15
49
+ response_generator:
50
+ _target_: shimnet.generators.ResponseGenerator
51
+ response_function_library:
52
+ _target_: shimnet.generators.ResponseLibrary
53
+ response_files:
54
+ - data/scrf_81_600MHz.pt
55
+ response_function_stretch_min: 1.0
56
+ response_function_stretch_max: 1.0
57
+ response_function_noise: 0.0
58
+ flip_response_function: false
59
+ noise_generator:
60
+ _target_: shimnet.generators.NoiseGenerator
61
+ spectrum_noise_min: 0.0
62
+ spectrum_noise_max: 0.015625
63
+ logging:
64
+ step: 1000000
65
+ num_plots: 32
66
+ metadata:
67
+ frq_step: 0.30048
68
+ spectrometer_frequency: 600.0
shimnet/generators.py CHANGED
@@ -1,13 +1,16 @@
1
  import numpy as np
2
  import torch
3
  import torchdata
 
 
 
4
  # from itertools import islice
5
 
6
- def random_value(min_value, max_value):
7
- return (min_value + torch.rand(1) * (max_value - min_value)).item()
8
 
9
- def random_loguniform(min_value, max_value):
10
- return (min_value * torch.exp(torch.rand(1) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
11
 
12
  def calculate_theoretical_spectrum(peaks_parameters: dict, frq_frq:torch.Tensor):
13
  # extract parameters
@@ -75,23 +78,24 @@ def generate_theoretical_spectrum(
75
  multiplicity_j1_min, multiplicity_j1_max,
76
  multiplicity_j2_min, multiplicity_j2_max,
77
  atom_groups_data,
78
- frq_frq
 
79
  ):
80
- number_of_signals = torch.randint(number_of_signals_min, number_of_signals_max+1, [])
81
- atom_group_indices = torch.randint(0, len(atom_groups_data), [number_of_signals])
82
- width_spectrum = random_loguniform(spectrum_width_min, spectrum_width_max)
83
- height_spectrum = random_loguniform(thf_min, thf_max)
84
 
85
  peak_parameters_data = []
86
  theoretical_spectrum = None
87
  for atom_group_index in atom_group_indices:
88
  relative_intensity, multiplicity1, multiplicity2 = atom_groups_data[atom_group_index]
89
- position = random_value(tff_min, tff_max)
90
- j1 = random_value(multiplicity_j1_min, multiplicity_j1_max)
91
- j2 = random_value(multiplicity_j2_min, multiplicity_j2_max)
92
- width = width_spectrum*random_loguniform(relative_width_min, relative_width_max)
93
- height = height_spectrum*relative_intensity*random_loguniform(relative_height_min, relative_height_max)
94
- gaussian_contribution = random_value(trf_min, trf_max)
95
 
96
  peaks_parameters = generate_multiplet_parameters(multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin= width, trf_lin= gaussian_contribution, j1=j1, j2=j2)
97
  peaks_parameters["tff_relative"] = value_to_index(peaks_parameters["tff_lin"], frq_frq)
@@ -143,8 +147,8 @@ def theoretical_generator(
143
  )
144
 
145
  class ResponseLibrary:
146
- def __init__(self, reponse_files, normalize=True):
147
- self.data = [torch.load(f, map_location='cpu', weights_only=True).flatten(0,-4) for f in reponse_files]
148
  if normalize:
149
  self.data = [data/torch.sum(data, dim=(-1,), keepdim=True) for data in self.data]
150
  lengths = [len(data) for data in self.data]
@@ -159,6 +163,10 @@ class ResponseLibrary:
159
 
160
  def __len__(self):
161
  return self.total_length
 
 
 
 
162
 
163
  def generator(
164
  theoretical_generator_params,
@@ -179,7 +187,7 @@ def generator(
179
  response_function = response_function_library[torch.randint(0, len(response_function_library), [1])][0]
180
  # stretch response function
181
  padding_size = (response_function.shape[-1] - 1)//2
182
- padding_size = round(random_loguniform(response_function_stretch_min, response_function_stretch_max)*padding_size) #torch.randint(round(padding_size*response_function_stretch_min), round(padding_size*response_function_stretch_max), [1]).item()
183
  response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
184
  response_function /= response_function.sum() # normalize sum of response function to 1
185
  # add noise to response function
@@ -277,4 +285,307 @@ def get_datapipe(
277
  pipe = pipe.batch(batch_size)
278
  pipe = pipe.collate(collate_fn=collate_with_spectrum_data if include_spectrum_data else None)
279
 
280
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import torch
3
  import torchdata
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from abc import ABC, abstractmethod
6
+
7
  # from itertools import islice
8
 
9
+ def random_value(min_value, max_value, generator=None):
10
+ return (min_value + torch.rand(1, generator=generator) * (max_value - min_value)).item()
11
 
12
+ def random_loguniform(min_value, max_value, generator=None):
13
+ return (min_value * torch.exp(torch.rand(1, generator=generator) * (torch.log(torch.tensor(max_value)) - torch.log(torch.tensor(min_value))))).item()
14
 
15
  def calculate_theoretical_spectrum(peaks_parameters: dict, frq_frq:torch.Tensor):
16
  # extract parameters
 
78
  multiplicity_j1_min, multiplicity_j1_max,
79
  multiplicity_j2_min, multiplicity_j2_max,
80
  atom_groups_data,
81
+ frq_frq,
82
+ generator=None
83
  ):
84
+ number_of_signals = torch.randint(number_of_signals_min, number_of_signals_max+1, [], generator=generator)
85
+ atom_group_indices = torch.randint(0, len(atom_groups_data), [number_of_signals], generator=generator)
86
+ width_spectrum = random_loguniform(spectrum_width_min, spectrum_width_max, generator=generator)
87
+ height_spectrum = random_loguniform(thf_min, thf_max, generator=generator)
88
 
89
  peak_parameters_data = []
90
  theoretical_spectrum = None
91
  for atom_group_index in atom_group_indices:
92
  relative_intensity, multiplicity1, multiplicity2 = atom_groups_data[atom_group_index]
93
+ position = random_value(tff_min, tff_max, generator=generator)
94
+ j1 = random_value(multiplicity_j1_min, multiplicity_j1_max, generator=generator)
95
+ j2 = random_value(multiplicity_j2_min, multiplicity_j2_max, generator=generator)
96
+ width = width_spectrum*random_loguniform(relative_width_min, relative_width_max, generator=generator)
97
+ height = height_spectrum*relative_intensity*random_loguniform(relative_height_min, relative_height_max, generator=generator)
98
+ gaussian_contribution = random_value(trf_min, trf_max, generator=generator)
99
 
100
  peaks_parameters = generate_multiplet_parameters(multiplicity=(multiplicity1, multiplicity2), tff_lin=position, thf_lin=height, twf_lin= width, trf_lin= gaussian_contribution, j1=j1, j2=j2)
101
  peaks_parameters["tff_relative"] = value_to_index(peaks_parameters["tff_lin"], frq_frq)
 
147
  )
148
 
149
  class ResponseLibrary:
150
+ def __init__(self, response_files, normalize=True):
151
+ self.data = [torch.load(f, map_location='cpu', weights_only=True).flatten(0,-4) for f in response_files]
152
  if normalize:
153
  self.data = [data/torch.sum(data, dim=(-1,), keepdim=True) for data in self.data]
154
  lengths = [len(data) for data in self.data]
 
163
 
164
  def __len__(self):
165
  return self.total_length
166
+
167
+ @property
168
+ def max_response_length(self):
169
+ return max([data.shape[-1] for data in self.data])
170
 
171
  def generator(
172
  theoretical_generator_params,
 
187
  response_function = response_function_library[torch.randint(0, len(response_function_library), [1])][0]
188
  # stretch response function
189
  padding_size = (response_function.shape[-1] - 1)//2
190
+ padding_size = round(random_loguniform(response_function_stretch_min, response_function_stretch_max)*padding_size) #torch.randint(round(padding_size*response_function_stretch_min), round(paddingSize*response_function_stretch_max), [1]).item()
191
  response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
192
  response_function /= response_function.sum() # normalize sum of response function to 1
193
  # add noise to response function
 
285
  pipe = pipe.batch(batch_size)
286
  pipe = pipe.collate(collate_fn=collate_with_spectrum_data if include_spectrum_data else None)
287
 
288
+ return pipe
289
+
290
+ # response_functions_files,
291
+ # atom_groups_data_file=None,
292
+ # batch_size=64,
293
+ # pixels=2048, frq_step=11160.7142857 / 32768,
294
+ # number_of_signals_min=1, number_of_signals_max=8,
295
+ # spectrum_width_min=0.2, spectrum_width_max=1,
296
+ # relative_width_min=1, relative_width_max=2,
297
+ # relative_height_min=1, relative_height_max=1,
298
+ # relative_frequency_min=-0.4, relative_frequency_max=0.4,
299
+ # thf_min=1/16, thf_max=16,
300
+ # trf_min=0, trf_max=1,
301
+ # multiplicity_j1_min=0, multiplicity_j1_max=15,
302
+ # multiplicity_j2_min=0, multiplicity_j2_max=15,
303
+ # response_function_stretch_min=0.5,
304
+ # response_function_stretch_max=2.0,
305
+ # response_function_noise=0.,
306
+ # spectrum_noise_min=0.,
307
+ # spectrum_noise_max=1/64,
308
+ # include_spectrum_data=False,
309
+ # include_peak_mask=False,
310
+ # include_response_function=False,
311
+ # flip_response_function=False
312
+
313
+
314
+ class RngGetter:
315
+ def __init__(self, seed=42):
316
+ self.rng = torch.Generator()
317
+ if seed is not None:
318
+ self.rng.manual_seed(seed)
319
+ else:
320
+ self.rng.seed()
321
+
322
+ def get_rng(self, seed=None):
323
+ # Use provided seed or fall back to instance RNG
324
+ if seed is not None:
325
+ rng = torch.Generator()
326
+ rng.manual_seed(seed)
327
+ else:
328
+ rng = self.rng
329
+ return rng
330
+
331
+ class TheoreticalMultipletSpectraGenerator:
332
+ def __init__(self, atom_groups_data_file=None, pixels=2048, frq_step=11160.7142857 / 32768,
333
+ number_of_signals_min=1, number_of_signals_max=8,
334
+ spectrum_width_min=0.2, spectrum_width_max=1, relative_width_min=1, relative_width_max=2,
335
+ relative_height_min=1, relative_height_max=1, relative_frequency_min=-0.4, relative_frequency_max=0.4,
336
+ thf_min=1/16, thf_max=16, trf_min=0, trf_max=1, multiplicity_j1_min=0, multiplicity_j1_max=15,
337
+ multiplicity_j2_min=0, multiplicity_j2_max=15, seed=42, **kwargs):
338
+ # Read atom_groups_data from file
339
+ if atom_groups_data_file is None:
340
+ self.atom_groups_data = np.ones((1,3), dtype=int)
341
+ else:
342
+ self.atom_groups_data = np.atleast_2d(np.loadtxt(atom_groups_data_file, usecols=(1,2,3), dtype=int))
343
+ self.pixels = pixels
344
+ self.frq_step = frq_step
345
+ self.number_of_signals_min = number_of_signals_min
346
+ self.number_of_signals_max = number_of_signals_max
347
+ self.spectrum_width_min = spectrum_width_min
348
+ self.spectrum_width_max = spectrum_width_max
349
+ self.relative_width_min = relative_width_min
350
+ self.relative_width_max = relative_width_max
351
+ self.relative_height_min = relative_height_min
352
+ self.relative_height_max = relative_height_max
353
+ self.relative_frequency_min = relative_frequency_min
354
+ self.relative_frequency_max = relative_frequency_max
355
+ self.thf_min = thf_min
356
+ self.thf_max = thf_max
357
+ self.trf_min = trf_min
358
+ self.trf_max = trf_max
359
+ self.multiplicity_j1_min = multiplicity_j1_min
360
+ self.multiplicity_j1_max = multiplicity_j1_max
361
+ self.multiplicity_j2_min = multiplicity_j2_min
362
+ self.multiplicity_j2_max = multiplicity_j2_max
363
+ self.frq_frq = torch.arange(-pixels // 2, pixels // 2) * frq_step
364
+ self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator
365
+
366
+ def __call__(self, seed=None):
367
+ rng = self.rng_getter.get_rng(seed=seed)
368
+
369
+ spectrum, spectrum_data = generate_theoretical_spectrum(
370
+ number_of_signals_min=self.number_of_signals_min,
371
+ number_of_signals_max=self.number_of_signals_max,
372
+ spectrum_width_min=self.spectrum_width_min,
373
+ spectrum_width_max=self.spectrum_width_max,
374
+ relative_width_min=self.relative_width_min,
375
+ relative_width_max=self.relative_width_max,
376
+ tff_min=self.relative_frequency_min * self.pixels * self.frq_step,
377
+ tff_max=self.relative_frequency_max * self.pixels * self.frq_step,
378
+ thf_min=self.thf_min,
379
+ thf_max=self.thf_max,
380
+ trf_min=self.trf_min,
381
+ trf_max=self.trf_max,
382
+ relative_height_min=self.relative_height_min,
383
+ relative_height_max=self.relative_height_max,
384
+ multiplicity_j1_min=self.multiplicity_j1_min,
385
+ multiplicity_j1_max=self.multiplicity_j1_max,
386
+ multiplicity_j2_min=self.multiplicity_j2_min,
387
+ multiplicity_j2_max=self.multiplicity_j2_max,
388
+ atom_groups_data=self.atom_groups_data,
389
+ frq_frq=self.frq_frq,
390
+ generator=rng
391
+ )
392
+ return spectrum, {"spectrum_data": spectrum_data, "frq_frq": self.frq_frq}
393
+
394
+ class ResponseGenerator:
395
+ def __init__(self, response_function_library, response_function_stretch_min=1., response_function_stretch_max=1., pad_to=None,
396
+ response_function_noise=0.0, flip_response_function=False, seed=42):
397
+ self.response_function_library = response_function_library
398
+ self.response_function_stretch_min = response_function_stretch_min
399
+ self.response_function_stretch_max = response_function_stretch_max
400
+ self.pad_to = pad_to
401
+ self.response_function_noise = response_function_noise
402
+ self.flip_response_function = flip_response_function
403
+ self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator
404
+
405
+ def __call__(self, seed=None):
406
+ rng = self.rng_getter.get_rng(seed=seed)
407
+
408
+ response_function = self.response_function_library[torch.randint(0, len(self.response_function_library), [1], generator=rng)][0]
409
+ padding_size = (response_function.shape[-1] - 1)//2
410
+ padding_size = round(random_loguniform(self.response_function_stretch_min, self.response_function_stretch_max, generator=rng)*padding_size)
411
+ response_function = torch.nn.functional.interpolate(response_function, size=2*padding_size+1, mode='linear')
412
+ response_function /= response_function.sum()
413
+ response_function += torch.randn(response_function.shape, generator=rng) * self.response_function_noise
414
+ response_function /= response_function.sum()
415
+ if self.flip_response_function and (torch.rand(1, generator=rng).item() < 0.5):
416
+ response_function = response_function.flip(-1)
417
+ if self.pad_to is not None:
418
+ pad_size_left = (self.pad_to - response_function.shape[-1]) // 2
419
+ pad_size_right = self.pad_to - response_function.shape[-1] - pad_size_left
420
+ response_function = torch.nn.functional.pad(response_function, (pad_size_left, pad_size_right))
421
+ return response_function
422
+
423
+ class NoiseGenerator:
424
+ def __init__(self, spectrum_noise_min=0., spectrum_noise_max=1/64, seed=42):
425
+ self.spectrum_noise_min = spectrum_noise_min
426
+ self.spectrum_noise_max = spectrum_noise_max
427
+ self.rng_getter = RngGetter(seed=seed) # self.rng_getter.get_rng(seed=seed) to get random generator
428
+
429
+ def __call__(self, disturbed_spectrum, seed=None):
430
+ rng = self.rng_getter.get_rng(seed=seed)
431
+ return disturbed_spectrum + torch.randn(disturbed_spectrum.shape, generator=rng) * random_value(self.spectrum_noise_min, self.spectrum_noise_max, generator=rng)
432
+
433
+ class BaseGenerator(ABC):
434
+ """
435
+ Single-threaded base generator.
436
+
437
+ For this workload, single-threaded execution is typically faster because:
438
+ - Thread creation/synchronization overhead > computation time
439
+ - Python GIL contention during object creation
440
+ - Memory allocator contention when multiple threads allocate tensors
441
+ - CPU cache thrashing across cores
442
+ - Small per-thread workload doesn't amortize thread overhead
443
+ """
444
+ def __init__(self, batch_size=64, seed=None):
445
+ self.batch_size = batch_size
446
+ self.seed = seed
447
+
448
+ def set_seed(self, seed):
449
+ self.seed = seed
450
+
451
+ @abstractmethod
452
+ def _generate_element(self, seed):
453
+ pass
454
+
455
+ def __iter__(self):
456
+ rng = torch.Generator()
457
+ if self.seed is not None:
458
+ rng.manual_seed(self.seed)
459
+ else:
460
+ rng.seed()
461
+
462
+ while True:
463
+ batch = []
464
+ # Generate unique seeds for each element in the batch
465
+ if self.seed is not None:
466
+ element_seeds = [torch.randint(0, 2**31, (1,), generator=rng).item() for _ in range(self.batch_size)]
467
+ else:
468
+ element_seeds = [None] * self.batch_size
469
+
470
+ # Single-threaded sequential generation
471
+ for i in range(self.batch_size):
472
+ batch.append(self._generate_element(element_seeds[i]))
473
+
474
+ yield self.collate_fn(batch)
475
+
476
+ @abstractmethod
477
+ def collate_fn(self, batch):
478
+ pass
479
+
480
+
481
+ class BaseGeneratorMultithread(ABC):
482
+ """
483
+ Multithreaded base generator (backup option).
484
+
485
+ Use only if profiling shows benefit for your specific use case
486
+ (e.g., very large/slow generation functions, I/O-bound operations).
487
+ """
488
+ def __init__(self, batch_size=64, num_workers=4, seed=None, ordered_batch=False):
489
+ self.batch_size = batch_size
490
+ self.num_workers = num_workers
491
+ self.seed = seed
492
+ self.ordered_batch = ordered_batch
493
+
494
+ def set_seed(self, seed):
495
+ self.seed = seed
496
+
497
+ def set_ordered_batch(self, ordered_batch):
498
+ self.ordered_batch = ordered_batch
499
+
500
+ @abstractmethod
501
+ def _generate_element(self, seed):
502
+ pass
503
+
504
+ def __iter__(self):
505
+ rng = torch.Generator()
506
+ if self.seed is not None:
507
+ rng.manual_seed(self.seed)
508
+ else:
509
+ rng.seed()
510
+
511
+ while True:
512
+ batch = []
513
+ # Generate unique seeds for each element in the batch
514
+ if self.seed is not None:
515
+ element_seeds = [torch.randint(0, 2**31, (1,), generator=rng).item() for _ in range(self.batch_size)]
516
+ else:
517
+ element_seeds = [None] * self.batch_size
518
+
519
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
520
+ futures = [executor.submit(self._generate_element, element_seeds[i]) for i in range(self.batch_size)]
521
+
522
+ if self.ordered_batch:
523
+ # Maintain order: iterate futures in submission order
524
+ for f in futures:
525
+ batch.append(f.result())
526
+ else:
527
+ # Faster: process as completed (order may vary)
528
+ for f in as_completed(futures):
529
+ batch.append(f.result())
530
+
531
+ yield self.collate_fn(batch)
532
+
533
+ @abstractmethod
534
+ def collate_fn(self, batch):
535
+ pass
536
+
537
+ class Generator(BaseGenerator):
538
+ def __init__(self, clean_spectra_generator, response_generator, noise_generator, batch_size=64,
539
+ include_spectrum_data=False, include_peak_mask=False, include_response_function=False, seed=None):
540
+ super().__init__(batch_size=batch_size, seed=seed)
541
+ self.clean_spectra_generator = clean_spectra_generator
542
+ self.response_generator = response_generator
543
+ self.noise_generator = noise_generator
544
+ self.include_spectrum_data = include_spectrum_data
545
+ self.include_peak_mask = include_peak_mask
546
+ self.include_response_function = include_response_function
547
+
548
+ def _generate_element(self, seed):
549
+ # Generate different seeds for each generator from the provided seed
550
+ if seed is not None:
551
+ rng = torch.Generator()
552
+ rng.manual_seed(seed)
553
+ clean_seed = torch.randint(0, 2**31, (1,), generator=rng).item()
554
+ response_seed = torch.randint(0, 2**31, (1,), generator=rng).item()
555
+ noise_seed = torch.randint(0, 2**31, (1,), generator=rng).item()
556
+ else:
557
+ clean_seed = None
558
+ response_seed = None
559
+ noise_seed = None
560
+
561
+ clean_spectrum, extra_clean_data = self.clean_spectra_generator(seed=clean_seed)
562
+ response_function = self.response_generator(seed=response_seed)
563
+ padding_size = (response_function.shape[-1] - 1)//2
564
+ disturbed_spectrum = torch.nn.functional.conv1d(clean_spectrum, response_function, padding=padding_size)
565
+ noised_spectrum = self.noise_generator(disturbed_spectrum, seed=noise_seed)
566
+ out = {
567
+ 'theoretical_spectrum': clean_spectrum,
568
+ 'disturbed_spectrum': disturbed_spectrum,
569
+ 'noised_spectrum': noised_spectrum,
570
+ }
571
+ if self.include_spectrum_data:
572
+ out['theoretical_spectrum_data'] = extra_clean_data['spectrum_data']
573
+ out['frq_frq'] = extra_clean_data['frq_frq']
574
+ if self.include_peak_mask and extra_clean_data is not None:
575
+ all_peaks_rel = torch.cat([peak_data["tff_relative"] for peak_data in extra_clean_data['spectrum_data']])
576
+ peaks_indices = all_peaks_rel.round().type(torch.int64)
577
+ out["peaks_mask"] = torch.scatter(torch.zeros(out["theoretical_spectrum"].shape[1]), 0, peaks_indices, 1.).unsqueeze(0)
578
+ if self.include_response_function:
579
+ out['response_function'] = response_function
580
+ return out
581
+
582
+ def collate_fn(self, batch):
583
+ tensor_keys = set(batch[0].keys())
584
+ for k in ['theoretical_spectrum_data', 'frq_frq']:
585
+ tensor_keys.discard(k)
586
+ out = {k: torch.stack([item[k] for item in batch]) for k in tensor_keys}
587
+ if 'theoretical_spectrum_data' in batch[0]:
588
+ out['theoretical_spectrum_data'] = [item['theoretical_spectrum_data'] for item in batch]
589
+ if 'frq_frq' in batch[0]:
590
+ out['frq_frq'] = [item['frq_frq'] for item in batch]
591
+ return out
train.py CHANGED
@@ -6,7 +6,7 @@ from hydra.utils import instantiate
6
  import datetime
7
  import sys
8
  import matplotlib.pyplot as plt
9
-
10
 
11
  import matplotlib
12
  matplotlib.use('Agg')
@@ -15,8 +15,6 @@ matplotlib.use('Agg')
15
  import warnings
16
  warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
17
 
18
- # from shiment import models
19
- from shimnet.generators import get_datapipe
20
  from shimnet.predict_utils import Defaults as PredictDefaults
21
 
22
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -64,6 +62,19 @@ model_weights_file = run_dir / f'model.pt'
64
  optimizer = torch.optim.Adam(model.parameters())
65
  optimizer_weights_file = run_dir / f'optimizer.pt'
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def evaluate_model(stage=0, epoch=0):
68
  plot_dir = run_dir / "plots" / f"{stage}_{epoch}"
69
  plot_dir.mkdir(exist_ok=True, parents=True)
@@ -72,11 +83,12 @@ def evaluate_model(stage=0, epoch=0):
72
  torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt")
73
 
74
  num_plots = config.logging.num_plots
75
- pipe = get_datapipe(
76
- **config.data,
77
- include_response_function=True,
78
- batch_size=num_plots
79
- )
 
80
  batch = next(iter(pipe))
81
 
82
  with torch.no_grad():
@@ -154,18 +166,14 @@ for i_stage, training_stage in enumerate(config.training):
154
  if optimizer_weights_file.is_file():
155
  optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True))
156
  optimizer.param_groups[0]['lr'] = training_stage.learning_rate
157
-
158
- pipe = get_datapipe(
159
- **config.data,
160
- include_response_function=True,
161
- batch_size=training_stage.batch_size
162
- )
163
 
164
  losses_history = []
165
  losses_history_limit = 64*100 // training_stage.batch_size
166
 
167
  last_evaluation = 0
168
- for epoch, batch in pipe.enumerate():
169
 
170
  # logging
171
  iters_done = epoch*training_stage.batch_size
 
6
  import datetime
7
  import sys
8
  import matplotlib.pyplot as plt
9
+ from copy import deepcopy
10
 
11
  import matplotlib
12
  matplotlib.use('Agg')
 
15
  import warnings
16
  warnings.filterwarnings("ignore", category=UserWarning, module='torchdata')
17
 
 
 
18
  from shimnet.predict_utils import Defaults as PredictDefaults
19
 
20
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
62
  optimizer = torch.optim.Adam(model.parameters())
63
  optimizer_weights_file = run_dir / f'optimizer.pt'
64
 
65
+ def get_datapipe(config_data, batch_size, alter_seed_by=None):
66
+ data_config = deepcopy(config_data)
67
+ data_config.batch_size = batch_size
68
+
69
+ # we may change the seed for different stages
70
+ if alter_seed_by is not None:
71
+ if "seed" in data_config:
72
+ if data_config.seed is None:
73
+ data_config.seed = alter_seed_by
74
+ else:
75
+ data_config.seed = config_data.seed + alter_seed_by
76
+ return instantiate(data_config)
77
+
78
  def evaluate_model(stage=0, epoch=0):
79
  plot_dir = run_dir / "plots" / f"{stage}_{epoch}"
80
  plot_dir.mkdir(exist_ok=True, parents=True)
 
83
  torch.save(optimizer.state_dict(), plot_dir / "optimizer.pt")
84
 
85
  num_plots = config.logging.num_plots
86
+ pipe = get_datapipe(config.data, batch_size=num_plots)
87
+ # if possible, set seed and ordered batch for reproducibility
88
+ if hasattr(pipe, 'set_seed'):
89
+ pipe.set_seed(42)
90
+ if hasattr(pipe, 'set_ordered_batch'):
91
+ pipe.set_ordered_batch(True)
92
  batch = next(iter(pipe))
93
 
94
  with torch.no_grad():
 
166
  if optimizer_weights_file.is_file():
167
  optimizer.load_state_dict(torch.load(optimizer_weights_file, weights_only=True))
168
  optimizer.param_groups[0]['lr'] = training_stage.learning_rate
169
+
170
+ pipe = get_datapipe(config.data, batch_size=training_stage.batch_size, alter_seed_by=i_stage)
 
 
 
 
171
 
172
  losses_history = []
173
  losses_history_limit = 64*100 // training_stage.batch_size
174
 
175
  last_evaluation = 0
176
+ for epoch, batch in enumerate(pipe):
177
 
178
  # logging
179
  iters_done = epoch*training_stage.batch_size