Localsong commited on
Commit
15b323d
·
verified ·
1 Parent(s): eeed815

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +508 -0
train.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import DataLoader
5
+ from torchvision.utils import make_grid, save_image
6
+ from tqdm import tqdm
7
+ from ddt_model import LocalSongModel
8
+ from transformers import get_cosine_schedule_with_warmup
9
+ from datasets import load_from_disk
10
+ from accelerate import Accelerator
11
+ import os
12
+ import argparse
13
+ from torch.utils.tensorboard import SummaryWriter
14
+ from datetime import datetime
15
+ from collections import deque
16
+ import torchaudio
17
+ import re
18
+ import sys
19
+ import math
20
+ from tag_embedder import TagEmbedder
21
+
22
+ # Import MusicDCAE
23
+ from acestep.music_dcae.music_dcae_pipeline import MusicDCAE
24
+
25
+ # Import Muon optimizer
26
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
27
+ import timm.optim
28
+
29
+ import os
30
+
31
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
32
+
33
+ def save(model, optimizer, scheduler, global_step, accelerator):
34
+ if accelerator.is_main_process:
35
+ checkpoint_dir = "checkpoints"
36
+ os.makedirs(checkpoint_dir, exist_ok=True)
37
+
38
+ unwrapped_model = accelerator.unwrap_model(model)
39
+
40
+ checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{global_step}.pth")
41
+ save_dict = {
42
+ 'model_state_dict': unwrapped_model.state_dict(),
43
+ 'optimizer_state_dict': optimizer.state_dict(),
44
+ 'global_step': global_step
45
+ }
46
+
47
+ accelerator.save(save_dict, checkpoint_path)
48
+ print(f"Checkpoint saved at step {global_step}: {checkpoint_path}")
49
+
50
+ checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")],
51
+ key=lambda x: int(x.split("_")[1].split(".")[0]), reverse=True)
52
+
53
+ for old_checkpoint in checkpoints[5:]:
54
+ os.remove(os.path.join(checkpoint_dir, old_checkpoint))
55
+ print(f"Removed old checkpoint: {old_checkpoint}")
56
+
57
+
58
+ def load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator):
59
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
60
+
61
+ unwrapped_model = accelerator.unwrap_model(model)
62
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in checkpoint['model_state_dict'].items()}
63
+ missing, unexpected = unwrapped_model.load_state_dict(state_dict, strict=True)
64
+ print("MISSING:", missing)
65
+ print("UNEXPECTED:", unexpected)
66
+
67
+ if 'optimizer_state_dict' in checkpoint:
68
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
69
+ print("Optimizer loaded")
70
+
71
+ global_step = checkpoint['global_step']
72
+ print(f"Resumed from step {global_step}")
73
+ return global_step
74
+
75
+ def resume(model, optimizer, scheduler, accelerator):
76
+ checkpoint_dir = "checkpoints"
77
+ if os.path.exists(checkpoint_dir):
78
+ checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")]
79
+ if checkpoints:
80
+ latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("_")[1].split(".")[0]))
81
+ checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
82
+ if accelerator.is_main_process:
83
+ print(f"Resuming from checkpoint: {checkpoint_path}")
84
+
85
+ return load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator)
86
+ else:
87
+ if accelerator.is_main_process:
88
+ print("No checkpoints found. Starting from scratch.")
89
+ else:
90
+ if accelerator.is_main_process:
91
+ print("Checkpoint directory not found. Starting from scratch.")
92
+
93
+ return 0
94
+
95
+ class AudioVAE:
96
+ def __init__(self, device):
97
+ self.model = MusicDCAE().to(device)
98
+ self.model.eval()
99
+ self.device = device
100
+
101
+ self.latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], device=device).view(1, -1, 1, 1)
102
+ self.latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], device=device).view(1, -1, 1, 1)
103
+
104
+ def encode(self, audio):
105
+ """Encode audio to latents"""
106
+ # audio should be (B, 2, T) at 48kHz
107
+ with torch.no_grad():
108
+ audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device)
109
+ latents, _ = self.model.encode(audio, audio_lengths, sr=48000)
110
+ # Normalize latents: (latents - mean) / std
111
+ latents = (latents - self.latent_mean) / self.latent_std
112
+ return latents
113
+
114
+ def decode(self, latents):
115
+ """Decode latents to audio"""
116
+ with torch.no_grad():
117
+ # Denormalize latents: latents * std + mean
118
+ latents = latents * self.latent_std + self.latent_mean
119
+ sr, audio_list = self.model.decode(latents, sr=48000)
120
+ # Convert list of audio tensors to batch tensor
121
+ audio_batch = torch.stack(audio_list).to(self.device)
122
+ return audio_batch
123
+
124
+ class RF:
125
+ def __init__(self, model, time_sampling="sigmoid"):
126
+ self.model = model
127
+ self.time_sampling = time_sampling
128
+
129
+ def sample_timesteps(self, batch, device):
130
+ """Sample timesteps based on the configured strategy."""
131
+ if self.time_sampling == "sigmoid":
132
+ return torch.sigmoid(torch.randn((batch,), device=device))
133
+ elif self.time_sampling == "warped":
134
+ pm = 128 * 16 * 16
135
+ alpha = max(1.0, math.sqrt(pm / 4096.0))
136
+ u = torch.rand(batch, device=device)
137
+ return alpha * u / (1.0 + (alpha - 1.0) * u)
138
+ elif self.time_sampling == "uniform":
139
+ return torch.rand(batch, device=device)
140
+ else:
141
+ raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}")
142
+
143
+ def forward(self, x, cond):
144
+ b = x.size(0)
145
+
146
+ t = self.sample_timesteps(b, x.device)
147
+
148
+ texp = t.view([b, *([1] * len(x.shape[1:]))])
149
+ z1 = torch.randn_like(x)
150
+ zt = (1 - texp) * x + texp * z1
151
+
152
+ x_pred = self.model(zt, t, cond)
153
+
154
+ target = (zt - x) / (texp + 0.05)
155
+ v_pred = (zt - x_pred) / (texp + 0.05)
156
+ loss = F.mse_loss(target, v_pred)
157
+
158
+ return loss
159
+
160
+ def get_sampling_timesteps(self, steps, device):
161
+ """Generate timesteps for sampling."""
162
+ if self.time_sampling == "uniform" or self.time_sampling == "sigmoid":
163
+ return torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1]
164
+ elif self.time_sampling == "warped":
165
+ pm = 128 * 16 * 16
166
+ alpha = max(1.0, math.sqrt(pm / 4096.0))
167
+ u = torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1]
168
+ return alpha * u / (1.0 + (alpha - 1.0) * u)
169
+ else:
170
+ raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}")
171
+
172
+ def sample(self, z, cond, null_cond=None, sample_steps=100, cfg=3.0):
173
+ b = z.size(0)
174
+ device = z.device
175
+ latent_shape = [b, *([1] * len(z.shape[1:]))]
176
+
177
+ timesteps = self.get_sampling_timesteps(sample_steps, device)
178
+ images = [z]
179
+
180
+ for idx in range(sample_steps):
181
+ t_curr = timesteps[idx]
182
+ t_next = timesteps[idx + 1] if idx + 1 < sample_steps else torch.tensor(0.0, device=device)
183
+ dt = t_curr - t_next
184
+ t = t_curr.expand(b)
185
+
186
+ vc = self.model(z, t, cond)
187
+ vc = (z - vc) / t_curr
188
+ if null_cond is not None:
189
+ vu = self.model(z, t, null_cond)
190
+ vu = (z - vu) / t_curr
191
+ vc = vu + cfg * (vc - vu)
192
+
193
+ z = z - dt * vc
194
+ images.append(z)
195
+ return images
196
+
197
+ def save_audio_samples(audio_batch, sample_rate, filename):
198
+ """Save audio samples to file"""
199
+ os.makedirs("audio_samples", exist_ok=True)
200
+
201
+ # Take first sample from batch and convert to CPU
202
+ audio = audio_batch[0].cpu() # Shape: (2, T) for stereo
203
+
204
+ # Save as WAV file
205
+ filepath = os.path.join("audio_samples", filename)
206
+ torchaudio.save(filepath, audio, sample_rate)
207
+ print(f"Saved audio sample: {filepath}")
208
+
209
+ def parse_args():
210
+ parser = argparse.ArgumentParser(description='Audio training script with TensorBoard logging')
211
+
212
+ parser.add_argument('--channels', type=int, default=8, help='Number of input channels in the audio latents')
213
+ parser.add_argument('--audio_height', type=int, default=16, help='Height of audio latents')
214
+ parser.add_argument('--max_audio_width', type=int, default=4096, help='Max width of audio latents')
215
+ parser.add_argument('--subsection_length', type=int, default=256, help='Length of random subsection to sample from each audio latent')
216
+ parser.add_argument('--n_layers', type=int, default=36, help='Number of layers in the model')
217
+ parser.add_argument('--n_encoder_layers', type=int, default=36, help='Number of encoder layers in the model')
218
+ parser.add_argument('--n_heads', type=int, default=16, help='Number of heads in the model')
219
+ parser.add_argument('--dim', type=int, default=768, help='Dimension of the encoder')
220
+ parser.add_argument('--decoder_dim', type=int, default=1536, help='Dimension of the decoder (if None, uses --dim)')
221
+ parser.add_argument('--dataset_name', type=str, default="cache", help='Audio dataset name')
222
+ parser.add_argument('--num_workers', type=int, default=16, help='Number of workers for dataloader')
223
+
224
+ parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training')
225
+ parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train')
226
+ parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
227
+ parser.add_argument('--warmup_steps', type=int, default=0, help='Number of warmup steps')
228
+
229
+ parser.add_argument('--sample_every', type=int, default=500, help='Audio sampling interval (batches)')
230
+ parser.add_argument('--save_every', type=int, default=1000, help='Model saving interval (batches)')
231
+ parser.add_argument('--num_samples', type=int, default=16, help='Number of samples to generate')
232
+ parser.add_argument('--resume', type=bool, default=True, help='Resume training from checkpoint')
233
+ parser.add_argument('--pad_to_length', action='store_true', help='Pad short samples to subsection_length instead of filtering them out')
234
+ parser.add_argument('--time_sampling', type=str, default='warped', choices=['sigmoid', 'warped', 'uniform'], help='Timestep sampling strategy')
235
+
236
+ return parser.parse_args()
237
+
238
+ def main():
239
+ args = parse_args()
240
+
241
+ accelerator = Accelerator(mixed_precision="bf16" if torch.cuda.is_available() else "no")
242
+
243
+ is_main_process = accelerator.is_main_process
244
+
245
+ writer = None
246
+ if is_main_process:
247
+ run_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
248
+ writer = SummaryWriter(log_dir=f"runs/{run_datetime}")
249
+
250
+ dataset = load_from_disk(args.dataset_name).with_format(type="torch")
251
+
252
+ # Filter out audio samples shorter than subsection_length (unless padding is enabled)
253
+ if not args.pad_to_length:
254
+ def filter_by_length(example):
255
+ latent_width = example['latents'].shape[-1]
256
+ return latent_width >= args.subsection_length * 2
257
+
258
+ dataset = dataset.filter(filter_by_length)
259
+
260
+ if is_main_process:
261
+ print(f"Dataset filtered to {len(dataset)} samples with width >= {args.subsection_length * 2}")
262
+ else:
263
+ if is_main_process:
264
+ print(f"Padding enabled: short samples will be zero-padded to {args.subsection_length}")
265
+
266
+ # Latent normalization parameters (per-channel)
267
+ latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526]).view(1, -1, 1, 1)
268
+ latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707]).view(1, -1, 1, 1)
269
+
270
+ # Initialize tag embedder for converting metadata to tag indices
271
+ num_classes = 2304
272
+ tag_embedder = TagEmbedder(num_classes=num_classes)
273
+
274
+ # Custom collate function to randomly sample subsections from variable-width audio latents
275
+ def collate_fn(batch):
276
+ subsection_length = args.subsection_length
277
+ pad_to_length = False
278
+
279
+ sampled_latents = []
280
+ album_names = []
281
+ song_names = []
282
+ ids = []
283
+ tags = [] # List of tag lists for each sample
284
+
285
+ for item in batch:
286
+ latent = item['latents']
287
+ if len(latent.shape) == 3: # Add batch dimension if missing
288
+ latent = latent.unsqueeze(0)
289
+
290
+ # Get the width of the current latent
291
+ _, _, _, width = latent.shape
292
+
293
+ if width < subsection_length:
294
+ if pad_to_length:
295
+ # Pad the latent to subsection_length with zeros on the right
296
+ pad_amount = subsection_length - width
297
+ sampled_latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0)
298
+
299
+ else:
300
+ # Randomly sample a starting position
301
+ max_start = width - subsection_length
302
+ start_idx = torch.randint(0, max_start + 1, (1,)).item()
303
+
304
+ # Extract the subsection
305
+ sampled_latent = latent[:, :, :, start_idx:start_idx + subsection_length]
306
+
307
+ sampled_latents.append(sampled_latent.squeeze(0)) # Remove batch dim for stacking
308
+ album_name = item['album_name']
309
+ song_name = item['song_name']
310
+ album_names.append(album_name)
311
+ song_names.append(song_name)
312
+
313
+ sample_tags = tag_embedder.get_tags(album_name, song_name)
314
+ tags.append(sample_tags)
315
+
316
+ # Stack latents and normalize
317
+ stacked_latents = torch.stack(sampled_latents)
318
+ normalized_latents = (stacked_latents - latent_mean) / latent_std
319
+
320
+ return {
321
+ 'latents': normalized_latents,
322
+ 'tags': tags
323
+ }
324
+
325
+ dataloader = DataLoader(
326
+ dataset,
327
+ batch_size=args.batch_size,
328
+ shuffle=True,
329
+ drop_last=True,
330
+ persistent_workers=True,
331
+ num_workers=args.num_workers if torch.cuda.is_available() else 0,
332
+ pin_memory=True,
333
+ collate_fn=collate_fn
334
+ )
335
+
336
+ channels = args.channels
337
+
338
+ model = LocalSongModel(
339
+ in_channels=channels,
340
+ num_groups=args.n_heads,
341
+ hidden_size=args.dim,
342
+ decoder_hidden_size=args.decoder_dim,
343
+ num_blocks=args.n_layers,
344
+ patch_size=(16, 1), # Audio patch size (16 in height, 1 in width)
345
+ num_classes=num_classes, # Number of tag classes
346
+ max_tags=8, # Maximum number of tags per sample
347
+ )
348
+
349
+ vae = AudioVAE(accelerator.device)
350
+
351
+ rf = RF(model, time_sampling=args.time_sampling)
352
+
353
+ optimizer = timm.optim.Muon(model.parameters(),lr=args.lr)
354
+ scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.epochs * len(dataloader))
355
+
356
+ global_step = 0
357
+ if args.resume:
358
+ global_step = resume(model, optimizer, scheduler, accelerator)
359
+
360
+ if torch.cuda.is_available():
361
+ torch.backends.cuda.matmul.allow_tf32 = True
362
+ torch.backends.cudnn.allow_tf32 = True
363
+ model.forward_emb = torch.compile(model.forward_emb)
364
+
365
+ model, optimizer, scheduler, dataloader = accelerator.prepare(
366
+ model, optimizer, scheduler, dataloader
367
+ )
368
+
369
+ rf.model = model
370
+
371
+ if is_main_process:
372
+ model_size = sum(p.numel() for p in accelerator.unwrap_model(model).parameters() if p.requires_grad)
373
+ print(f"Number of parameters: {model_size}, {model_size / 1e6}M")
374
+
375
+ os.makedirs("audio_samples", exist_ok=True)
376
+ num_samples = args.num_samples
377
+
378
+ fixed_batch = None
379
+ fixed_latents = None
380
+ fixed_labels = None
381
+ fixed_noise = None
382
+
383
+ if is_main_process:
384
+ data_iter = iter(dataloader)
385
+ fixed_batch = next(data_iter)
386
+ fixed_latents = fixed_batch["latents"][:num_samples]
387
+
388
+ print("Fixed ids:", fixed_batch["album_names"])
389
+
390
+ # Get fixed tags for sampling
391
+ fixed_tags = []
392
+
393
+ # Create reverse mapping from tag indices to strings
394
+ idx_to_tag = {v: k for k, v in tag_embedder.tag_mapping.items()}
395
+
396
+ # Print string labels for fixed tags
397
+ print("Fixed tag labels:")
398
+ for i, tag_list in enumerate(fixed_tags):
399
+ labels = [idx_to_tag.get(idx, f"<unknown:{idx}>") for idx in tag_list]
400
+ print(f" Sample {i}: {labels}")
401
+
402
+ # Create noise with same shape as fixed latents
403
+ B, C, H, W = fixed_latents.shape
404
+ fixed_noise = torch.randn(num_samples, C, H, args.subsection_length, device=accelerator.device)
405
+
406
+ fixed_latents = fixed_latents.to(accelerator.device)
407
+
408
+ if is_main_process:
409
+ print("Begin training")
410
+
411
+ mse_loss_window = deque(maxlen=100)
412
+ start_epoch = 0
413
+ for epoch in range(start_epoch, args.epochs):
414
+
415
+ pbar = tqdm(dataloader) if is_main_process else dataloader
416
+ for batch in pbar:
417
+ x = batch["latents"]
418
+
419
+ # Get tags from batch
420
+ tags = batch["tags"]
421
+
422
+ # Apply classifier-free guidance dropout (10% chance to drop all tags)
423
+ dropout_tags = []
424
+ for tag_list in tags:
425
+ if torch.rand(1).item() < 0.1:
426
+ # Replace with empty list (will be padded to [0] in embed_condition)
427
+ dropout_tags.append([])
428
+ else:
429
+ dropout_tags.append(tag_list)
430
+
431
+ # Tags will be embedded inside the model's forward pass
432
+ c = dropout_tags
433
+
434
+ with accelerator.accumulate(model):
435
+ optimizer.zero_grad()
436
+ mse_loss = rf.forward(x, c)
437
+
438
+ loss = mse_loss
439
+
440
+ accelerator.backward(loss)
441
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
442
+ optimizer.step()
443
+ scheduler.step()
444
+
445
+ if is_main_process:
446
+
447
+ mse_loss_window.append(mse_loss.item())
448
+
449
+ avg_mse_loss = sum(mse_loss_window) / len(mse_loss_window)
450
+
451
+ if isinstance(pbar, tqdm):
452
+ pbar.set_postfix({"mse_loss": avg_mse_loss, "lr": optimizer.param_groups[0]['lr']})
453
+
454
+ if writer is not None:
455
+ writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step)
456
+ writer.add_scalar('MSE_Loss', avg_mse_loss, global_step)
457
+
458
+ global_step += 1
459
+
460
+ if is_main_process and global_step % args.save_every == 0:
461
+ save(model, optimizer, scheduler, global_step, accelerator)
462
+
463
+ if is_main_process and global_step % args.sample_every == 0:
464
+ model.eval()
465
+
466
+ with torch.no_grad():
467
+ # Use fixed tags for conditional sampling
468
+ cond = fixed_tags
469
+ # Unconditional is empty tags for all samples
470
+ null_cond = [[] for _ in range(len(cond))]
471
+
472
+ sampled_latents = rf.sample(fixed_noise, cond, null_cond)[-1]
473
+
474
+ # Decode latents to audio
475
+ try:
476
+ sampled_audio = vae.decode(sampled_latents)
477
+
478
+ # Save audio samples
479
+ for i in range(min(8, sampled_audio.shape[0])): # Save first 2 samples
480
+ save_audio_samples(
481
+ sampled_audio[i:i+1],
482
+ 48000,
483
+ f"sample_{global_step}_generated_{i}.wav"
484
+ )
485
+
486
+ # Also save original for comparison
487
+ if global_step == args.sample_every:
488
+ original_audio = vae.decode(fixed_latents)
489
+ for i in range(min(8, original_audio.shape[0])):
490
+ save_audio_samples(
491
+ original_audio[i:i+1],
492
+ 48000,
493
+ f"sample_{global_step}_original_{i}.wav"
494
+ )
495
+
496
+ except Exception as e:
497
+ print(f"Error during audio generation: {e}")
498
+
499
+ model.train()
500
+
501
+ print("Saving final model")
502
+ save(model, optimizer, scheduler, global_step, accelerator)
503
+
504
+ if writer is not None:
505
+ writer.close()
506
+
507
+ if __name__ == '__main__':
508
+ main()