Respair commited on
Commit
6ff8078
·
verified ·
1 Parent(s): ce00803

Delete train_boson.py

Browse files
Files changed (1) hide show
  1. train_boson.py +0 -891
train_boson.py DELETED
@@ -1,891 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Training script for Boson Audio Codec with DAC-inspired losses
4
- """
5
-
6
- import os
7
- import json
8
- import argparse
9
- import random
10
- from pathlib import Path
11
- from datetime import datetime
12
- import numpy as np
13
- import pandas as pd
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- import torch.distributed as dist
18
- from torch.nn.parallel import DistributedDataParallel as DDP
19
- from torch.utils.data import Dataset, DataLoader
20
- from torch.utils.data.distributed import DistributedSampler
21
- from torch.utils.tensorboard import SummaryWriter
22
- import torchaudio
23
- import librosa
24
- from tqdm import tqdm
25
- from audiotools import AudioSignal, STFTParams
26
-
27
- # Import from the provided codebase
28
- from higgs_audio_tokenizer import HiggsAudioTokenizer
29
- from quantization.distrib import broadcast_tensors, sync_buffer, is_distributed, world_size, rank
30
- from quantization.ddp_utils import set_random_seed, is_logging_process, get_timestamp
31
-
32
- # Import DAC losses and discriminator
33
- import sys
34
- sys.path.append('.') # Add current directory to path
35
- from loss import L1Loss, MultiScaleSTFTLoss, MelSpectrogramLoss, GANLoss
36
- from discriminator import Discriminator
37
-
38
-
39
- class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
40
- """Cosine scheduler with linear warmup"""
41
- def __init__(self, optimizer, warmup_steps, total_steps, eta_min=1e-6, last_epoch=-1):
42
- self.warmup_steps = warmup_steps
43
- self.total_steps = total_steps
44
- self.eta_min = eta_min
45
- super().__init__(optimizer, last_epoch)
46
-
47
- def get_lr(self):
48
- if self.last_epoch < self.warmup_steps:
49
- # Linear warmup
50
- warmup_factor = self.last_epoch / self.warmup_steps
51
- return [base_lr * warmup_factor for base_lr in self.base_lrs]
52
- else:
53
- # Cosine annealing
54
- progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
55
- cosine_factor = 0.5 * (1 + np.cos(np.pi * progress))
56
- return [self.eta_min + (base_lr - self.eta_min) * cosine_factor for base_lr in self.base_lrs]
57
-
58
-
59
- class AudioDataset(Dataset):
60
- """Dataset for loading audio files from CSV"""
61
- def __init__(self, csv_path, sample_rate=44100, segment_duration=2.0, is_train=True):
62
- self.df = pd.read_csv(csv_path)
63
- self.sample_rate = sample_rate
64
- self.segment_duration = segment_duration
65
- self.segment_length = int(sample_rate * segment_duration)
66
- self.is_train = is_train
67
-
68
- # Filter out files that don't exist
69
- valid_files = []
70
- for idx, row in self.df.iterrows():
71
- if os.path.exists(row.iloc[0]):
72
- valid_files.append(row.iloc[0])
73
- self.audio_paths = valid_files
74
- print(f"Found {len(self.audio_paths)} valid audio files")
75
-
76
- def __len__(self):
77
- return len(self.audio_paths)
78
-
79
- def __getitem__(self, idx):
80
- audio_path = self.audio_paths[idx]
81
-
82
- try:
83
- # Load audio using librosa
84
- audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
85
-
86
- # Random segment extraction for training
87
- if len(audio) > self.segment_length:
88
- if self.is_train:
89
- start = random.randint(0, len(audio) - self.segment_length)
90
- else:
91
- start = 0 # Always use beginning for validation
92
- audio = audio[start:start + self.segment_length]
93
- else:
94
- # Pad if too short
95
- audio = np.pad(audio, (0, self.segment_length - len(audio)))
96
-
97
- # Convert to tensor and add batch dimension
98
- audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
99
-
100
- return audio_tensor, audio_path
101
-
102
- except Exception as e:
103
- print(f"Error loading {audio_path}: {e}")
104
- # Return silence if loading fails
105
- return torch.zeros(1, self.segment_length), audio_path
106
-
107
-
108
- class BosonTrainer:
109
- def __init__(self, args):
110
- self.args = args
111
- self.distributed = False
112
-
113
- # Check if we're in a distributed environment
114
- if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) > 1:
115
- self.distributed = True
116
- self.setup_ddp()
117
- self.device = torch.device(f'cuda:{args.local_rank}')
118
- else:
119
- # Single GPU mode
120
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
121
- torch.cuda.set_device(0)
122
- set_random_seed(args.seed)
123
-
124
- # Load config
125
- with open(args.config, 'r') as f:
126
- self.config = json.load(f)
127
-
128
- # Initialize models
129
- self.model = self.build_model()
130
- self.discriminator = self.build_discriminator() if args.use_discriminator else None
131
-
132
- # Setup data loaders
133
- self.train_loader, self.val_loader = self.setup_data_loaders()
134
-
135
- # Setup optimizers
136
- self.optimizer_g = torch.optim.AdamW(
137
- self.model.parameters(),
138
- lr=args.learning_rate,
139
- betas=(0.5, 0.9),
140
- weight_decay=args.weight_decay
141
- )
142
-
143
- if self.discriminator is not None:
144
- self.optimizer_d = torch.optim.AdamW(
145
- self.discriminator.parameters(),
146
- lr=args.learning_rate * 2, # Typically discriminator learns faster
147
- betas=(0.5, 0.9),
148
- weight_decay=args.weight_decay
149
- )
150
-
151
- # Calculate total training steps
152
- self.total_steps = args.num_epochs * len(self.train_loader)
153
-
154
- # Setup schedulers with warmup
155
- self.scheduler_g = CosineWarmupScheduler(
156
- self.optimizer_g,
157
- warmup_steps=args.warmup_steps,
158
- total_steps=self.total_steps,
159
- eta_min=1e-6
160
- )
161
-
162
- if self.discriminator is not None:
163
- self.scheduler_d = CosineWarmupScheduler(
164
- self.optimizer_d,
165
- warmup_steps=args.warmup_steps,
166
- total_steps=self.total_steps,
167
- eta_min=1e-6
168
- )
169
-
170
- # Setup losses
171
- self.setup_losses()
172
-
173
- # Setup tensorboard
174
- if not self.distributed or rank() == 0:
175
- self.writer = SummaryWriter(
176
- log_dir=os.path.join(args.output_dir, 'logs', get_timestamp())
177
- )
178
-
179
- self.global_step = 0
180
- self.start_epoch = 0
181
-
182
- # Load checkpoint if exists
183
- if args.resume:
184
- self.load_checkpoint()
185
-
186
- def setup_ddp(self):
187
- """Initialize DDP"""
188
- if 'LOCAL_RANK' in os.environ:
189
- self.args.local_rank = int(os.environ['LOCAL_RANK'])
190
- dist.init_process_group(backend='nccl')
191
- torch.cuda.set_device(self.args.local_rank)
192
- set_random_seed(self.args.seed + rank())
193
-
194
- def build_model(self):
195
- """Build and wrap model with DDP if needed"""
196
-
197
- print(self.config)
198
- model = HiggsAudioTokenizer(
199
- n_filters=self.config['n_filters'],
200
- D=self.config['D'],
201
- target_bandwidths=self.config['target_bandwidths'],
202
- ratios=self.config['ratios'],
203
- sample_rate=self.config['sample_rate'],
204
- bins=self.config['bins'],
205
- n_q=self.config['n_q'],
206
- codebook_dim=self.config.get('codebook_dim', None),
207
- semantic_techer=self.config['semantic_techer'],
208
- device=self.device
209
- ).to(self.device)
210
-
211
- if self.distributed:
212
- # Broadcast model parameters to ensure all ranks have same initialization
213
- broadcast_tensors(model.parameters())
214
- # Wrap with DDP
215
- model = DDP(model, device_ids=[self.args.local_rank])
216
-
217
- return model
218
-
219
- def build_discriminator(self):
220
- """Build discriminator with DDP if needed"""
221
- # Use sample rate from config
222
- discriminator = Discriminator(
223
- rates=[], # No multi-rate discriminator for now
224
- periods=[2, 3, 5, 7, 11],
225
- fft_sizes=[2048, 1024, 512],
226
- sample_rate=self.config['sample_rate'],
227
- ).to(self.device)
228
-
229
- if self.distributed:
230
- broadcast_tensors(discriminator.parameters())
231
- discriminator = DDP(discriminator, device_ids=[self.args.local_rank])
232
-
233
- return discriminator
234
-
235
- def setup_losses(self):
236
- """Setup all loss functions"""
237
- # Basic losses
238
- self.l1_loss = L1Loss()
239
- self.stft_loss = MultiScaleSTFTLoss(
240
- window_lengths=[2048, 1024, 512, 256, 128],
241
- loss_fn=nn.L1Loss(),
242
- clamp_eps=1e-5,
243
- mag_weight=1.0,
244
- log_weight=1.0,
245
- )
246
- self.mel_loss = MelSpectrogramLoss(
247
- n_mels=[150, 80],
248
- window_lengths=[2048, 512],
249
- mel_fmin=[0.0, 0.0],
250
- mel_fmax=[None, None],
251
- clamp_eps=1e-5,
252
- mag_weight=1.0,
253
- log_weight=1.0,
254
- )
255
-
256
- # GAN loss if using discriminator
257
- if self.discriminator is not None:
258
- self.gan_loss = GANLoss(self.discriminator)
259
-
260
- # Loss weights (matching DAC's proven configuration)
261
- self.loss_weights = {
262
- 'rec': 1., # Waveform L1 loss
263
- 'stft': 1., # Multi-scale STFT loss
264
- # 'mel': 15.0, # Mel-spectrogram loss (ENABLE it after 20-25k steps)
265
- 'mel': 0.0, # Mel-spectrogram loss (DISABLED)
266
- 'commit': 0.25, # Commitment loss
267
- 'semantic': 1., # Semantic loss
268
- 'gen': 1., # Generator adversarial loss
269
- 'feat': 1.0, # Feature matching loss
270
- }
271
-
272
- def setup_data_loaders(self):
273
- """Setup data loaders (distributed or single GPU)"""
274
- # Split data into train/val
275
- df = pd.read_csv(self.args.data_csv)
276
- n_total = len(df)
277
- n_train = int(n_total * 0.9)
278
-
279
- # Create temporary CSV files for train/val split
280
- train_csv = '/tmp/train_audio.csv'
281
- val_csv = '/tmp/val_audio.csv'
282
-
283
- if not self.distributed or rank() == 0:
284
- df[:n_train].to_csv(train_csv, index=False)
285
- df[n_train:].to_csv(val_csv, index=False)
286
-
287
- # Synchronize across processes if distributed
288
- if self.distributed:
289
- dist.barrier()
290
-
291
- # Create datasets
292
- train_dataset = AudioDataset(
293
- train_csv,
294
- sample_rate=self.config['sample_rate'],
295
- segment_duration=self.args.segment_duration,
296
- is_train=True
297
- )
298
-
299
- val_dataset = AudioDataset(
300
- val_csv,
301
- sample_rate=self.config['sample_rate'],
302
- segment_duration=self.args.segment_duration,
303
- is_train=False
304
- )
305
-
306
- # Create samplers and loaders
307
- if self.distributed:
308
- train_sampler = DistributedSampler(train_dataset, shuffle=True)
309
- val_sampler = DistributedSampler(val_dataset, shuffle=False)
310
- else:
311
- train_sampler = None
312
- val_sampler = None
313
-
314
- train_loader = DataLoader(
315
- train_dataset,
316
- batch_size=self.args.batch_size,
317
- sampler=train_sampler,
318
- shuffle=(train_sampler is None),
319
- num_workers=self.args.num_workers,
320
- pin_memory=True,
321
- drop_last=True
322
- )
323
-
324
- val_loader = DataLoader(
325
- val_dataset,
326
- batch_size=self.args.batch_size,
327
- sampler=val_sampler,
328
- shuffle=False,
329
- num_workers=self.args.num_workers,
330
- pin_memory=True,
331
- drop_last=False
332
- )
333
-
334
- return train_loader, val_loader
335
-
336
- def is_main_process(self):
337
- """Check if this is the main process"""
338
- return not self.distributed or rank() == 0
339
-
340
- def train_epoch(self, epoch):
341
- """Train for one epoch"""
342
- self.model.train()
343
- if self.discriminator is not None:
344
- self.discriminator.train()
345
-
346
- if self.distributed:
347
- self.train_loader.sampler.set_epoch(epoch)
348
-
349
- total_losses = {
350
- 'total': 0, 'rec': 0, 'stft': 0, 'mel': 0,
351
- 'commit': 0, 'semantic': 0, 'gen': 0, 'feat': 0, 'disc': 0
352
- }
353
-
354
- pbar = tqdm(self.train_loader, desc=f'Epoch {epoch}', disable=not self.is_main_process())
355
-
356
- for batch_idx, (audio, paths) in enumerate(pbar):
357
- audio = audio.to(self.device)
358
-
359
- # Create AudioSignal objects for loss computation
360
- audio_signal = AudioSignal(audio, self.config['sample_rate'])
361
-
362
- # Forward pass with random bandwidth
363
- bw_idx = random.randint(0, len(self.config['target_bandwidths']) - 1)
364
- bw = self.config['target_bandwidths'][bw_idx]
365
-
366
- output, commit_loss, semantic_loss, _ = self.model(audio, bw)
367
- recons_signal = AudioSignal(output, self.config['sample_rate'])
368
-
369
- # Check if discriminator should be active (after discriminator_start_step)
370
- use_discriminator = (self.discriminator is not None and
371
- self.global_step >= self.args.discriminator_start_step)
372
-
373
- # Train discriminator first if using GAN and past the start step
374
- if use_discriminator and self.global_step % self.args.disc_interval == 0:
375
- self.optimizer_d.zero_grad()
376
- disc_loss = self.gan_loss.discriminator_loss(recons_signal, audio_signal)
377
- disc_loss.backward()
378
- torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 10.0)
379
- self.optimizer_d.step()
380
- self.scheduler_d.step()
381
- total_losses['disc'] += disc_loss.item()
382
-
383
- # Train generator
384
- losses = {}
385
-
386
- # Reconstruction losses
387
- losses['rec'] = self.l1_loss(recons_signal, audio_signal)
388
- losses['stft'] = self.stft_loss(recons_signal, audio_signal)
389
- # losses['mel'] = self.mel_loss(recons_signal, audio_signal)
390
- losses['mel'] = torch.tensor(0.0, device=self.device) # 15.
391
- losses['commit'] = commit_loss
392
- losses['semantic'] = semantic_loss
393
-
394
- # GAN losses if discriminator is active
395
- if use_discriminator:
396
- gen_loss, feat_loss = self.gan_loss.generator_loss(recons_signal, audio_signal)
397
- losses['gen'] = gen_loss
398
- losses['feat'] = feat_loss
399
- else:
400
- # Set to zero for logging purposes
401
- losses['gen'] = torch.tensor(0.0, device=self.device)
402
- losses['feat'] = torch.tensor(0.0, device=self.device)
403
-
404
- # Total weighted loss
405
- total_loss = sum(self.loss_weights.get(k, 0) * v for k, v in losses.items()
406
- if k not in ['gen', 'feat'] or use_discriminator)
407
-
408
- # Backward pass
409
- self.optimizer_g.zero_grad()
410
- total_loss.backward()
411
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
412
- self.optimizer_g.step()
413
- self.scheduler_g.step()
414
-
415
- # Update metrics
416
- total_losses['total'] += total_loss.item()
417
- for k, v in losses.items():
418
- total_losses[k] += v.item()
419
-
420
- # Update progress bar
421
- if self.is_main_process():
422
- pbar.set_postfix({
423
- 'loss': f'{total_loss.item():.4f}',
424
- 'rec': f'{losses["rec"].item():.4f}',
425
- 'mel': f'{losses["mel"].item():.4f}',
426
- 'commit_loss': f'{losses["commit"].item():.4f}',
427
- 'semantic_loss': f'{losses["semantic"].item():.4f}',
428
- 'lr': f'{self.scheduler_g.get_last_lr()[0]:.9f}',
429
- 'disc': 'ON' if use_discriminator else 'OFF',
430
- 'step': self.global_step
431
- })
432
-
433
- # Log to tensorboard
434
- if self.is_main_process() and self.global_step % self.args.log_interval == 0:
435
- for k, v in losses.items():
436
- self.writer.add_scalar(f'train/{k}_loss', v.item(), self.global_step)
437
- self.writer.add_scalar('train/total_loss', total_loss.item(), self.global_step)
438
- self.writer.add_scalar('train/lr', self.scheduler_g.get_last_lr()[0], self.global_step)
439
- self.writer.add_scalar('train/bandwidth', bw, self.global_step)
440
- self.writer.add_scalar('train/discriminator_active', float(use_discriminator), self.global_step)
441
- if use_discriminator:
442
- self.writer.add_scalar('train/disc_loss', total_losses['disc'] / max(1, batch_idx), self.global_step)
443
-
444
- # Save checkpoint at step intervals
445
- if self.global_step > 0 and self.global_step % self.args.save_step_interval == 0:
446
- self.save_checkpoint_step(self.global_step)
447
- if self.is_main_process():
448
- print(f"\nSaved checkpoint at step {self.global_step}")
449
-
450
- self.global_step += 1
451
-
452
- # Return average losses
453
- n_batches = len(self.train_loader)
454
- return {k: v / n_batches for k, v in total_losses.items()}
455
-
456
- @torch.no_grad()
457
- def validate(self, epoch):
458
- """Validation loop"""
459
- self.model.eval()
460
-
461
- total_losses = {
462
- 'total': 0, 'rec': 0, 'stft': 0, 'mel': 0,
463
- 'commit': 0, 'semantic': 0
464
- }
465
-
466
- # Store audio samples for tensorboard
467
- audio_samples = {'train': [], 'val': []}
468
-
469
- for batch_idx, (audio, paths) in enumerate(tqdm(self.val_loader, desc='Validation', disable=not self.is_main_process())):
470
- audio = audio.to(self.device)
471
- audio_signal = AudioSignal(audio, self.config['sample_rate'])
472
-
473
- # Use medium bandwidth for validation
474
- bw = self.config['target_bandwidths'][2]
475
-
476
- output, commit_loss, semantic_loss, _ = self.model(audio, bw)
477
- recons_signal = AudioSignal(output, self.config['sample_rate'])
478
-
479
- # Compute losses
480
- losses = {
481
- 'rec': self.l1_loss(recons_signal, audio_signal),
482
- 'stft': self.stft_loss(recons_signal, audio_signal),
483
- 'mel': self.mel_loss(recons_signal, audio_signal),
484
- 'commit': commit_loss,
485
- 'semantic': semantic_loss
486
- }
487
-
488
- total_loss = sum(self.loss_weights.get(k, 0) * v for k, v in losses.items())
489
-
490
- total_losses['total'] += total_loss.item()
491
- for k, v in losses.items():
492
- total_losses[k] += v.item()
493
-
494
- # Collect audio samples for tensorboard (first 3 from validation)
495
- if self.is_main_process() and len(audio_samples['val']) < 3:
496
- audio_samples['val'].append({
497
- 'original': audio[0].cpu(),
498
- 'reconstructed': output[0].cpu(),
499
- 'path': paths[0]
500
- })
501
-
502
- # Get train samples for comparison
503
- if self.is_main_process():
504
- self.model.eval()
505
- for batch_idx, (audio, paths) in enumerate(self.train_loader):
506
- if len(audio_samples['train']) >= 3:
507
- break
508
- audio = audio.to(self.device)
509
- bw = self.config['target_bandwidths'][2]
510
- output, _, _, _ = self.model(audio, bw)
511
- audio_samples['train'].append({
512
- 'original': audio[0].cpu(),
513
- 'reconstructed': output[0].cpu(),
514
- 'path': paths[0]
515
- })
516
-
517
- # Log audio samples to tensorboard
518
- if self.is_main_process():
519
- for split in ['train', 'val']:
520
- for idx, sample in enumerate(audio_samples[split]):
521
- self.writer.add_audio(
522
- f'{split}/original_{idx}',
523
- sample['original'],
524
- epoch,
525
- sample_rate=self.config['sample_rate']
526
- )
527
- self.writer.add_audio(
528
- f'{split}/reconstructed_{idx}',
529
- sample['reconstructed'],
530
- epoch,
531
- sample_rate=self.config['sample_rate']
532
- )
533
-
534
- # Average losses
535
- n_batches = len(self.val_loader)
536
- val_metrics = {k: v / n_batches for k, v in total_losses.items()}
537
-
538
- # Log validation metrics
539
- if self.is_main_process():
540
- for key, value in val_metrics.items():
541
- self.writer.add_scalar(f'val/{key}_loss', value, epoch)
542
-
543
- return val_metrics
544
-
545
- def save_checkpoint(self, epoch, is_best=False):
546
- """Save model checkpoint (epoch-based)"""
547
- if not self.is_main_process():
548
- return
549
-
550
- model_state = self.model.module.state_dict() if self.distributed else self.model.state_dict()
551
-
552
- # Get current learning rates for verification
553
- current_lr_g = self.scheduler_g.get_last_lr()[0]
554
-
555
- checkpoint = {
556
- 'epoch': epoch,
557
- 'global_step': self.global_step,
558
- 'model_state_dict': model_state,
559
- 'optimizer_g_state_dict': self.optimizer_g.state_dict(),
560
- 'scheduler_g_state_dict': self.scheduler_g.state_dict(),
561
- 'scheduler_g_last_epoch': self.scheduler_g.last_epoch, # Explicitly save this
562
- 'current_lr_g': current_lr_g, # Save for verification
563
- 'config': self.config,
564
- 'args': self.args
565
- }
566
-
567
- if self.discriminator is not None:
568
- disc_state = self.discriminator.module.state_dict() if self.distributed else self.discriminator.state_dict()
569
- current_lr_d = self.scheduler_d.get_last_lr()[0]
570
- checkpoint['discriminator_state_dict'] = disc_state
571
- checkpoint['optimizer_d_state_dict'] = self.optimizer_d.state_dict()
572
- checkpoint['scheduler_d_state_dict'] = self.scheduler_d.state_dict()
573
- checkpoint['scheduler_d_last_epoch'] = self.scheduler_d.last_epoch
574
- checkpoint['current_lr_d'] = current_lr_d
575
-
576
- # Save latest checkpoint
577
- checkpoint_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth')
578
- os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
579
- torch.save(checkpoint, checkpoint_path)
580
-
581
- # Save best checkpoint
582
- if is_best:
583
- best_path = os.path.join(self.args.output_dir, 'checkpoints', 'best.pth')
584
- torch.save(checkpoint, best_path)
585
-
586
- # Save periodic checkpoint
587
- if epoch % self.args.save_interval == 0:
588
- epoch_path = os.path.join(self.args.output_dir, 'checkpoints', f'epoch_{epoch}.pth')
589
- torch.save(checkpoint, epoch_path)
590
-
591
-
592
- def save_checkpoint_step(self, step):
593
- """Save model checkpoint (step-based)"""
594
- if not self.is_main_process():
595
- return
596
-
597
- # Get current epoch from training loop
598
- current_epoch = step // len(self.train_loader)
599
-
600
- model_state = self.model.module.state_dict() if self.distributed else self.model.state_dict()
601
-
602
- # Get current learning rates for verification
603
- current_lr_g = self.scheduler_g.get_last_lr()[0]
604
-
605
- checkpoint = {
606
- 'epoch': current_epoch,
607
- 'global_step': step,
608
- 'model_state_dict': model_state,
609
- 'optimizer_g_state_dict': self.optimizer_g.state_dict(),
610
- 'scheduler_g_state_dict': self.scheduler_g.state_dict(),
611
- 'scheduler_g_last_epoch': self.scheduler_g.last_epoch, # Explicitly save this
612
- 'current_lr_g': current_lr_g, # Save for verification
613
- 'config': self.config,
614
- 'args': self.args
615
- }
616
-
617
- if self.discriminator is not None:
618
- disc_state = self.discriminator.module.state_dict() if self.distributed else self.discriminator.state_dict()
619
- current_lr_d = self.scheduler_d.get_last_lr()[0]
620
- checkpoint['discriminator_state_dict'] = disc_state
621
- checkpoint['optimizer_d_state_dict'] = self.optimizer_d.state_dict()
622
- checkpoint['scheduler_d_state_dict'] = self.scheduler_d.state_dict()
623
- checkpoint['scheduler_d_last_epoch'] = self.scheduler_d.last_epoch
624
- checkpoint['current_lr_d'] = current_lr_d
625
-
626
- # Save step-based checkpoint
627
- step_path = os.path.join(self.args.output_dir, 'checkpoints', f'step_{step}.pth')
628
- torch.save(checkpoint, step_path)
629
-
630
- # Also update latest checkpoint
631
- latest_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth')
632
- torch.save(checkpoint, latest_path)
633
-
634
- # Keep only the last N step-based checkpoints to save disk space
635
- if self.args.keep_last_n_steps > 0:
636
- checkpoint_dir = os.path.join(self.args.output_dir, 'checkpoints')
637
- step_checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith('step_')])
638
- if len(step_checkpoints) > self.args.keep_last_n_steps:
639
- for old_checkpoint in step_checkpoints[:-self.args.keep_last_n_steps]:
640
- os.remove(os.path.join(checkpoint_dir, old_checkpoint))
641
-
642
-
643
- def load_checkpoint(self):
644
- """Load checkpoint with proper state restoration"""
645
- checkpoint_path = os.path.join(self.args.output_dir, 'checkpoints', 'latest.pth')
646
- if os.path.exists(checkpoint_path):
647
- print(f"Loading checkpoint from {checkpoint_path}")
648
- checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
649
-
650
- # Load model state
651
- if self.distributed:
652
- self.model.module.load_state_dict(checkpoint['model_state_dict'])
653
- else:
654
- self.model.load_state_dict(checkpoint['model_state_dict'])
655
-
656
- # Load optimizer state
657
- self.optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
658
-
659
- # Load scheduler state
660
- self.scheduler_g.load_state_dict(checkpoint['scheduler_g_state_dict'])
661
-
662
- # Restore scheduler's last_epoch from checkpoint
663
- if 'scheduler_g_last_epoch' in checkpoint:
664
- self.scheduler_g.last_epoch = checkpoint['scheduler_g_last_epoch']
665
- else:
666
- # Fallback: use global_step if the explicit value wasn't saved
667
- self.scheduler_g.last_epoch = checkpoint['global_step']
668
-
669
- # Force scheduler to recompute its internal state
670
- self.scheduler_g._last_lr = self.scheduler_g.get_lr()
671
-
672
- # Load discriminator if present
673
- if self.discriminator is not None and 'discriminator_state_dict' in checkpoint:
674
- if self.distributed:
675
- self.discriminator.module.load_state_dict(checkpoint['discriminator_state_dict'])
676
- else:
677
- self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
678
- self.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
679
- self.scheduler_d.load_state_dict(checkpoint['scheduler_d_state_dict'])
680
-
681
- # Restore discriminator scheduler's last_epoch
682
- if 'scheduler_d_last_epoch' in checkpoint:
683
- self.scheduler_d.last_epoch = checkpoint['scheduler_d_last_epoch']
684
- else:
685
- self.scheduler_d.last_epoch = checkpoint['global_step']
686
-
687
- self.scheduler_d._last_lr = self.scheduler_d.get_lr()
688
-
689
- # Restore training state
690
- self.start_epoch = checkpoint['epoch'] + 1
691
- self.global_step = checkpoint['global_step']
692
-
693
- # Verify learning rate restoration
694
- current_lr_g = self.scheduler_g.get_last_lr()[0]
695
- saved_lr_g = checkpoint.get('current_lr_g', None)
696
-
697
- print(f"\n{'='*60}")
698
- print(f"CHECKPOINT LOADED SUCCESSFULLY")
699
- print(f"{'='*60}")
700
- print(f"Resumed from epoch: {checkpoint['epoch']}")
701
- print(f"Global step: {self.global_step}")
702
- print(f"Scheduler last_epoch: {self.scheduler_g.last_epoch}")
703
- print(f"Current learning rate (generator): {current_lr_g:.9f}")
704
- if saved_lr_g is not None:
705
- print(f"Saved learning rate (generator): {saved_lr_g:.9f}")
706
- if abs(current_lr_g - saved_lr_g) > 1e-9:
707
- print("⚠️ WARNING: Learning rate mismatch! This might indicate improper state restoration.")
708
-
709
- if self.discriminator is not None:
710
- current_lr_d = self.scheduler_d.get_last_lr()[0]
711
- saved_lr_d = checkpoint.get('current_lr_d', None)
712
- print(f"Current learning rate (discriminator): {current_lr_d:.9f}")
713
- if saved_lr_d is not None:
714
- print(f"Saved learning rate (discriminator): {saved_lr_d:.9f}")
715
- print(f"Discriminator status: {'ACTIVE' if self.global_step >= self.args.discriminator_start_step else f'INACTIVE (starts at step {self.args.discriminator_start_step})'}")
716
-
717
- print(f"Next epoch: {self.start_epoch}")
718
- print(f"Next step checkpoint at: step {((self.global_step // self.args.save_step_interval) + 1) * self.args.save_step_interval}")
719
- print(f"{'='*60}\n")
720
-
721
- # Double-check by creating a fresh scheduler and comparing
722
- if self.global_step > 0:
723
- temp_scheduler = CosineWarmupScheduler(
724
- self.optimizer_g,
725
- self.args.warmup_steps,
726
- self.total_steps,
727
- eta_min=1e-6,
728
- last_epoch=-1
729
- )
730
- # Step it to the current global step
731
- for _ in range(self.global_step):
732
- temp_scheduler.step()
733
- expected_lr = temp_scheduler.get_last_lr()[0]
734
- if abs(current_lr_g - expected_lr) > 1e-9:
735
- print(f"⚠️ Learning rate verification failed!")
736
- print(f" Expected: {expected_lr:.9f}")
737
- print(f" Got: {current_lr_g:.9f}")
738
- print(" The scheduler state might not be properly restored.")
739
- else:
740
- print(f"No checkpoint found at {checkpoint_path}, starting from scratch")
741
-
742
- def train(self):
743
- """Main training loop"""
744
- best_val_loss = float('inf')
745
-
746
- # Print training configuration
747
- if self.is_main_process():
748
- print(f"\n{'='*50}")
749
- print(f"Training Configuration:")
750
- print(f"{'='*50}")
751
- print(f"Total epochs: {self.args.num_epochs}")
752
- print(f"Steps per epoch: {len(self.train_loader)}")
753
- print(f"Total steps: {self.total_steps}")
754
- print(f"Warmup steps: {self.args.warmup_steps}")
755
- print(f"Discriminator starts at step: {self.args.discriminator_start_step}")
756
- print(f"Checkpoint saving:")
757
- print(f" - Every {self.args.save_interval} epochs")
758
- print(f" - Every {self.args.save_step_interval} steps")
759
- print(f" - Keep last {self.args.keep_last_n_steps} step checkpoints")
760
- if self.start_epoch > 0:
761
- print(f"RESUMING from epoch {self.start_epoch}, step {self.global_step}")
762
- print(f"{'='*50}\n")
763
-
764
- for epoch in range(self.start_epoch, self.args.num_epochs):
765
- # IMPORTANT: Set the epoch for distributed sampler when resuming
766
- # This ensures proper data shuffling across epochs
767
- if self.distributed and hasattr(self.train_loader.sampler, 'set_epoch'):
768
- self.train_loader.sampler.set_epoch(epoch)
769
-
770
- # Train
771
- train_metrics = self.train_epoch(epoch)
772
-
773
- # Validate
774
- val_metrics = self.validate(epoch)
775
-
776
- # Log epoch metrics
777
- if self.is_main_process():
778
- print(f"\nEpoch {epoch} Summary:")
779
- print(f"Train - Total: {train_metrics['total']:.4f}, Rec: {train_metrics['rec']:.4f}, "
780
- f"STFT: {train_metrics['stft']:.4f}, Mel: {train_metrics['mel']:.4f}, "
781
- f"Commit: {train_metrics['commit']:.4f}, Semantic: {train_metrics['semantic']:.4f}")
782
- if self.discriminator is not None:
783
- print(f" Gen: {train_metrics['gen']:.4f}, Feat: {train_metrics['feat']:.4f}, "
784
- f"Disc: {train_metrics['disc']:.4f}")
785
- print(f" Discriminator Status: {'Active' if self.global_step >= self.args.discriminator_start_step else f'Starting at step {self.args.discriminator_start_step}'}")
786
- print(f"Val - Total: {val_metrics['total']:.4f}, Rec: {val_metrics['rec']:.4f}, "
787
- f"STFT: {val_metrics['stft']:.4f}, Mel: {val_metrics['mel']:.4f}, "
788
- f"Commit: {val_metrics['commit']:.4f}, Semantic: {val_metrics['semantic']:.4f}")
789
- print(f"Current Step: {self.global_step}, Next step checkpoint at: {((self.global_step // self.args.save_step_interval) + 1) * self.args.save_step_interval}")
790
- print(f"Current LR: {self.scheduler_g.get_last_lr()[0]:.9f}")
791
-
792
- # Save checkpoint
793
- is_best = val_metrics['total'] < best_val_loss
794
- if is_best:
795
- best_val_loss = val_metrics['total']
796
- self.save_checkpoint(epoch, is_best)
797
-
798
- # Save final model
799
- if self.is_main_process():
800
- model_state = self.model.module.state_dict() if self.distributed else self.model.state_dict()
801
-
802
- final_path = os.path.join(self.args.output_dir, 'checkpoints', 'final.pth')
803
- torch.save({
804
- 'model_state_dict': model_state,
805
- 'config': self.config
806
- }, final_path)
807
-
808
- # Also save just the model weights in the format expected by the original code
809
- model_only_path = os.path.join(self.args.output_dir, 'model.pth')
810
- torch.save(model_state, model_only_path)
811
-
812
- # Copy config
813
- import shutil
814
- shutil.copy(self.args.config, os.path.join(self.args.output_dir, 'config.json'))
815
-
816
- # Cleanup
817
- if self.is_main_process():
818
- self.writer.close()
819
- if self.distributed:
820
- dist.destroy_process_group()
821
-
822
-
823
- def main():
824
- parser = argparse.ArgumentParser(description='Train Boson Audio Codec')
825
-
826
- # Data arguments
827
- parser.add_argument('--data_csv', type=str, required=True,
828
- help='Path to CSV file containing audio file paths')
829
- parser.add_argument('--config', type=str, default='config.json',
830
- help='Path to config JSON file')
831
-
832
- # Training argumentssss
833
- parser.add_argument('--batch_size', type=int, default=32,
834
- help='Batch size per GPU')
835
- parser.add_argument('--num_epochs', type=int, default=100,
836
- help='Number of training epochs')
837
- parser.add_argument('--learning_rate', type=float, default=1e-4,
838
- help='Initial learning rate')
839
- parser.add_argument('--weight_decay', type=float, default=0.01,
840
- help='Weight decay')
841
- parser.add_argument('--segment_duration', type=float, default=2.,
842
- help='Audio segment duration in seconds')
843
-
844
- # Scheduler arguments
845
- parser.add_argument('--warmup_steps', type=int, default=5000,
846
- help='Number of warmup steps for cosine scheduler')
847
-
848
- # Loss arguments
849
- parser.add_argument('--use_discriminator', action='store_true',
850
- help='Use adversarial training with discriminator')
851
- parser.add_argument('--discriminator_start_step', type=int, default=24_000,
852
- help='Start training discriminator after N steps')
853
- parser.add_argument('--disc_interval', type=int, default=1,
854
- help='Train discriminator every N steps')
855
-
856
- # System arguments
857
- parser.add_argument('--output_dir', type=str, default='outputs',
858
- help='Output directory for checkpoints and logs')
859
- parser.add_argument('--num_workers', type=int, default=16,
860
- help='Number of data loading workers')
861
- parser.add_argument('--seed', type=int, default=42,
862
- help='Random seed')
863
- parser.add_argument('--local_rank', type=int, default=0,
864
- help='Local rank for distributed training')
865
-
866
- # Logging arguments
867
- parser.add_argument('--log_interval', type=int, default=10,
868
- help='Log every N steps')
869
- parser.add_argument('--save_interval', type=int, default=1,
870
- help='Save checkpoint every N epochs')
871
- parser.add_argument('--save_step_interval', type=int, default=1000,
872
- help='Save checkpoint every N steps')
873
- parser.add_argument('--keep_last_n_steps', type=int, default=5,
874
- help='Keep only the last N step-based checkpoints (0 to keep all)')
875
-
876
- # Resume training
877
- parser.add_argument('--resume', action='store_true',
878
- help='Resume training from latest checkpoint')
879
-
880
- args = parser.parse_args()
881
-
882
- # Create output directory
883
- os.makedirs(args.output_dir, exist_ok=True)
884
-
885
- # Train
886
- trainer = BosonTrainer(args)
887
- trainer.train()
888
-
889
-
890
- if __name__ == '__main__':
891
- main()