Respair commited on
Commit
6c92e76
·
verified ·
1 Parent(s): 04a2208

Update train_boson_mixed_precision.py

Browse files
Files changed (1) hide show
  1. train_boson_mixed_precision.py +4 -4
train_boson_mixed_precision.py CHANGED
@@ -54,7 +54,7 @@ class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
54
 
55
  class AudioDataset(Dataset):
56
  """Dataset for loading audio files from CSV"""
57
- def __init__(self, csv_path, sample_rate=44100, segment_duration=2.0, is_train=True):
58
  self.df = pd.read_csv(csv_path)
59
  self.sample_rate = sample_rate
60
  self.segment_duration = segment_duration
@@ -241,7 +241,7 @@ class BosonTrainer:
241
  rates=[], # No multi-rate discriminator
242
  periods=[2, 3, 5, 7, 11],
243
  fft_sizes=[2048, 1024, 512],
244
- sample_rate=self.config['sample_rate'], # 44100
245
  ).to(self.device)
246
 
247
  if self.distributed:
@@ -926,7 +926,7 @@ def main():
926
  # Loss arguments
927
  parser.add_argument('--use_discriminator', action='store_true',
928
  help='Use adversarial training with discriminator')
929
- parser.add_argument('--discriminator_start_step', type=int, default=25_000,
930
  help='Start training discriminator after N steps')
931
  parser.add_argument('--disc_interval', type=int, default=1,
932
  help='Train discriminator every N steps')
@@ -953,7 +953,7 @@ def main():
953
 
954
  # Resume training
955
  parser.add_argument('--resume', action='store_true',
956
- help='Resume training from latest checkpoint')
957
 
958
  args = parser.parse_args()
959
 
 
54
 
55
  class AudioDataset(Dataset):
56
  """Dataset for loading audio files from CSV"""
57
+ def __init__(self, csv_path, sample_rate=24000, segment_duration=2.0, is_train=True):
58
  self.df = pd.read_csv(csv_path)
59
  self.sample_rate = sample_rate
60
  self.segment_duration = segment_duration
 
241
  rates=[], # No multi-rate discriminator
242
  periods=[2, 3, 5, 7, 11],
243
  fft_sizes=[2048, 1024, 512],
244
+ sample_rate=self.config['sample_rate'], # 24000
245
  ).to(self.device)
246
 
247
  if self.distributed:
 
926
  # Loss arguments
927
  parser.add_argument('--use_discriminator', action='store_true',
928
  help='Use adversarial training with discriminator')
929
+ parser.add_argument('--discriminator_start_step', type=int, default=30_000,
930
  help='Start training discriminator after N steps')
931
  parser.add_argument('--disc_interval', type=int, default=1,
932
  help='Train discriminator every N steps')
 
953
 
954
  # Resume training
955
  parser.add_argument('--resume', action='store_true',
956
+ help='Resume training from latest checkpoint') # NOTE: you gotta change your desired checkpoint's name to latest.pth
957
 
958
  args = parser.parse_args()
959