| |
| |
| |
| |
| """Inspect the structure of .pt checkpoint files.""" |
|
|
| import sys |
|
|
| sys.path.insert(0, 'original') |
|
|
| import torch |
| from pathlib import Path |
|
|
|
|
| def inspect_checkpoint(path: Path) -> None: |
| """Print the structure of a checkpoint file.""" |
| print(f'\n{"=" * 60}') |
| print(f'File: {path.name} ({path.stat().st_size / 1024 / 1024:.1f} MB)') |
| print('=' * 60) |
|
|
| data = torch.load(path, map_location='cpu', weights_only=False) |
| print(f'Top-level type: {type(data).__name__}') |
|
|
| if isinstance(data, dict): |
| print(f'Keys: {list(data.keys())[:10]}') |
| for key, val in list(data.items())[:5]: |
| if isinstance(val, torch.Tensor): |
| print(f' {key}: Tensor {val.shape} {val.dtype}') |
| else: |
| print(f' {key}: {type(val).__name__}') |
| elif hasattr(data, 'state_dict'): |
| print('This is a full model object.') |
| sd = data.state_dict() |
| print(f'state_dict keys ({len(sd)}):') |
| for k, v in sd.items(): |
| print(f' {k}: {v.shape} {v.dtype}') |
| |
| for attr in ['max_seq_len', 'patch_size', 'data_seq_len']: |
| if hasattr(data, attr): |
| print(f' model.{attr} = {getattr(data, attr)}') |
| else: |
| print(f'Unexpected type: {type(data)}') |
|
|
|
|
| if __name__ == '__main__': |
| for pt_file in sorted(Path('.').glob('*_parameters.pt')): |
| inspect_checkpoint(pt_file) |
|
|