primepake commited on
Commit
92a99c9
·
1 Parent(s): 6378746

update training code

Browse files
Files changed (1) hide show
  1. dac-vae/train.py +1000 -0
dac-vae/train.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comet_ml import Experiment
2
+ import argparse
3
+ import os
4
+ import time
5
+ import typing
6
+ from dataclasses import dataclass
7
+ from datetime import timedelta
8
+ from typing import Dict, Union
9
+
10
+ import numpy as np
11
+ import soundfile as sf
12
+ import torch
13
+ import yaml
14
+
15
+ from model import Discriminator
16
+ from model import DACVAE as VAE
17
+ from loss import (GANLoss, L1Loss, MelSpectrogramLoss,
18
+ MultiScaleSTFTLoss, kl_loss)
19
+ from torch import nn
20
+ from torch.distributed import destroy_process_group, init_process_group
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+ from torch.optim import Adam, AdamW
23
+ from torch.optim.lr_scheduler import ConstantLR, LinearLR, SequentialLR
24
+ from torch.utils.data.distributed import DistributedSampler
25
+
26
+ from audiotools import AudioSignal
27
+ from audiotools.core import util
28
+ from audiotools.data import transforms
29
+ from audiotools.data.datasets import AudioDataset, AudioLoader, ConcatDataset
30
+ from audiotools.ml.decorators import Tracker, timer, when
31
+
32
+
33
+ def ddp_setup():
34
+ print("Setting up DDP")
35
+ init_process_group(backend="nccl", timeout=timedelta(seconds=7200))
36
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
37
+
38
+
39
+ def build_transform(
40
+ augment_prob=1.0,
41
+ preprocess=["Identity"],
42
+ augment=["Identity"],
43
+ postprocess=["Identity", "RescaleAudio", "ShiftPhase"],
44
+ ):
45
+ to_tfm = lambda l: [getattr(transforms, x)() for x in l]
46
+ preprocess = transforms.Compose(*to_tfm(preprocess), name="preprocess")
47
+ augment = transforms.Compose(*to_tfm(augment), name="augment", prob=augment_prob)
48
+ postprocess = transforms.Compose(*to_tfm(postprocess), name="postprocess")
49
+ return transforms.Compose(preprocess, augment, postprocess)
50
+
51
+
52
+ def build_dataset(sample_rate, folders=None, **kwargs):
53
+ if folders is None:
54
+ folders = {}
55
+ datasets = []
56
+ for _, v in folders.items():
57
+ loader = AudioLoader(sources=v)
58
+ transform = build_transform()
59
+ dataset = AudioDataset(
60
+ loader, sample_rate, num_channels=2, transform=transform, **kwargs
61
+ )
62
+ datasets.append(dataset)
63
+ dataset = ConcatDataset(datasets)
64
+ dataset.transform = transform
65
+ return dataset
66
+
67
+
68
+ @dataclass
69
+ class State:
70
+ generator: DDP
71
+ optimizer_g: Union[AdamW, Adam]
72
+ scheduler_g: torch.optim.lr_scheduler._LRScheduler
73
+
74
+ discriminator: DDP
75
+ optimizer_d: Union[AdamW, Adam]
76
+ scheduler_d: torch.optim.lr_scheduler._LRScheduler
77
+
78
+ stft_loss: MultiScaleSTFTLoss
79
+ mel_loss: MelSpectrogramLoss
80
+ gan_loss: GANLoss
81
+ waveform_loss: L1Loss
82
+
83
+ train_dataset: AudioDataset
84
+ val_dataset: AudioDataset
85
+
86
+ tracker: Tracker
87
+ lambdas: Dict[str, float]
88
+
89
+ # ema: EMA # Add EMA to State
90
+
91
+
92
+ class ResumableDistributedSampler(DistributedSampler): # pragma: no cover
93
+ """Distributed sampler that can be resumed from a given start index."""
94
+
95
+ def __init__(self, dataset, start_idx: int = 0, **kwargs):
96
+ super().__init__(dataset, **kwargs)
97
+ # Start index, allows to resume an experiment at the index it was
98
+ self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0
99
+
100
+ def __iter__(self):
101
+ for i, idx in enumerate(super().__iter__()):
102
+ if i >= self.start_idx:
103
+ yield idx
104
+ self.start_idx = 0 # set the index back to 0 so for the next epoch
105
+
106
+
107
+ def prepare_dataloader(
108
+ dataset: AudioDataset,
109
+ world_size: int,
110
+ local_rank: int,
111
+ start_idx: int = 0,
112
+ shuffle: bool = True,
113
+ **kwargs,
114
+ ):
115
+ # sampler = ResumableDistributedSampler(
116
+ # dataset,
117
+ # start_idx,
118
+ # num_replicas=world_size,
119
+ # rank=local_rank,
120
+ # shuffle=shuffle,
121
+ # )
122
+
123
+ sampler = None
124
+ if start_idx > 0:
125
+ # Create a simple resumable sampler
126
+ indices = list(range(start_idx, len(dataset))) + list(range(start_idx))
127
+ sampler = torch.utils.data.SubsetRandomSampler(indices)
128
+
129
+ # if "num_workers" in kwargs:
130
+ # kwargs["num_workers"] = max(kwargs["num_workers"] // world_size, 1)
131
+ # kwargs["batch_size"] = max(kwargs["batch_size"] // world_size, 1)
132
+ # dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
133
+ dataloader = torch.utils.data.DataLoader(
134
+ dataset,
135
+ sampler=sampler,
136
+ shuffle=(sampler is None), # Only shuffle if no sampler
137
+ num_workers=24, # Can use more workers since no distribution
138
+ pin_memory=True,
139
+ persistent_workers=True,
140
+ prefetch_factor=8, # Can be higher for single GPU
141
+ drop_last=True,
142
+ **kwargs
143
+ )
144
+ return dataloader
145
+
146
+
147
+ class Trainer:
148
+ def __init__(self, args) -> None:
149
+ self.local_rank = int(os.environ["LOCAL_RANK"])
150
+ self.global_rank = int(os.environ["RANK"])
151
+ self.world_size = int(os.environ["WORLD_SIZE"])
152
+ torch.backends.cudnn.benchmark = True
153
+ torch.cuda.set_device(self.local_rank)
154
+ torch.cuda.empty_cache()
155
+
156
+ configs = yaml.safe_load(open(args.config_path, "r"))
157
+ print("configs: ", configs)
158
+ self.configs = configs
159
+
160
+ self.gan_start_step = configs.get("gan_start_step", 0)
161
+
162
+ self.num_iters = configs.get("num_iters", 250000)
163
+
164
+ self.generator = VAE(**configs["vae"])
165
+
166
+ self.discriminator = Discriminator(**configs["discriminator"])
167
+
168
+ total_steps = configs["num_samples"] // configs["batch_size"]
169
+
170
+ if configs["optimizer"]["scheduler"] == "linearlr":
171
+ self.optimizer_g, self.scheduler_g = self.get_scheduler(
172
+ self.generator, total_steps, configs["optimizer"]
173
+ )
174
+ else:
175
+ self.optimizer_g, self.scheduler_g = self.get_constant_scheduler(
176
+ self.generator, total_steps
177
+ )
178
+
179
+ if configs["disc_optimizer"]["scheduler"] == "constantlr":
180
+ self.optimizer_d, self.scheduler_d = self.get_constant_scheduler(
181
+ self.discriminator, total_steps
182
+ )
183
+ else:
184
+ self.optimizer_d, self.scheduler_d = self.get_scheduler(
185
+ self.discriminator, total_steps, configs["disc_optimizer"]
186
+ )
187
+
188
+ save_path = args.save_path
189
+ os.makedirs(save_path, exist_ok=True)
190
+ self.save_path = save_path
191
+
192
+ if self.local_rank == 0:
193
+ print(f"Rank {self.local_rank}: Initializing Comet.ml")
194
+ experiment = Experiment(
195
+ api_key=os.environ.get(
196
+ "COMET_API_KEY"
197
+ ), # Set COMET_API_KEY in your environment
198
+ project_name="DACVAE",
199
+ workspace=os.environ.get("COMET_WORKSPACE"), # Optional: Set workspace
200
+ # experiment_key=args.run_id, # Use run_id as experiment key
201
+ )
202
+ experiment.log_parameters(configs) # Log configuration
203
+ writer = experiment
204
+ else:
205
+ writer = None
206
+
207
+ print(f"Rank {self.local_rank}: Setting up tracker")
208
+ self.tracker = Tracker(
209
+ writer=writer, log_file=f"{save_path}/log.txt", rank=self.local_rank
210
+ )
211
+ self.val_idx = configs.get("val_idx", [0, 1, 2, 3, 4, 5, 6, 7])
212
+ self.save_iters = configs.get("save_iters", 1000)
213
+ self.sample_freq = configs.get("sample_freq", 10000)
214
+ self.valid_freq = configs.get("valid_freq", 1000)
215
+
216
+ self.tracker.print(self.generator)
217
+ self.tracker.print(self.discriminator)
218
+
219
+ self.waveform_loss = L1Loss()
220
+ self.stft_loss = MultiScaleSTFTLoss(**configs["MultiScaleSTFTLoss"])
221
+ self.mel_loss = MelSpectrogramLoss(**configs["MelSpectrogramLoss"])
222
+
223
+ print(f"{self.global_rank} Loading datasets...")
224
+ sample_rate = configs["vae"]["sample_rate"]
225
+ train_folders = {k: v for k, v in configs.get("train_folders", {}).items()}
226
+ val_folders = {k: v for k, v in configs.get("val_folders", {}).items()}
227
+ self.batch_size = configs["batch_size"]
228
+ self.val_batch_size = configs["val_batch_size"]
229
+ self.num_workers = configs["num_workers"]
230
+
231
+ print(f"Rank {self.local_rank}: Validating train dataset")
232
+ self.train_dataset = build_dataset(
233
+ sample_rate, train_folders, **configs["train"]
234
+ )
235
+ print(f"Rank {self.local_rank}: Validating val dataset")
236
+ self.val_dataset = build_dataset(sample_rate, val_folders, **configs["val"])
237
+
238
+ self.lambdas = configs["lambdas"]
239
+
240
+ if args.resume:
241
+ checkpoint_dir = os.path.join(args.save_path, args.tag)
242
+ self.resume_from_checkpoint(checkpoint_dir)
243
+
244
+ self.gan_loss = GANLoss(self.discriminator)
245
+ print("self.tracker.step: ", self.tracker.step)
246
+
247
+ self.generator = self.generator.to(self.local_rank)
248
+ self.discriminator = self.discriminator.to(self.local_rank)
249
+ #
250
+ self.generator = nn.SyncBatchNorm.convert_sync_batchnorm(self.generator)
251
+ self.discriminator = nn.SyncBatchNorm.convert_sync_batchnorm(self.discriminator)
252
+
253
+ # Wrap models with DDP
254
+ self.generator = DDP(self.generator, device_ids=[self.local_rank])
255
+ self.discriminator = DDP(self.discriminator, device_ids=[self.local_rank])
256
+
257
+ # ema_decay = self.configs.get("ema_decay", 0.999) # Add to your config YAML or set default
258
+ # self.ema = EMA(self.unwrap(self.generator), decay=ema_decay, device=self.local_rank)
259
+
260
+ self.state = State(
261
+ generator=self.generator,
262
+ optimizer_g=self.optimizer_g,
263
+ scheduler_g=self.scheduler_g,
264
+ discriminator=self.discriminator,
265
+ optimizer_d=self.optimizer_d,
266
+ scheduler_d=self.scheduler_d,
267
+ tracker=self.tracker,
268
+ train_dataset=self.train_dataset,
269
+ val_dataset=self.val_dataset,
270
+ stft_loss=self.stft_loss.to(self.local_rank),
271
+ mel_loss=self.mel_loss.to(self.local_rank),
272
+ gan_loss=self.gan_loss.to(self.local_rank),
273
+ waveform_loss=self.waveform_loss.to(self.local_rank),
274
+ lambdas=self.lambdas,
275
+ # ema=self.ema, # Add EMA to state
276
+ )
277
+ train_dataloader = prepare_dataloader(
278
+ self.train_dataset,
279
+ world_size=self.world_size,
280
+ local_rank=self.local_rank,
281
+ start_idx=self.state.tracker.step, # Use step directly
282
+ batch_size=self.batch_size,
283
+ collate_fn=self.state.train_dataset.collate,
284
+ )
285
+
286
+ self.len_train = len(train_dataloader)
287
+
288
+ self.train_dataloader = self.get_infinite_loader(train_dataloader)
289
+
290
+ if self.global_rank == 0:
291
+ self.val_dataloader = prepare_dataloader(
292
+ self.state.val_dataset,
293
+ world_size=1,
294
+ local_rank=0,
295
+ start_idx=0,
296
+ shuffle=False,
297
+ batch_size=self.val_batch_size,
298
+ collate_fn=self.state.val_dataset.collate,
299
+ )
300
+
301
+ self.seed = 0
302
+ self.val_real_audio = []
303
+ self.val_gen_audio = []
304
+ self.initial_norm = configs.get("initial_norm", float("inf"))
305
+ self.max_norm = configs.get("max_norm", float("inf"))
306
+ self.initial_norm_d = configs.get("initial_norm_d", float("inf"))
307
+ self.max_norm_d = configs.get("max_norm_d", float("inf"))
308
+
309
+ self.init_logs_penalty = self.state.lambdas["logs_penalty"]
310
+ self.init_lipschitz_penalty = self.state.lambdas["lipschitz_penalty"]
311
+ self.kl_max_beta = self.state.lambdas["kl/loss"]
312
+ self.hold_base_steps = configs.get("hold_base_steps", 200000)
313
+
314
+ def get_scheduler(self, model, total_steps, configs):
315
+ warmup_steps = configs.get("warmup_steps", 0)
316
+ if configs["type"] == "Adamw":
317
+ optimizer = AdamW(
318
+ model.parameters(),
319
+ lr=configs["lr"],
320
+ weight_decay=configs["weight_decay"],
321
+ )
322
+ else:
323
+ optimizer = Adam(
324
+ model.parameters(),
325
+ lr=configs["lr"],
326
+ weight_decay=configs["weight_decay"],
327
+ )
328
+
329
+ # Warmup from near-zero to max_lr
330
+ warmup = LinearLR(
331
+ optimizer,
332
+ start_factor=1e-9,
333
+ end_factor=1.0, # Go up to max_lr
334
+ total_iters=warmup_steps,
335
+ )
336
+ remaining_iters = total_steps - warmup_steps
337
+ constant = ConstantLR(
338
+ optimizer,
339
+ factor=1.0, # Keep the learning rate constant at max_lr
340
+ total_iters=remaining_iters,
341
+ )
342
+
343
+ scheduler = SequentialLR(
344
+ optimizer, schedulers=[warmup, constant], milestones=[warmup_steps]
345
+ )
346
+ return optimizer, scheduler
347
+
348
+ def get_constant_scheduler(self, model, total_steps):
349
+ if self.configs["optimizer"]["type"] == "adamw":
350
+ optimizer = AdamW(
351
+ model.parameters(),
352
+ lr=self.configs["optimizer"]["lr"],
353
+ weight_decay=self.configs["optimizer"]["weight_decay"],
354
+ )
355
+ else:
356
+ optimizer = Adam(
357
+ model.parameters(),
358
+ lr=self.configs["optimizer"]["lr"],
359
+ weight_decay=self.configs["optimizer"]["weight_decay"],
360
+ )
361
+ scheduler = ConstantLR(
362
+ optimizer,
363
+ factor=1.0, # Keep the learning rate constant at max_lr
364
+ total_iters=total_steps,
365
+ )
366
+ return optimizer, scheduler
367
+
368
+ def get_infinite_loader(self, dataset):
369
+ print(
370
+ f"Rank {torch.distributed.get_rank() if torch.distributed.is_initialized() else 0}: Starting infinite loader"
371
+ )
372
+ # Skip iterations if resuming
373
+ iterator = iter(dataset)
374
+ steps_to_skip = self.state.tracker.step
375
+ while True:
376
+ try:
377
+ batch = next(iterator)
378
+ if batch is None:
379
+ print(f"Rank {torch.distributed.get_rank()}: Skipping None batch")
380
+ continue
381
+ yield batch
382
+ except StopIteration:
383
+ iterator = iter(dataset) # Reset iterator at the end of the dataset
384
+
385
+ def log_grad_norms(self, output, norm_threshold=1.0):
386
+ """
387
+ Log gradient norms for key DACVAE components to aid debugging.
388
+ Tracks pre-clipping norms for encoder, decoder, and selected blocks.
389
+
390
+ Args:
391
+ output (dict): Dictionary to store gradient norm logs.
392
+ norm_threshold (float): Log norms above this threshold to reduce noise.
393
+ """
394
+ # Initialize dictionaries for norms
395
+ submodule_norms = {
396
+ "en_conv_post": 0.0,
397
+ "de_conv_pre": 0.0,
398
+ "encoder_initial_conv": 0.0,
399
+ "encoder_final_conv": 0.0,
400
+ "encoder_snake1d_alpha": 0.0,
401
+ "decoder_initial_conv": 0.0,
402
+ "decoder_final_conv": 0.0,
403
+ "decoder_snake1d_alpha": 0.0,
404
+ }
405
+ norm_values = [] # For distributional statistics
406
+
407
+ # Initialize norms for a few representative blocks (e.g., first and last)
408
+ num_enc_blocks = len(self.state.generator.module.encoder_rates)
409
+ num_dec_blocks = len(self.state.generator.module.decoder_rates)
410
+ for i in [0, num_enc_blocks - 1]: # First and last encoder blocks
411
+ submodule_norms.update(
412
+ {
413
+ f"encoder_block_{i}": 0.0,
414
+ f"encoder_block_{i}_snake1d": 0.0,
415
+ f"encoder_block_{i}_conv1d": 0.0,
416
+ }
417
+ )
418
+ for i in [0, num_dec_blocks - 1]: # First and last decoder blocks
419
+ submodule_norms.update(
420
+ {
421
+ f"decoder_block_{i}": 0.0,
422
+ f"decoder_block_{i}_snake1d": 0.0,
423
+ f"decoder_block_{i}_conv_transpose": 0.0,
424
+ }
425
+ )
426
+
427
+ # Calculate indices for final layers
428
+ enc_final_conv_idx = num_enc_blocks + 2
429
+ dec_final_conv_idx = num_dec_blocks * 2 + 1
430
+
431
+ # Iterate through parameters
432
+ for name, param in self.state.generator.named_parameters():
433
+ if param.grad is not None:
434
+ norm = param.grad.norm().item()
435
+ norm_values.append(norm)
436
+
437
+ # DACVAE layers
438
+ if "en_conv_post" in name:
439
+ submodule_norms["en_conv_post"] += norm**2
440
+ elif "de_conv_pre" in name:
441
+ submodule_norms["de_conv_pre"] += norm**2
442
+
443
+ # Encoder components
444
+ if "encoder.block.0" in name:
445
+ submodule_norms["encoder_initial_conv"] += norm**2
446
+ elif f"encoder.block.{enc_final_conv_idx}" in name:
447
+ submodule_norms["encoder_final_conv"] += norm**2
448
+ elif "encoder" in name and "alpha" in name:
449
+ submodule_norms["encoder_snake1d_alpha"] += norm**2
450
+ for i in [0, num_enc_blocks - 1]:
451
+ block_idx = i + 1
452
+ if f"encoder.block.{block_idx}" in name:
453
+ submodule_norms[f"encoder_block_{i}"] += norm**2
454
+ if "block.3" in name: # Snake1d
455
+ submodule_norms[f"encoder_block_{i}_snake1d"] += norm**2
456
+ elif "block.4" in name: # WNConv1d
457
+ submodule_norms[f"encoder_block_{i}_conv1d"] += norm**2
458
+
459
+ # Decoder components
460
+ if "decoder.model.0" in name:
461
+ submodule_norms["decoder_initial_conv"] += norm**2
462
+ elif f"decoder.model.{dec_final_conv_idx}" in name:
463
+ submodule_norms["decoder_final_conv"] += norm**2
464
+ elif "decoder" in name and "alpha" in name:
465
+ submodule_norms["decoder_snake1d_alpha"] += norm**2
466
+ for i in [0, num_dec_blocks - 1]:
467
+ block_idx = i * 2 + 1
468
+ if f"decoder.model.{block_idx}" in name:
469
+ submodule_norms[f"decoder_block_{i}"] += norm**2
470
+ if "block.0" in name: # Snake1d
471
+ submodule_norms[f"decoder_block_{i}_snake1d"] += norm**2
472
+ elif "block.1" in name: # WNConvTranspose1d
473
+ submodule_norms[f"decoder_block_{i}_conv_transpose"] += (
474
+ norm**2
475
+ )
476
+
477
+ # Compute square root of summed norms and log if above threshold
478
+ for key in submodule_norms:
479
+ norm = submodule_norms[key] ** 0.5
480
+ if norm > norm_threshold:
481
+ output[f"grad_norm/{key}"] = norm
482
+
483
+ # Log pre-clipping norm statistics
484
+ if norm_values:
485
+ output["grad_norm/pre_clip_max"] = max(norm_values)
486
+ output["grad_norm/pre_clip_mean"] = sum(norm_values) / len(norm_values)
487
+ output["grad_norm/pre_clip_95th_percentile"] = (
488
+ torch.tensor(norm_values).quantile(0.95).item()
489
+ )
490
+
491
+ def compute_lipschitz_penalty(self, lambda_lip=0.01):
492
+ penalty = 0.0
493
+ for name, param in self.state.generator.named_parameters():
494
+ if (
495
+ ("decoder" in name or "de_conv_pre" in name)
496
+ and param.grad is not None
497
+ and "weight" in name
498
+ ):
499
+ grad_norm = param.grad.norm(2)
500
+ penalty += grad_norm**2
501
+ return lambda_lip * penalty
502
+
503
+ def compute_gradient_penalty(self, recons, z):
504
+ # Compute gradients of decoder output w.r.t. latents
505
+ grads = torch.autograd.grad(
506
+ outputs=recons,
507
+ inputs=z,
508
+ grad_outputs=torch.ones_like(recons),
509
+ create_graph=True,
510
+ retain_graph=True,
511
+ )[0]
512
+ grad_norm = grads.norm(2, dim=[1, 2]).mean()
513
+ return 0.1 * grad_norm # Weight for penalty
514
+
515
+ def cosine_decay_with_warmup(
516
+ self,
517
+ cur_step,
518
+ base_value,
519
+ total_steps,
520
+ final_value,
521
+ warmup_value=0.0,
522
+ warmup_steps=0,
523
+ hold_base_steps=0,
524
+ ):
525
+ """Cosine schedule with warmup, adapted from R3GAN."""
526
+ # Ensure cur_step is a tensor
527
+ cur_step = torch.tensor(cur_step, dtype=torch.float32)
528
+
529
+ # Compute decay term
530
+ denom = float(total_steps - warmup_steps - hold_base_steps)
531
+ if denom <= 0:
532
+ raise ValueError(
533
+ "total_steps must be greater than warmup_steps + hold_base_steps"
534
+ )
535
+ phase = torch.pi * (cur_step - warmup_steps - hold_base_steps) / denom
536
+ decay = 0.5 * (1 + torch.cos(phase))
537
+
538
+ # Compute current value
539
+ cur_value = base_value + (1 - decay) * (final_value - base_value)
540
+
541
+ # Apply hold_base_steps condition
542
+ if hold_base_steps > 0:
543
+ cur_value = torch.where(
544
+ cur_step > warmup_steps + hold_base_steps,
545
+ cur_value,
546
+ torch.tensor(base_value, dtype=torch.float32),
547
+ )
548
+
549
+ # Apply warmup_steps condition
550
+ if warmup_steps > 0:
551
+ slope = (base_value - warmup_value) / warmup_steps
552
+ warmup_v = slope * cur_step + warmup_value
553
+ cur_value = torch.where(cur_step < warmup_steps, warmup_v, cur_value)
554
+
555
+ # Apply total_steps cap
556
+ cur_value = torch.where(
557
+ cur_step > total_steps,
558
+ torch.tensor(final_value, dtype=torch.float32),
559
+ cur_value,
560
+ )
561
+
562
+ return cur_value.item() # Return as float
563
+
564
+ def smooth_increase(
565
+ self,
566
+ step: int,
567
+ initial_beta: float = 0.01,
568
+ final_beta: float = 0.0,
569
+ total_steps: int = 50000,
570
+ ) -> float:
571
+ """Compute a linear decrease for beta."""
572
+ progress = min(step / total_steps, 1.0)
573
+ beta = initial_beta + progress * (final_beta - initial_beta)
574
+ return beta
575
+
576
+ @timer()
577
+ def train_loop(self, batch):
578
+ print(f"Rank {self.local_rank}: Starting train_loop")
579
+
580
+ self.max_gen_norm = self.cosine_decay_with_warmup(
581
+ cur_step=self.tracker.step,
582
+ base_value=self.initial_norm, # e.g., 100
583
+ total_steps=self.num_iters, # e.g., 250000
584
+ final_value=self.max_norm,
585
+ warmup_value=self.initial_norm,
586
+ warmup_steps=0,
587
+ hold_base_steps=self.hold_base_steps,
588
+ )
589
+
590
+ self.max_d_norm = self.cosine_decay_with_warmup(
591
+ cur_step=self.tracker.step,
592
+ base_value=self.initial_norm_d, # e.g., 100
593
+ total_steps=self.num_iters, # e.g., 250000
594
+ final_value=self.max_norm_d,
595
+ warmup_value=self.initial_norm_d,
596
+ warmup_steps=0,
597
+ hold_base_steps=self.hold_base_steps,
598
+ )
599
+
600
+ self.state.generator.train()
601
+ if self.tracker.step >= self.gan_start_step:
602
+ self.state.discriminator.train()
603
+ print(
604
+ f"Rank {self.local_rank}: Discriminator training mode: {self.state.discriminator.training}"
605
+ )
606
+ output = {}
607
+ output = {}
608
+ timing_logs = {}
609
+
610
+ output["max_gen_norm"] = self.max_gen_norm
611
+ output["max_d_norm"] = self.max_d_norm
612
+
613
+ train_loop_start = time.time()
614
+
615
+ # Batch preparation
616
+ batch_prepare_start = time.time()
617
+ batch = util.prepare_batch(batch, self.local_rank)
618
+ timing_logs["batch_prepare"] = time.time() - batch_prepare_start
619
+
620
+ # Data transformation
621
+ transform_start = time.time()
622
+ with torch.no_grad():
623
+ signal = self.train_dataset.transform(
624
+ batch["signal"].clone(), **batch["transform_args"]
625
+ )
626
+ signal.audio_data = torch.clamp(signal.audio_data, -1.0, 1.0)
627
+ timing_logs["transform"] = time.time() - transform_start
628
+
629
+ # Generator forward
630
+ gen_forward_start = time.time()
631
+ out = self.state.generator(signal.audio_data, signal.sample_rate)
632
+ recons = AudioSignal(out["audio"], signal.sample_rate)
633
+ timing_logs["gen_forward"] = time.time() - gen_forward_start
634
+ z, mu, logs = out["z"], out["mu"], out["logs"]
635
+ z.requires_grad_(True)
636
+ logs_reg = torch.mean(logs.abs()) # Penalize large logs
637
+
638
+ output["kl/loss"] = kl_loss(logs, mu)
639
+ output["logs_penalty"] = logs_reg
640
+
641
+ kl_beta = self.cosine_decay_with_warmup(
642
+ cur_step=self.tracker.step,
643
+ base_value=self.kl_max_beta, # e.g., 100
644
+ total_steps=self.num_iters, # e.g., 250000
645
+ final_value=0.1, # 0.1,
646
+ warmup_value=self.initial_norm,
647
+ warmup_steps=0,
648
+ hold_base_steps=self.hold_base_steps,
649
+ )
650
+
651
+ output["kl/beta"] = kl_beta
652
+
653
+ logs_penalty_weight = self.cosine_decay_with_warmup(
654
+ cur_step=self.tracker.step,
655
+ base_value=self.init_logs_penalty, # Initial weight for logs_penalty
656
+ total_steps=self.num_iters, # e.g., 250000
657
+ final_value=self.init_logs_penalty
658
+ * 0.01, # * 0.0001, # 10% of initial weight
659
+ warmup_value=self.init_logs_penalty,
660
+ warmup_steps=0,
661
+ hold_base_steps=self.hold_base_steps,
662
+ )
663
+ lipschitz_penalty_weight = self.cosine_decay_with_warmup(
664
+ cur_step=self.tracker.step,
665
+ base_value=self.init_lipschitz_penalty, # Initial weight for lipschitz_penalty
666
+ total_steps=self.num_iters, # e.g., 250000
667
+ final_value=self.init_lipschitz_penalty
668
+ * 0.01, # * 0.0001, # 10% of initial weight
669
+ warmup_value=self.init_lipschitz_penalty,
670
+ warmup_steps=0,
671
+ hold_base_steps=self.hold_base_steps,
672
+ )
673
+
674
+ # Discriminator loss
675
+ if self.tracker.step >= self.gan_start_step:
676
+ print(f"Rank {self.local_rank}: Discriminator loss")
677
+ disc_loss_start = time.time()
678
+ output["adv/disc_loss"] = self.state.gan_loss.discriminator_loss(
679
+ recons, signal
680
+ )
681
+ timing_logs["disc_loss"] = time.time() - disc_loss_start
682
+
683
+ # Discriminator backward
684
+ disc_backward_start = time.time()
685
+ self.state.optimizer_d.zero_grad(set_to_none=True)
686
+ output["adv/disc_loss"].backward()
687
+ output["other/grad_norm_d"] = torch.nn.utils.clip_grad_norm_(
688
+ self.state.discriminator.parameters(), self.max_d_norm
689
+ )
690
+ self.state.optimizer_d.step()
691
+ self.state.scheduler_d.step()
692
+ timing_logs["disc_backward"] = time.time() - disc_backward_start
693
+
694
+ # DDP synchronization for discriminator
695
+ disc_ddp_sync_start = time.time()
696
+ # if torch.distributed.is_initialized():
697
+ # torch.distributed.barrier()
698
+ timing_logs["disc_ddp_sync"] = time.time() - disc_ddp_sync_start
699
+ (
700
+ output["adv/gen_loss"],
701
+ output["adv/feat_loss"],
702
+ ) = self.state.gan_loss.generator_loss(recons, signal)
703
+
704
+ # Generator losses
705
+ gen_loss_start = time.time()
706
+ output["stft/loss"] = self.state.stft_loss(recons, signal)
707
+
708
+ output["mel/loss"] = self.state.mel_loss(recons, signal)
709
+ output["waveform/loss"] = self.state.waveform_loss(recons, signal)
710
+
711
+ output["lipschitz_penalty"] = self.compute_lipschitz_penalty(lambda_lip=0.01)
712
+ output["grad_penalty"] = self.compute_gradient_penalty(recons.audio_data, z)
713
+
714
+ loss_keys = [
715
+ "stft/loss",
716
+ "mel/loss",
717
+ "waveform/loss",
718
+ "kl/loss",
719
+ "logs_penalty",
720
+ "lipschitz_penalty",
721
+ "grad_penalty",
722
+ ]
723
+ # print("self.tracker.step >= self.gan_start_step: ", self.tracker.step >= self.gan_start_step)
724
+ if self.tracker.step >= self.gan_start_step:
725
+ loss_keys.extend(["adv/gen_loss", "adv/feat_loss"])
726
+
727
+ loss_weights = {k: self.state.lambdas.get(k, 1.0) for k in loss_keys}
728
+ loss_weights["kl/loss"] = kl_beta
729
+ loss_weights["logs_penalty"] = logs_penalty_weight
730
+ loss_weights["lipschitz_penalty"] = lipschitz_penalty_weight
731
+
732
+ # log the loss weights
733
+ output.update({f"loss_weight/{k}": v for k, v in loss_weights.items()})
734
+
735
+ output["loss"] = sum(
736
+ [loss_weights[k] * output[k] for k in loss_keys if k in output]
737
+ )
738
+ timing_logs["gen_loss"] = time.time() - gen_loss_start
739
+
740
+ # Generator backward
741
+ print(f"Rank {self.local_rank}: Updating generator")
742
+ gen_backward_start = time.time()
743
+ self.state.optimizer_g.zero_grad(set_to_none=True)
744
+ output["loss"].backward()
745
+
746
+ encoder_grad_norm = torch.nn.utils.clip_grad_norm_(
747
+ self.state.generator.module.encoder.parameters(), self.max_gen_norm
748
+ )
749
+ decoder_grad_norm = torch.nn.utils.clip_grad_norm_(
750
+ self.state.generator.module.decoder.parameters(), self.max_gen_norm
751
+ )
752
+
753
+ if self.tracker.step % 2 == 0: # Log every 100 iterations
754
+ self.log_grad_norms(output, norm_threshold=0.0)
755
+
756
+ output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
757
+ self.state.generator.parameters(), self.max_gen_norm
758
+ )
759
+
760
+ # Log gradient norms
761
+ output["other/grad_norm_encoder"] = (
762
+ encoder_grad_norm.item()
763
+ if torch.is_tensor(encoder_grad_norm)
764
+ else encoder_grad_norm
765
+ )
766
+ output["other/grad_norm_decoder"] = (
767
+ decoder_grad_norm.item()
768
+ if torch.is_tensor(decoder_grad_norm)
769
+ else decoder_grad_norm
770
+ )
771
+
772
+ self.state.optimizer_g.step()
773
+ self.state.scheduler_g.step()
774
+ timing_logs["gen_backward"] = time.time() - gen_backward_start
775
+
776
+ # self.state.ema.update()
777
+
778
+ # DDP synchronization for generator
779
+ gen_ddp_sync_start = time.time()
780
+ # if torch.distributed.is_initialized():
781
+ # torch.distributed.barrier()
782
+ timing_logs["gen_ddp_sync"] = time.time() - gen_ddp_sync_start
783
+
784
+ # Other metrics
785
+ output["other/learning_rate"] = self.state.optimizer_g.param_groups[0]["lr"]
786
+ output["other/batch_size"] = signal.batch_size * self.world_size
787
+
788
+ # Total train_loop time
789
+ timing_logs["total_train_loop"] = time.time() - train_loop_start
790
+ output.update({f"time/{k}": v for k, v in timing_logs.items()})
791
+
792
+ print(f"Rank {self.local_rank}: train_loop complete")
793
+ return {k: v for k, v in sorted(output.items())}
794
+
795
+ def checkpoint(self):
796
+ from datetime import datetime
797
+
798
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
799
+ step = self.state.tracker.step
800
+ tags = ["latest"]
801
+ if step % self.save_iters == 0:
802
+ tags.append(f"{step // 1000}k")
803
+
804
+ self.state.tracker.print(f"Saving checkpoint at step {step}")
805
+
806
+ # Prepare everything for saving
807
+ checkpoint = {
808
+ "generator": self.unwrap(self.state.generator).state_dict(),
809
+ "discriminator": self.unwrap(self.state.discriminator).state_dict(),
810
+ "optimizer_g": self.state.optimizer_g.state_dict(),
811
+ "optimizer_d": self.state.optimizer_d.state_dict(),
812
+ "scheduler_g": self.state.scheduler_g.state_dict(),
813
+ "scheduler_d": self.state.scheduler_d.state_dict(),
814
+ "tracker": self.state.tracker.state_dict(),
815
+ # "ema": self.state.ema.state_dict(), # Save EMA state
816
+ "step": step,
817
+ "config": self.configs,
818
+ "metadata": {
819
+ "logs": self.state.tracker.history,
820
+ "step": step,
821
+ "config": self.configs,
822
+ },
823
+ }
824
+
825
+ # Save for each tag (latest, 120k, etc)
826
+ for tag in tags:
827
+ save_folder = f"{self.save_path}/{tag}_{timestamp}"
828
+ os.makedirs(save_folder, exist_ok=True)
829
+ save_path = os.path.join(save_folder, "checkpoint.pt")
830
+ torch.save(checkpoint, save_path)
831
+ self.state.tracker.print(f"Checkpoint saved: {save_path}")
832
+
833
+ def resume_from_checkpoint(self, load_folder):
834
+ checkpoint_path = os.path.join(load_folder, "checkpoint.pt")
835
+ assert os.path.exists(
836
+ checkpoint_path
837
+ ), f"Checkpoint {checkpoint_path} not found!"
838
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
839
+
840
+ # Load model state dicts
841
+ self.unwrap(self.generator).load_state_dict(checkpoint["generator"])
842
+ self.unwrap(self.discriminator).load_state_dict(checkpoint["discriminator"])
843
+
844
+ # Load optimizer and scheduler state dicts **after** model is on device
845
+ self.optimizer_g.load_state_dict(checkpoint["optimizer_g"])
846
+ self.optimizer_d.load_state_dict(checkpoint["optimizer_d"])
847
+
848
+ for state in self.optimizer_g.state.values():
849
+ for k, v in state.items():
850
+ if torch.is_tensor(v):
851
+ state[k] = v.to(self.local_rank)
852
+ for state in self.optimizer_d.state.values():
853
+ for k, v in state.items():
854
+ if torch.is_tensor(v):
855
+ state[k] = v.to(self.local_rank)
856
+
857
+ self.scheduler_g.load_state_dict(checkpoint["scheduler_g"])
858
+ self.scheduler_d.load_state_dict(checkpoint["scheduler_d"])
859
+
860
+ # Load EMA state
861
+ # if "ema" in checkpoint:
862
+ # self.ema.load_state_dict(checkpoint["ema"])
863
+
864
+ # Load tracker/logs/step
865
+ self.tracker.load_state_dict(checkpoint["tracker"])
866
+ self.tracker.step = checkpoint.get("step", 0)
867
+
868
+ self.tracker.print(
869
+ f"Checkpoint loaded from {checkpoint_path} at step {self.tracker.step}"
870
+ )
871
+
872
+ def unwrap(self, model):
873
+ if hasattr(model, "module"):
874
+ return model.module
875
+ return model
876
+
877
+ @torch.no_grad()
878
+ def save_samples(self, val_idx):
879
+ print(f"Rank {self.local_rank}: Starting save_samples")
880
+ self.state.tracker.print("Saving audio samples to WandB")
881
+ self.state.generator.eval()
882
+
883
+ # Apply EMA weights
884
+ # self.state.ema.apply_shadow()
885
+
886
+ samples = [self.val_dataset[idx] for idx in val_idx]
887
+ batch = self.val_dataset.collate(samples)
888
+ batch = util.prepare_batch(batch, self.local_rank)
889
+ signal = self.val_dataset.transform(
890
+ batch["signal"].clone(), **batch["transform_args"]
891
+ )
892
+
893
+ out = self.state.generator(signal.audio_data, signal.sample_rate)
894
+ recons = AudioSignal(out["audio"], signal.sample_rate)
895
+
896
+ # Restore original weights
897
+ # self.state.ema.restore()
898
+
899
+ audio_dict = {"recons": recons}
900
+ # if self.state.tracker.step == 0:
901
+ audio_dict["signal"] = signal
902
+
903
+ audio_logs = {}
904
+ for k, v in audio_dict.items():
905
+ for nb in range(v.batch_size):
906
+ audio_data = v[nb].cpu().audio_data
907
+ if audio_data.dim() == 3:
908
+ audio_data = audio_data.squeeze(0)
909
+ elif audio_data.dim() == 1:
910
+ audio_data = audio_data.unsqueeze(0)
911
+
912
+ audio_data = audio_data.numpy().astype(np.float32)
913
+ if audio_data.max() > 1.0 or audio_data.min() < -1.0:
914
+ audio_data /= np.abs(audio_data).max()
915
+
916
+ sample_rate = int(v[nb].sample_rate)
917
+ if sample_rate <= 0:
918
+ raise ValueError(f"Invalid sample rate: {sample_rate}")
919
+
920
+ # Save audio to a temporary file
921
+ temp_file = f"temp_audio_{k}_{nb}.wav"
922
+ sf.write(temp_file, audio_data.T, sample_rate)
923
+
924
+ self.state.tracker.writer.log_audio(
925
+ temp_file,
926
+ metadata={
927
+ "caption": f"{k} sample {nb}",
928
+ "sample_rate": sample_rate,
929
+ "step": self.state.tracker.step,
930
+ },
931
+ step=self.state.tracker.step,
932
+ )
933
+
934
+ # Clean up temporary file
935
+ os.remove(temp_file)
936
+
937
+ def train(self):
938
+ print(f"Rank {self.local_rank}: Starting train ")
939
+ util.seed(self.seed)
940
+
941
+ max_iters = self.num_iters
942
+ train_loop = self.tracker.log("train", "value", history=False)(
943
+ self.tracker.track("train", max_iters, completed=self.state.tracker.step)(
944
+ self.train_loop
945
+ )
946
+ )
947
+
948
+ save_samples = when(lambda: self.local_rank == 0)(self.save_samples)
949
+ checkpoint = when(lambda: self.global_rank == 0)(self.checkpoint)
950
+
951
+ with self.tracker.live:
952
+ for self.tracker.step, batch in enumerate(
953
+ self.train_dataloader, start=self.state.tracker.step
954
+ ):
955
+ self.tracker.print(
956
+ f"Rank {self.global_rank}: Iteration {self.tracker.step}/{max_iters} "
957
+ )
958
+ output = train_loop(batch)
959
+
960
+ if self.global_rank == 0:
961
+ for k, v in output.items():
962
+ value = v.item() if torch.is_tensor(v) else v
963
+ self.tracker.writer.log_metric(k, value, step=self.tracker.step)
964
+
965
+ last_iter = self.tracker.step == max_iters - 1
966
+
967
+ if self.tracker.step % self.sample_freq == 0 or last_iter:
968
+ # torch.distributed.barrier()
969
+ save_samples(self.val_idx)
970
+ checkpoint()
971
+
972
+ if last_iter:
973
+ break
974
+
975
+
976
+ if __name__ == "__main__":
977
+ parser = argparse.ArgumentParser(description="Distributed DAC training")
978
+ parser.add_argument(
979
+ "--config_path",
980
+ type=str,
981
+ default="config.yml",
982
+ help="Path to config YAML",
983
+ )
984
+ parser.add_argument("--run_id", type=str, required=True, help="Run ID for wandb")
985
+ parser.add_argument(
986
+ "--resume", action="store_true", help="Resume training from checkpoint"
987
+ )
988
+ parser.add_argument(
989
+ "--load_weights", action="store_true", help="Load weights from checkpoint"
990
+ )
991
+ parser.add_argument(
992
+ "--save_path", type=str, default="ckpts", help="Path to save checkpoints"
993
+ )
994
+ parser.add_argument("--tag", type=str, default="latest", help="Tag for checkpoint")
995
+ args = parser.parse_args()
996
+
997
+ ddp_setup()
998
+ trainer = Trainer(args)
999
+ trainer.train()
1000
+ destroy_process_group()