# /// script # requires-python = ">=3.13" # dependencies = ["torch", "timm", "safetensors"] # /// """Verify converted safetensors match original .pt checkpoints.""" import json import sys from pathlib import Path sys.path.insert(0, 'original') import torch from safetensors.torch import load_file from models.anomaly_transformer import get_anomaly_transformer def verify(dataset: str) -> bool: """Verify a single converted checkpoint.""" pt_path = Path(f'{dataset}_parameters.pt') config_path = Path(dataset) / 'config.json' safetensors_path = Path(dataset) / 'model.safetensors' # Load original original = torch.load(pt_path, map_location='cpu', weights_only=False) original_sd = original.state_dict() # Load config and rebuild model with open(config_path) as f: config = json.load(f) model = get_anomaly_transformer( input_d_data=config['input_d_data'], output_d_data=config['output_d_data'], patch_size=config['patch_size'], d_embed=config['d_embed'], hidden_dim_rate=config['hidden_dim_rate'], max_seq_len=config['max_seq_len'], positional_encoding=config['positional_encoding'], relative_position_embedding=config['relative_position_embedding'], transformer_n_layer=config['transformer_n_layer'], transformer_n_head=config['transformer_n_head'], dropout=config['dropout'], ) # Load safetensors weights saved_sd = load_file(str(safetensors_path)) model.load_state_dict(saved_sd) loaded_sd = model.state_dict() # Compare ok = True for key in original_sd: if key not in loaded_sd: print(f' MISSING: {key}') ok = False continue if not torch.equal(original_sd[key], loaded_sd[key]): diff = (original_sd[key] - loaded_sd[key]).abs().max().item() print(f' MISMATCH: {key} (max diff={diff})') ok = False extra = set(loaded_sd.keys()) - set(original_sd.keys()) if extra: print(f' EXTRA keys: {extra}') ok = False status = 'OK' if ok else 'FAIL' print(f'{dataset}: {status}') return ok def main() -> None: """Verify all converted checkpoints.""" datasets = ['MSL', 'SMAP', 'SWaT', 'WADI'] results = {d: verify(d) for d in datasets} all_ok = all(results.values()) print(f'\nAll passed: {all_ok}') if not all_ok: sys.exit(1) if __name__ == '__main__': main()