| |
| |
| |
| |
| |
|
|
| import argparse |
| import os |
| from pathlib import Path |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.") |
| default_raw = None |
| default_musdb = None |
| if 'DEMUCS_RAW' in os.environ: |
| default_raw = Path(os.environ['DEMUCS_RAW']) |
| if 'DEMUCS_MUSDB' in os.environ: |
| default_musdb = Path(os.environ['DEMUCS_MUSDB']) |
| parser.add_argument( |
| "--raw", |
| type=Path, |
| default=default_raw, |
| help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.") |
| parser.add_argument("--no_raw", action="store_const", const=None, dest="raw") |
| parser.add_argument("-m", |
| "--musdb", |
| type=Path, |
| default=default_musdb, |
| help="Path to musdb root") |
| parser.add_argument("--is_wav", action="store_true", |
| help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).") |
| parser.add_argument("--metadata", type=Path, default=Path("metadata/"), |
| help="Folder where metadata information is stored.") |
| parser.add_argument("--wav", type=Path, |
| help="Path to a wav dataset. This should contain a 'train' and a 'valid' " |
| "subfolder.") |
| parser.add_argument("--samplerate", type=int, default=44100) |
| parser.add_argument("--audio_channels", type=int, default=2) |
| parser.add_argument("--samples", |
| default=44100 * 10, |
| type=int, |
| help="number of samples to feed in") |
| parser.add_argument("--data_stride", |
| default=44100, |
| type=int, |
| help="Stride for chunks, shorter = longer epochs") |
| parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers") |
| parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers") |
| parser.add_argument("-d", |
| "--device", |
| help="Device to train on, default is cuda if available else cpu") |
| parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.") |
| parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file") |
| parser.add_argument("--test", help="Just run the test pipeline + one validation. " |
| "This should be a filename relative to the models/ folder.") |
| parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, " |
| "on a pretrained model. ") |
|
|
| parser.add_argument("--rank", default=0, type=int) |
| parser.add_argument("--world_size", default=1, type=int) |
| parser.add_argument("--master") |
|
|
| parser.add_argument("--checkpoints", |
| type=Path, |
| default=Path("checkpoints"), |
| help="Folder where to store checkpoints etc") |
| parser.add_argument("--evals", |
| type=Path, |
| default=Path("evals"), |
| help="Folder where to store evals and waveforms") |
| parser.add_argument("--save", |
| action="store_true", |
| help="Save estimated for the test set waveforms") |
| parser.add_argument("--logs", |
| type=Path, |
| default=Path("logs"), |
| help="Folder where to store logs") |
| parser.add_argument("--models", |
| type=Path, |
| default=Path("models"), |
| help="Folder where to store trained models") |
| parser.add_argument("-R", |
| "--restart", |
| action='store_true', |
| help='Restart training, ignoring previous run') |
|
|
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs") |
| parser.add_argument("-r", |
| "--repeat", |
| type=int, |
| default=2, |
| help="Repeat the train set, longer epochs") |
| parser.add_argument("-b", "--batch_size", type=int, default=64) |
| parser.add_argument("--lr", type=float, default=3e-4) |
| parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1") |
| parser.add_argument("--init", help="Initialize from a pre-trained model.") |
|
|
| |
| parser.add_argument("--no_augment", |
| action="store_false", |
| dest="augment", |
| default=True, |
| help="No basic data augmentation.") |
| parser.add_argument("--repitch", type=float, default=0.2, |
| help="Probability to do tempo/pitch change") |
| parser.add_argument("--max_tempo", type=float, default=12, |
| help="Maximum relative tempo change in %% when using repitch.") |
|
|
| parser.add_argument("--remix_group_size", |
| type=int, |
| default=4, |
| help="Shuffle sources using group of this size. Useful to somewhat " |
| "replicate multi-gpu training " |
| "on less GPUs.") |
| parser.add_argument("--shifts", |
| type=int, |
| default=10, |
| help="Number of random shifts used for the shift trick.") |
| parser.add_argument("--overlap", |
| type=float, |
| default=0.25, |
| help="Overlap when --split_valid is passed.") |
|
|
| |
| parser.add_argument("--growth", |
| type=float, |
| default=2., |
| help="Number of channels between two layers will increase by this factor") |
| parser.add_argument("--depth", |
| type=int, |
| default=6, |
| help="Number of layers for the encoder and decoder") |
| parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM") |
| parser.add_argument("--channels", |
| type=int, |
| default=64, |
| help="Number of channels for the first encoder layer") |
| parser.add_argument("--kernel_size", |
| type=int, |
| default=8, |
| help="Kernel size for the (transposed) convolutions") |
| parser.add_argument("--conv_stride", |
| type=int, |
| default=4, |
| help="Stride for the (transposed) convolutions") |
| parser.add_argument("--context", |
| type=int, |
| default=3, |
| help="Context size for the decoder convolutions " |
| "before the transposed convolutions") |
| parser.add_argument("--rescale", |
| type=float, |
| default=0.1, |
| help="Initial weight rescale reference") |
| parser.add_argument("--no_resample", action="store_false", |
| default=True, dest="resample", |
| help="No Resampling of the input/output x2") |
| parser.add_argument("--no_glu", |
| action="store_false", |
| default=True, |
| dest="glu", |
| help="Replace all GLUs by ReLUs") |
| parser.add_argument("--no_rewrite", |
| action="store_false", |
| default=True, |
| dest="rewrite", |
| help="No 1x1 rewrite convolutions") |
| parser.add_argument("--normalize", action="store_true") |
| parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True) |
|
|
| |
| parser.add_argument("--tasnet", action="store_true") |
| parser.add_argument("--split_valid", |
| action="store_true", |
| help="Predict chunks by chunks for valid and test. Required for tasnet") |
| parser.add_argument("--X", type=int, default=8) |
|
|
| |
| parser.add_argument("--show", |
| action="store_true", |
| help="Show model architecture, size and exit") |
| parser.add_argument("--save_model", action="store_true", |
| help="Skip traning, just save final model " |
| "for the current checkpoint value.") |
| parser.add_argument("--save_state", |
| help="Skip training, just save state " |
| "for the current checkpoint value. You should " |
| "provide a model name as argument.") |
|
|
| |
| parser.add_argument("--q-min-size", type=float, default=1, |
| help="Only quantize layers over this size (in MB)") |
| parser.add_argument( |
| "--qat", type=int, help="If provided, use QAT training with that many bits.") |
|
|
| parser.add_argument("--diffq", type=float, default=0) |
| parser.add_argument( |
| "--ms-target", type=float, default=162, |
| help="Model size target in MB, when using DiffQ. Best model will be kept " |
| "only if it is smaller than this target.") |
|
|
| return parser |
|
|
|
|
| def get_name(parser, args): |
| """ |
| Return the name of an experiment given the args. Some parameters are ignored, |
| for instance --workers, as they do not impact the final result. |
| """ |
| ignore_args = set([ |
| "checkpoints", |
| "deterministic", |
| "eval", |
| "evals", |
| "eval_cpu", |
| "eval_workers", |
| "logs", |
| "master", |
| "rank", |
| "restart", |
| "save", |
| "save_model", |
| "save_state", |
| "show", |
| "workers", |
| "world_size", |
| ]) |
| parts = [] |
| name_args = dict(args.__dict__) |
| for name, value in name_args.items(): |
| if name in ignore_args: |
| continue |
| if value != parser.get_default(name): |
| if isinstance(value, Path): |
| parts.append(f"{name}={value.name}") |
| else: |
| parts.append(f"{name}={value}") |
| if parts: |
| name = " ".join(parts) |
| else: |
| name = "default" |
| return name |
|
|