| | """Utility for loading the models from HF.""" |
| | from pathlib import Path |
| | import typing as tp |
| |
|
| | from omegaconf import OmegaConf |
| | from huggingface_hub import hf_hub_download |
| | import torch |
| |
|
| | from audiocraft.models import builders, MusicGen |
| |
|
| | MODEL_CHECKPOINTS_MAP = { |
| | "small": "facebook/musicgen-small", |
| | "medium": "facebook/musicgen-medium", |
| | "large": "facebook/musicgen-large", |
| | "melody": "facebook/musicgen-melody", |
| | } |
| |
|
| |
|
| | def _get_state_dict(file_or_url: tp.Union[Path, str], |
| | filename="state_dict.bin", device='cpu'): |
| | |
| | print("loading", file_or_url, filename) |
| | file_or_url = str(file_or_url) |
| | assert isinstance(file_or_url, str) |
| | return torch.load( |
| | hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device) |
| |
|
| |
|
| | def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'): |
| | pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin") |
| | cfg = OmegaConf.create(pkg['xp.cfg']) |
| | cfg.device = str(device) |
| | model = builders.get_compression_model(cfg) |
| | model.load_state_dict(pkg['best_state']) |
| | model.eval() |
| | model.cfg = cfg |
| | return model |
| |
|
| |
|
| | def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'): |
| | pkg = _get_state_dict(file_or_url) |
| | cfg = OmegaConf.create(pkg['xp.cfg']) |
| | cfg.device = str(device) |
| | if cfg.device == 'cpu': |
| | cfg.transformer_lm.memory_efficient = False |
| | cfg.transformer_lm.custom = True |
| | cfg.dtype = 'float32' |
| | else: |
| | cfg.dtype = 'float16' |
| | model = builders.get_lm_model(cfg) |
| | model.load_state_dict(pkg['best_state']) |
| | model.eval() |
| | model.cfg = cfg |
| | return model |
| |
|
| |
|
| | def get_pretrained(name: str = 'small', device='cuda'): |
| | model_id = MODEL_CHECKPOINTS_MAP[name] |
| | compression_model = load_compression_model(model_id, device=device) |
| | lm = load_lm_model(model_id, device=device) |
| | return MusicGen(name, compression_model, lm) |
| |
|