| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Utility to export a training checkpoint to a lightweight release checkpoint. |
| | """ |
| |
|
| | from pathlib import Path |
| | import typing as tp |
| |
|
| | from omegaconf import OmegaConf |
| | import torch |
| |
|
| | from audiocraft import __version__ |
| |
|
| |
|
| | def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): |
| | """Export only the best state from the given EnCodec checkpoint. This |
| | should be used if you trained your own EnCodec model. |
| | """ |
| | pkg = torch.load(checkpoint_path, 'cpu') |
| | new_pkg = { |
| | 'best_state': pkg['best_state']['model'], |
| | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), |
| | 'version': __version__, |
| | 'exported': True, |
| | } |
| | Path(out_file).parent.mkdir(exist_ok=True, parents=True) |
| | torch.save(new_pkg, out_file) |
| | return out_file |
| |
|
| |
|
| | def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): |
| | """Export a compression model (potentially EnCodec) from a pretrained model. |
| | This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. |
| | Do not include the //pretrained/ prefix. For instance if you trained a model |
| | with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. |
| | |
| | In that case, this will not actually include a copy of the model, simply the reference |
| | to the model used. |
| | """ |
| | if Path(pretrained_encodec).exists(): |
| | pkg = torch.load(pretrained_encodec) |
| | assert 'best_state' in pkg |
| | assert 'xp.cfg' in pkg |
| | assert 'version' in pkg |
| | assert 'exported' in pkg |
| | else: |
| | pkg = { |
| | 'pretrained': pretrained_encodec, |
| | 'exported': True, |
| | 'version': __version__, |
| | } |
| | Path(out_file).parent.mkdir(exist_ok=True, parents=True) |
| | torch.save(pkg, out_file) |
| |
|
| |
|
| | def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): |
| | """Export only the best state from the given MusicGen or AudioGen checkpoint. |
| | """ |
| | pkg = torch.load(checkpoint_path, 'cpu') |
| | if pkg['fsdp_best_state']: |
| | best_state = pkg['fsdp_best_state']['model'] |
| | else: |
| | assert pkg['best_state'] |
| | best_state = pkg['best_state']['model'] |
| | new_pkg = { |
| | 'best_state': best_state, |
| | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), |
| | 'version': __version__, |
| | 'exported': True, |
| | } |
| |
|
| | Path(out_file).parent.mkdir(exist_ok=True, parents=True) |
| | torch.save(new_pkg, out_file) |
| | return out_file |
| |
|