Update train_boson_mixed_precision.py
Browse files
train_boson_mixed_precision.py
CHANGED
|
@@ -54,7 +54,7 @@ class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
|
|
| 54 |
|
| 55 |
class AudioDataset(Dataset):
|
| 56 |
"""Dataset for loading audio files from CSV"""
|
| 57 |
-
def __init__(self, csv_path, sample_rate=
|
| 58 |
self.df = pd.read_csv(csv_path)
|
| 59 |
self.sample_rate = sample_rate
|
| 60 |
self.segment_duration = segment_duration
|
|
@@ -241,7 +241,7 @@ class BosonTrainer:
|
|
| 241 |
rates=[], # No multi-rate discriminator
|
| 242 |
periods=[2, 3, 5, 7, 11],
|
| 243 |
fft_sizes=[2048, 1024, 512],
|
| 244 |
-
sample_rate=self.config['sample_rate'], #
|
| 245 |
).to(self.device)
|
| 246 |
|
| 247 |
if self.distributed:
|
|
@@ -926,7 +926,7 @@ def main():
|
|
| 926 |
# Loss arguments
|
| 927 |
parser.add_argument('--use_discriminator', action='store_true',
|
| 928 |
help='Use adversarial training with discriminator')
|
| 929 |
-
parser.add_argument('--discriminator_start_step', type=int, default=
|
| 930 |
help='Start training discriminator after N steps')
|
| 931 |
parser.add_argument('--disc_interval', type=int, default=1,
|
| 932 |
help='Train discriminator every N steps')
|
|
@@ -953,7 +953,7 @@ def main():
|
|
| 953 |
|
| 954 |
# Resume training
|
| 955 |
parser.add_argument('--resume', action='store_true',
|
| 956 |
-
help='Resume training from latest checkpoint')
|
| 957 |
|
| 958 |
args = parser.parse_args()
|
| 959 |
|
|
|
|
| 54 |
|
| 55 |
class AudioDataset(Dataset):
|
| 56 |
"""Dataset for loading audio files from CSV"""
|
| 57 |
+
def __init__(self, csv_path, sample_rate=24000, segment_duration=2.0, is_train=True):
|
| 58 |
self.df = pd.read_csv(csv_path)
|
| 59 |
self.sample_rate = sample_rate
|
| 60 |
self.segment_duration = segment_duration
|
|
|
|
| 241 |
rates=[], # No multi-rate discriminator
|
| 242 |
periods=[2, 3, 5, 7, 11],
|
| 243 |
fft_sizes=[2048, 1024, 512],
|
| 244 |
+
sample_rate=self.config['sample_rate'], # 24000
|
| 245 |
).to(self.device)
|
| 246 |
|
| 247 |
if self.distributed:
|
|
|
|
| 926 |
# Loss arguments
|
| 927 |
parser.add_argument('--use_discriminator', action='store_true',
|
| 928 |
help='Use adversarial training with discriminator')
|
| 929 |
+
parser.add_argument('--discriminator_start_step', type=int, default=30_000,
|
| 930 |
help='Start training discriminator after N steps')
|
| 931 |
parser.add_argument('--disc_interval', type=int, default=1,
|
| 932 |
help='Train discriminator every N steps')
|
|
|
|
| 953 |
|
| 954 |
# Resume training
|
| 955 |
parser.add_argument('--resume', action='store_true',
|
| 956 |
+
help='Resume training from latest checkpoint') # NOTE: you gotta change your desired checkpoint's name to latest.pth
|
| 957 |
|
| 958 |
args = parser.parse_args()
|
| 959 |
|