# /// script # requires-python = ">=3.13" # dependencies = ["torch", "timm"] # /// """Inspect the structure of .pt checkpoint files.""" import sys sys.path.insert(0, 'original') import torch from pathlib import Path def inspect_checkpoint(path: Path) -> None: """Print the structure of a checkpoint file.""" print(f'\n{"=" * 60}') print(f'File: {path.name} ({path.stat().st_size / 1024 / 1024:.1f} MB)') print('=' * 60) data = torch.load(path, map_location='cpu', weights_only=False) print(f'Top-level type: {type(data).__name__}') if isinstance(data, dict): print(f'Keys: {list(data.keys())[:10]}') for key, val in list(data.items())[:5]: if isinstance(val, torch.Tensor): print(f' {key}: Tensor {val.shape} {val.dtype}') else: print(f' {key}: {type(val).__name__}') elif hasattr(data, 'state_dict'): print('This is a full model object.') sd = data.state_dict() print(f'state_dict keys ({len(sd)}):') for k, v in sd.items(): print(f' {k}: {v.shape} {v.dtype}') # Print model attributes for attr in ['max_seq_len', 'patch_size', 'data_seq_len']: if hasattr(data, attr): print(f' model.{attr} = {getattr(data, attr)}') else: print(f'Unexpected type: {type(data)}') if __name__ == '__main__': for pt_file in sorted(Path('.').glob('*_parameters.pt')): inspect_checkpoint(pt_file)