File size: 1,506 Bytes
32fa850 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | # /// script
# requires-python = ">=3.13"
# dependencies = ["torch", "timm"]
# ///
"""Inspect the structure of .pt checkpoint files."""
import sys
sys.path.insert(0, 'original')
import torch
from pathlib import Path
def inspect_checkpoint(path: Path) -> None:
"""Print the structure of a checkpoint file."""
print(f'\n{"=" * 60}')
print(f'File: {path.name} ({path.stat().st_size / 1024 / 1024:.1f} MB)')
print('=' * 60)
data = torch.load(path, map_location='cpu', weights_only=False)
print(f'Top-level type: {type(data).__name__}')
if isinstance(data, dict):
print(f'Keys: {list(data.keys())[:10]}')
for key, val in list(data.items())[:5]:
if isinstance(val, torch.Tensor):
print(f' {key}: Tensor {val.shape} {val.dtype}')
else:
print(f' {key}: {type(val).__name__}')
elif hasattr(data, 'state_dict'):
print('This is a full model object.')
sd = data.state_dict()
print(f'state_dict keys ({len(sd)}):')
for k, v in sd.items():
print(f' {k}: {v.shape} {v.dtype}')
# Print model attributes
for attr in ['max_seq_len', 'patch_size', 'data_seq_len']:
if hasattr(data, attr):
print(f' model.{attr} = {getattr(data, attr)}')
else:
print(f'Unexpected type: {type(data)}')
if __name__ == '__main__':
for pt_file in sorted(Path('.').glob('*_parameters.pt')):
inspect_checkpoint(pt_file)
|