| |
| |
| |
| |
| |
|
|
| """ |
| Utility functions to load from the checkpoints. |
| Each checkpoint is a torch.saved dict with the following keys: |
| - 'xp.cfg': the hydra config as dumped during training. This should be used |
| to rebuild the object using the audiocraft.models.builders functions, |
| - 'model_best_state': a readily loadable best state for the model, including |
| the conditioner. The model obtained from `xp.cfg` should be compatible |
| with this state dict. In the case of a LM, the encodec model would not be |
| bundled along but instead provided separately. |
| |
| Those functions also support loading from a remote location with the Torch Hub API. |
| They also support overriding some parameters, in particular the device and dtype |
| of the returned model. |
| """ |
|
|
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
| import typing as tp |
| import os |
|
|
| from omegaconf import OmegaConf |
| import torch |
|
|
| from . import builders |
|
|
|
|
| HF_MODEL_CHECKPOINTS_MAP = { |
| "small": "facebook/musicgen-small", |
| "medium": "facebook/musicgen-medium", |
| "large": "facebook/musicgen-large", |
| "melody": "facebook/musicgen-melody", |
| } |
|
|
|
|
| def _get_state_dict( |
| file_or_url_or_id: tp.Union[Path, str], |
| filename: tp.Optional[str] = None, |
| device='cpu', |
| cache_dir: tp.Optional[str] = None, |
| ): |
| |
| file_or_url_or_id = str(file_or_url_or_id) |
| assert isinstance(file_or_url_or_id, str) |
|
|
| if os.path.isfile(file_or_url_or_id): |
| return torch.load(file_or_url_or_id, map_location=device) |
| |
| if os.path.isdir(file_or_url_or_id): |
| file = f"{file_or_url_or_id}/{filename}" |
| return torch.load(file, map_location=device) |
|
|
| elif file_or_url_or_id.startswith('https://'): |
| return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True) |
|
|
| elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP: |
| assert filename is not None, "filename needs to be defined if using HF checkpoints" |
|
|
| repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id] |
| file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir) |
| return torch.load(file, map_location=device) |
|
|
| else: |
| raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.") |
|
|
|
|
| def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): |
| pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) |
| cfg = OmegaConf.create(pkg['xp.cfg']) |
| cfg.device = str(device) |
| model = builders.get_compression_model(cfg) |
| model.load_state_dict(pkg['best_state']) |
| model.eval() |
| return model |
|
|
|
|
| def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): |
| pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) |
| cfg = OmegaConf.create(pkg['xp.cfg']) |
| cfg.device = str(device) |
| if cfg.device == 'cpu': |
| cfg.transformer_lm.memory_efficient = False |
| cfg.transformer_lm.custom = True |
| cfg.dtype = 'float32' |
| else: |
| cfg.dtype = 'float16' |
| model = builders.get_lm_model(cfg) |
| model.load_state_dict(pkg['best_state']) |
| model.eval() |
| model.cfg = cfg |
| return model |
|
|