| |
| |
| |
| |
| |
|
|
| import json |
| import math |
| import os |
| import sys |
| import time |
| from dataclasses import dataclass, field |
|
|
| import torch as th |
| from torch import distributed, nn |
| from torch.nn.parallel.distributed import DistributedDataParallel |
|
|
| from .augment import FlipChannels, FlipSign, Remix, Scale, Shift |
| from .compressed import get_compressed_datasets |
| from .model import Demucs |
| from .parser import get_name, get_parser |
| from .raw import Rawset |
| from .repitch import RepitchedWrapper |
| from .pretrained import load_pretrained, SOURCES |
| from .tasnet import ConvTasNet |
| from .test import evaluate |
| from .train import train_model, validate_model |
| from .utils import (human_seconds, load_model, save_model, get_state, |
| save_state, sizeof_fmt, get_quantizer) |
| from .wav import get_wav_datasets, get_musdb_wav_datasets |
|
|
|
|
| @dataclass |
| class SavedState: |
| metrics: list = field(default_factory=list) |
| last_state: dict = None |
| best_state: dict = None |
| optimizer: dict = None |
|
|
|
|
| def main(): |
| parser = get_parser() |
| args = parser.parse_args() |
| name = get_name(parser, args) |
| print(f"Experiment {name}") |
|
|
| if args.musdb is None and args.rank == 0: |
| print( |
| "You must provide the path to the MusDB dataset with the --musdb flag. " |
| "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", |
| file=sys.stderr) |
| sys.exit(1) |
|
|
| eval_folder = args.evals / name |
| eval_folder.mkdir(exist_ok=True, parents=True) |
| args.logs.mkdir(exist_ok=True) |
| metrics_path = args.logs / f"{name}.json" |
| eval_folder.mkdir(exist_ok=True, parents=True) |
| args.checkpoints.mkdir(exist_ok=True, parents=True) |
| args.models.mkdir(exist_ok=True, parents=True) |
|
|
| if args.device is None: |
| device = "cpu" |
| if th.cuda.is_available(): |
| device = "cuda" |
| else: |
| device = args.device |
|
|
| th.manual_seed(args.seed) |
| |
| |
| os.environ["OMP_NUM_THREADS"] = "1" |
| os.environ["MKL_NUM_THREADS"] = "1" |
|
|
| if args.world_size > 1: |
| if device != "cuda" and args.rank == 0: |
| print("Error: distributed training is only available with cuda device", file=sys.stderr) |
| sys.exit(1) |
| th.cuda.set_device(args.rank % th.cuda.device_count()) |
| distributed.init_process_group(backend="nccl", |
| init_method="tcp://" + args.master, |
| rank=args.rank, |
| world_size=args.world_size) |
|
|
| checkpoint = args.checkpoints / f"{name}.th" |
| checkpoint_tmp = args.checkpoints / f"{name}.th.tmp" |
| if args.restart and checkpoint.exists() and args.rank == 0: |
| checkpoint.unlink() |
|
|
| if args.test or args.test_pretrained: |
| args.epochs = 1 |
| args.repeat = 0 |
| if args.test: |
| model = load_model(args.models / args.test) |
| else: |
| model = load_pretrained(args.test_pretrained) |
| elif args.tasnet: |
| model = ConvTasNet(audio_channels=args.audio_channels, |
| samplerate=args.samplerate, X=args.X, |
| segment_length=4 * args.samples, |
| sources=SOURCES) |
| else: |
| model = Demucs( |
| audio_channels=args.audio_channels, |
| channels=args.channels, |
| context=args.context, |
| depth=args.depth, |
| glu=args.glu, |
| growth=args.growth, |
| kernel_size=args.kernel_size, |
| lstm_layers=args.lstm_layers, |
| rescale=args.rescale, |
| rewrite=args.rewrite, |
| stride=args.conv_stride, |
| resample=args.resample, |
| normalize=args.normalize, |
| samplerate=args.samplerate, |
| segment_length=4 * args.samples, |
| sources=SOURCES, |
| ) |
| model.to(device) |
| if args.init: |
| model.load_state_dict(load_pretrained(args.init).state_dict()) |
|
|
| if args.show: |
| print(model) |
| size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters())) |
| print(f"Model size {size}") |
| return |
|
|
| try: |
| saved = th.load(checkpoint, map_location='cpu') |
| except IOError: |
| saved = SavedState() |
|
|
| optimizer = th.optim.Adam(model.parameters(), lr=args.lr) |
|
|
| quantizer = None |
| quantizer = get_quantizer(model, args, optimizer) |
|
|
| if saved.last_state is not None: |
| model.load_state_dict(saved.last_state, strict=False) |
| if saved.optimizer is not None: |
| optimizer.load_state_dict(saved.optimizer) |
|
|
| model_name = f"{name}.th" |
| if args.save_model: |
| if args.rank == 0: |
| model.to("cpu") |
| model.load_state_dict(saved.best_state) |
| save_model(model, quantizer, args, args.models / model_name) |
| return |
| elif args.save_state: |
| model_name = f"{args.save_state}.th" |
| if args.rank == 0: |
| model.to("cpu") |
| model.load_state_dict(saved.best_state) |
| state = get_state(model, quantizer) |
| save_state(state, args.models / model_name) |
| return |
|
|
| if args.rank == 0: |
| done = args.logs / f"{name}.done" |
| if done.exists(): |
| done.unlink() |
|
|
| augment = [Shift(args.data_stride)] |
| if args.augment: |
| augment += [FlipSign(), FlipChannels(), Scale(), |
| Remix(group_size=args.remix_group_size)] |
| augment = nn.Sequential(*augment).to(device) |
| print("Agumentation pipeline:", augment) |
|
|
| if args.mse: |
| criterion = nn.MSELoss() |
| else: |
| criterion = nn.L1Loss() |
|
|
| |
| |
| |
| samples = model.valid_length(args.samples) |
| print(f"Number of training samples adjusted to {samples}") |
| samples = samples + args.data_stride |
| if args.repitch: |
| |
| |
| samples = math.ceil(samples / (1 - 0.01 * args.max_tempo)) |
|
|
| args.metadata.mkdir(exist_ok=True, parents=True) |
| if args.raw: |
| train_set = Rawset(args.raw / "train", |
| samples=samples, |
| channels=args.audio_channels, |
| streams=range(1, len(model.sources) + 1), |
| stride=args.data_stride) |
|
|
| valid_set = Rawset(args.raw / "valid", channels=args.audio_channels) |
| elif args.wav: |
| train_set, valid_set = get_wav_datasets(args, samples, model.sources) |
| elif args.is_wav: |
| train_set, valid_set = get_musdb_wav_datasets(args, samples, model.sources) |
| else: |
| train_set, valid_set = get_compressed_datasets(args, samples) |
|
|
| if args.repitch: |
| train_set = RepitchedWrapper( |
| train_set, |
| proba=args.repitch, |
| max_tempo=args.max_tempo) |
|
|
| best_loss = float("inf") |
| for epoch, metrics in enumerate(saved.metrics): |
| print(f"Epoch {epoch:03d}: " |
| f"train={metrics['train']:.8f} " |
| f"valid={metrics['valid']:.8f} " |
| f"best={metrics['best']:.4f} " |
| f"ms={metrics.get('true_model_size', 0):.2f}MB " |
| f"cms={metrics.get('compressed_model_size', 0):.2f}MB " |
| f"duration={human_seconds(metrics['duration'])}") |
| best_loss = metrics['best'] |
|
|
| if args.world_size > 1: |
| dmodel = DistributedDataParallel(model, |
| device_ids=[th.cuda.current_device()], |
| output_device=th.cuda.current_device()) |
| else: |
| dmodel = model |
|
|
| for epoch in range(len(saved.metrics), args.epochs): |
| begin = time.time() |
| model.train() |
| train_loss, model_size = train_model( |
| epoch, train_set, dmodel, criterion, optimizer, augment, |
| quantizer=quantizer, |
| batch_size=args.batch_size, |
| device=device, |
| repeat=args.repeat, |
| seed=args.seed, |
| diffq=args.diffq, |
| workers=args.workers, |
| world_size=args.world_size) |
| model.eval() |
| valid_loss = validate_model( |
| epoch, valid_set, model, criterion, |
| device=device, |
| rank=args.rank, |
| split=args.split_valid, |
| overlap=args.overlap, |
| world_size=args.world_size) |
|
|
| ms = 0 |
| cms = 0 |
| if quantizer and args.rank == 0: |
| ms = quantizer.true_model_size() |
| cms = quantizer.compressed_model_size(num_workers=min(40, args.world_size * 10)) |
|
|
| duration = time.time() - begin |
| if valid_loss < best_loss and ms <= args.ms_target: |
| best_loss = valid_loss |
| saved.best_state = { |
| key: value.to("cpu").clone() |
| for key, value in model.state_dict().items() |
| } |
|
|
| saved.metrics.append({ |
| "train": train_loss, |
| "valid": valid_loss, |
| "best": best_loss, |
| "duration": duration, |
| "model_size": model_size, |
| "true_model_size": ms, |
| "compressed_model_size": cms, |
| }) |
| if args.rank == 0: |
| json.dump(saved.metrics, open(metrics_path, "w")) |
|
|
| saved.last_state = model.state_dict() |
| saved.optimizer = optimizer.state_dict() |
| if args.rank == 0 and not args.test: |
| th.save(saved, checkpoint_tmp) |
| checkpoint_tmp.rename(checkpoint) |
|
|
| print(f"Epoch {epoch:03d}: " |
| f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} ms={ms:.2f}MB " |
| f"cms={cms:.2f}MB " |
| f"duration={human_seconds(duration)}") |
|
|
| if args.world_size > 1: |
| distributed.barrier() |
|
|
| del dmodel |
| model.load_state_dict(saved.best_state) |
| if args.eval_cpu: |
| device = "cpu" |
| model.to(device) |
| model.eval() |
| evaluate(model, args.musdb, eval_folder, |
| is_wav=args.is_wav, |
| rank=args.rank, |
| world_size=args.world_size, |
| device=device, |
| save=args.save, |
| split=args.split_valid, |
| shifts=args.shifts, |
| overlap=args.overlap, |
| workers=args.eval_workers) |
| model.to("cpu") |
| if args.rank == 0: |
| if not (args.test or args.test_pretrained): |
| save_model(model, quantizer, args, args.models / model_name) |
| print("done") |
| done.write_text("done") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|