| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Utility to export a training checkpoint to a lightweight release checkpoint. |
| | """ |
| |
|
| | from pathlib import Path |
| | import typing as tp |
| |
|
| | from omegaconf import OmegaConf, DictConfig |
| | import torch |
| |
|
| |
|
| | def _clean_lm_cfg(cfg: DictConfig): |
| | OmegaConf.set_struct(cfg, False) |
| | |
| | |
| | cfg['transformer_lm']['card'] = 2048 |
| | cfg['transformer_lm']['n_q'] = 4 |
| | |
| | bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', |
| | 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] |
| | for name in bad_params: |
| | del cfg['transformer_lm'][name] |
| | OmegaConf.set_struct(cfg, True) |
| | return cfg |
| |
|
| |
|
| | def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): |
| | sig = Path(checkpoint_path).parent.name |
| | assert len(sig) == 8, "Not a valid Dora signature" |
| | pkg = torch.load(checkpoint_path, 'cpu') |
| | new_pkg = { |
| | 'best_state': pkg['ema']['state']['model'], |
| | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), |
| | } |
| | out_file = Path(out_folder) / f'{sig}.th' |
| | torch.save(new_pkg, out_file) |
| | return out_file |
| |
|
| |
|
| | def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): |
| | sig = Path(checkpoint_path).parent.name |
| | assert len(sig) == 8, "Not a valid Dora signature" |
| | pkg = torch.load(checkpoint_path, 'cpu') |
| | new_pkg = { |
| | 'best_state': pkg['fsdp_best_state']['model'], |
| | 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) |
| | } |
| | out_file = Path(out_folder) / f'{sig}.th' |
| | torch.save(new_pkg, out_file) |
| | return out_file |
| |
|