AnomalyBERT / verify_conversion.py
Jongsu Liam Kim
feat: convert .pt checkpoints to HuggingFace-style state_dict + config
32fa850
# /// script
# requires-python = ">=3.13"
# dependencies = ["torch", "timm", "safetensors"]
# ///
"""Verify converted safetensors match original .pt checkpoints."""
import json
import sys
from pathlib import Path
sys.path.insert(0, 'original')
import torch
from safetensors.torch import load_file
from models.anomaly_transformer import get_anomaly_transformer
def verify(dataset: str) -> bool:
"""Verify a single converted checkpoint."""
pt_path = Path(f'{dataset}_parameters.pt')
config_path = Path(dataset) / 'config.json'
safetensors_path = Path(dataset) / 'model.safetensors'
# Load original
original = torch.load(pt_path, map_location='cpu', weights_only=False)
original_sd = original.state_dict()
# Load config and rebuild model
with open(config_path) as f:
config = json.load(f)
model = get_anomaly_transformer(
input_d_data=config['input_d_data'],
output_d_data=config['output_d_data'],
patch_size=config['patch_size'],
d_embed=config['d_embed'],
hidden_dim_rate=config['hidden_dim_rate'],
max_seq_len=config['max_seq_len'],
positional_encoding=config['positional_encoding'],
relative_position_embedding=config['relative_position_embedding'],
transformer_n_layer=config['transformer_n_layer'],
transformer_n_head=config['transformer_n_head'],
dropout=config['dropout'],
)
# Load safetensors weights
saved_sd = load_file(str(safetensors_path))
model.load_state_dict(saved_sd)
loaded_sd = model.state_dict()
# Compare
ok = True
for key in original_sd:
if key not in loaded_sd:
print(f' MISSING: {key}')
ok = False
continue
if not torch.equal(original_sd[key], loaded_sd[key]):
diff = (original_sd[key] - loaded_sd[key]).abs().max().item()
print(f' MISMATCH: {key} (max diff={diff})')
ok = False
extra = set(loaded_sd.keys()) - set(original_sd.keys())
if extra:
print(f' EXTRA keys: {extra}')
ok = False
status = 'OK' if ok else 'FAIL'
print(f'{dataset}: {status}')
return ok
def main() -> None:
"""Verify all converted checkpoints."""
datasets = ['MSL', 'SMAP', 'SWaT', 'WADI']
results = {d: verify(d) for d in datasets}
all_ok = all(results.values())
print(f'\nAll passed: {all_ok}')
if not all_ok:
sys.exit(1)
if __name__ == '__main__':
main()