| |
| |
| |
| |
| |
| """Loading pretrained models. |
| """ |
|
|
| import logging |
| from pathlib import Path |
| import typing as tp |
|
|
| |
|
|
| import logging |
|
|
| from diffq import DiffQuantizer |
| import torch.hub |
|
|
| from .model import Demucs |
| from .tasnet_v2 import ConvTasNet |
| from .utils import set_state |
|
|
| from .hdemucs import HDemucs |
| from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError |
|
|
| logger = logging.getLogger(__name__) |
| ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" |
| REMOTE_ROOT = Path(__file__).parent / 'remote' |
|
|
| SOURCES = ["drums", "bass", "other", "vocals"] |
|
|
|
|
| def demucs_unittest(): |
| model = HDemucs(channels=4, sources=SOURCES) |
| return model |
|
|
|
|
| def add_model_flags(parser): |
| group = parser.add_mutually_exclusive_group(required=False) |
| group.add_argument("-s", "--sig", help="Locally trained XP signature.") |
| group.add_argument("-n", "--name", default="mdx_extra_q", |
| help="Pretrained model name or signature. Default is mdx_extra_q.") |
| parser.add_argument("--repo", type=Path, |
| help="Folder containing all pre-trained models for use with -n.") |
|
|
|
|
| def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]: |
| root: str = '' |
| models: tp.Dict[str, str] = {} |
| for line in remote_file_list.read_text().split('\n'): |
| line = line.strip() |
| if line.startswith('#'): |
| continue |
| elif line.startswith('root:'): |
| root = line.split(':', 1)[1].strip() |
| else: |
| sig = line.split('-', 1)[0] |
| assert sig not in models |
| models[sig] = ROOT_URL + root + line |
| return models |
|
|
| def get_model(name: str, |
| repo: tp.Optional[Path] = None): |
| """`name` must be a bag of models name or a pretrained signature |
| from the remote AWS model repo or the specified local repo if `repo` is not None. |
| """ |
| if name == 'demucs_unittest': |
| return demucs_unittest() |
| model_repo: ModelOnlyRepo |
| if repo is None: |
| models = _parse_remote_files(REMOTE_ROOT / 'files.txt') |
| model_repo = RemoteRepo(models) |
| bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) |
| else: |
| if not repo.is_dir(): |
| fatal(f"{repo} must exist and be a directory.") |
| model_repo = LocalRepo(repo) |
| bag_repo = BagOnlyRepo(repo, model_repo) |
| any_repo = AnyModelRepo(model_repo, bag_repo) |
| model = any_repo.get_model(name) |
| model.eval() |
| return model |
|
|
| def get_model_from_args(args): |
| """ |
| Load local model package or pre-trained model. |
| """ |
| return get_model(name=args.name, repo=args.repo) |
|
|
| logger = logging.getLogger(__name__) |
| ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/" |
|
|
| PRETRAINED_MODELS = { |
| 'demucs': 'e07c671f', |
| 'demucs48_hq': '28a1282c', |
| 'demucs_extra': '3646af93', |
| 'demucs_quantized': '07afea75', |
| 'tasnet': 'beb46fac', |
| 'tasnet_extra': 'df3777b2', |
| 'demucs_unittest': '09ebc15f', |
| } |
|
|
| SOURCES = ["drums", "bass", "other", "vocals"] |
|
|
|
|
| def get_url(name): |
| sig = PRETRAINED_MODELS[name] |
| return ROOT + name + "-" + sig[:8] + ".th" |
|
|
| def is_pretrained(name): |
| return name in PRETRAINED_MODELS |
|
|
|
|
| def load_pretrained(name): |
| if name == "demucs": |
| return demucs(pretrained=True) |
| elif name == "demucs48_hq": |
| return demucs(pretrained=True, hq=True, channels=48) |
| elif name == "demucs_extra": |
| return demucs(pretrained=True, extra=True) |
| elif name == "demucs_quantized": |
| return demucs(pretrained=True, quantized=True) |
| elif name == "demucs_unittest": |
| return demucs_unittest(pretrained=True) |
| elif name == "tasnet": |
| return tasnet(pretrained=True) |
| elif name == "tasnet_extra": |
| return tasnet(pretrained=True, extra=True) |
| else: |
| raise ValueError(f"Invalid pretrained name {name}") |
|
|
|
|
| def _load_state(name, model, quantizer=None): |
| url = get_url(name) |
| state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) |
| set_state(model, quantizer, state) |
| if quantizer: |
| quantizer.detach() |
|
|
|
|
| def demucs_unittest(pretrained=True): |
| model = Demucs(channels=4, sources=SOURCES) |
| if pretrained: |
| _load_state('demucs_unittest', model) |
| return model |
|
|
|
|
| def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64): |
| if not pretrained and (extra or quantized or hq): |
| raise ValueError("if extra or quantized is True, pretrained must be True.") |
| model = Demucs(sources=SOURCES, channels=channels) |
| if pretrained: |
| name = 'demucs' |
| if channels != 64: |
| name += str(channels) |
| quantizer = None |
| if sum([extra, quantized, hq]) > 1: |
| raise ValueError("Only one of extra, quantized, hq, can be True.") |
| if quantized: |
| quantizer = DiffQuantizer(model, group_size=8, min_size=1) |
| name += '_quantized' |
| if extra: |
| name += '_extra' |
| if hq: |
| name += '_hq' |
| _load_state(name, model, quantizer) |
| return model |
|
|
|
|
| def tasnet(pretrained=True, extra=False): |
| if not pretrained and extra: |
| raise ValueError("if extra is True, pretrained must be True.") |
| model = ConvTasNet(X=10, sources=SOURCES) |
| if pretrained: |
| name = 'tasnet' |
| if extra: |
| name = 'tasnet_extra' |
| _load_state(name, model) |
| return model |