| |
| |
| |
| |
| |
|
|
| """ |
| Utility to export a training checkpoint to a lightweight release checkpoint. |
| """ |
|
|
| from pathlib import Path |
| import typing as tp |
|
|
| from omegaconf import OmegaConf |
| import torch |
|
|
| from audiocraft import __version__ |
|
|
|
|
| def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): |
| """Export only the best state from the given EnCodec checkpoint. This |
| should be used if you trained your own EnCodec model. |
| """ |
| pkg = torch.load(checkpoint_path, 'cpu') |
| new_pkg = { |
| 'best_state': pkg['best_state']['model'], |
| 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), |
| 'version': __version__, |
| 'exported': True, |
| } |
| Path(out_file).parent.mkdir(exist_ok=True, parents=True) |
| torch.save(new_pkg, out_file) |
| return out_file |
|
|
|
|
| def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): |
| """Export a compression model (potentially EnCodec) from a pretrained model. |
| This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. |
| Do not include the //pretrained/ prefix. For instance if you trained a model |
| with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. |
| |
| In that case, this will not actually include a copy of the model, simply the reference |
| to the model used. |
| """ |
| if Path(pretrained_encodec).exists(): |
| pkg = torch.load(pretrained_encodec) |
| assert 'best_state' in pkg |
| assert 'xp.cfg' in pkg |
| assert 'version' in pkg |
| assert 'exported' in pkg |
| else: |
| pkg = { |
| 'pretrained': pretrained_encodec, |
| 'exported': True, |
| 'version': __version__, |
| } |
| Path(out_file).parent.mkdir(exist_ok=True, parents=True) |
| torch.save(pkg, out_file) |
|
|
|
|
| def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): |
| """Export only the best state from the given MusicGen or AudioGen checkpoint. |
| """ |
| pkg = torch.load(checkpoint_path, 'cpu') |
| if pkg['fsdp_best_state']: |
| best_state = pkg['fsdp_best_state']['model'] |
| else: |
| assert pkg['best_state'] |
| best_state = pkg['best_state']['model'] |
| model=pkg['model'] |
| new_pkg = { |
| 'best_state': model, |
| 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), |
| 'version': __version__, |
| 'exported': True, |
| } |
| |
| |
| |
| |
| |
| |
|
|
| Path(out_file).parent.mkdir(exist_ok=True, parents=True) |
| torch.save(new_pkg, out_file) |
| return out_file |
|
|