AnomalyBERT / inspect_pt.py
Jongsu Liam Kim
feat: convert .pt checkpoints to HuggingFace-style state_dict + config
32fa850
# /// script
# requires-python = ">=3.13"
# dependencies = ["torch", "timm"]
# ///
"""Inspect the structure of .pt checkpoint files."""
import sys
sys.path.insert(0, 'original')
import torch
from pathlib import Path
def inspect_checkpoint(path: Path) -> None:
"""Print the structure of a checkpoint file."""
print(f'\n{"=" * 60}')
print(f'File: {path.name} ({path.stat().st_size / 1024 / 1024:.1f} MB)')
print('=' * 60)
data = torch.load(path, map_location='cpu', weights_only=False)
print(f'Top-level type: {type(data).__name__}')
if isinstance(data, dict):
print(f'Keys: {list(data.keys())[:10]}')
for key, val in list(data.items())[:5]:
if isinstance(val, torch.Tensor):
print(f' {key}: Tensor {val.shape} {val.dtype}')
else:
print(f' {key}: {type(val).__name__}')
elif hasattr(data, 'state_dict'):
print('This is a full model object.')
sd = data.state_dict()
print(f'state_dict keys ({len(sd)}):')
for k, v in sd.items():
print(f' {k}: {v.shape} {v.dtype}')
# Print model attributes
for attr in ['max_seq_len', 'patch_size', 'data_seq_len']:
if hasattr(data, attr):
print(f' model.{attr} = {getattr(data, attr)}')
else:
print(f'Unexpected type: {type(data)}')
if __name__ == '__main__':
for pt_file in sorted(Path('.').glob('*_parameters.pt')):
inspect_checkpoint(pt_file)