AnomalyBERT / convert_to_hf.py
Jongsu Liam Kim
feat: convert .pt checkpoints to HuggingFace-style state_dict + config
32fa850
# /// script
# requires-python = ">=3.13"
# dependencies = ["torch", "timm", "safetensors"]
# ///
"""Convert AnomalyBERT .pt checkpoint files to HuggingFace-style state_dict + config structure.
Each .pt file is converted to a directory containing:
- config.json: model hyperparameters
- model.safetensors: state_dict in safetensors format
"""
import json
import sys
from pathlib import Path
sys.path.insert(0, 'original')
import torch
from safetensors.torch import save_file
def extract_config(model: torch.nn.Module) -> dict:
"""Extract model hyperparameters from a loaded AnomalyTransformer."""
patch_size = model.patch_size
max_seq_len = model.max_seq_len
d_embed = model.linear_embedding.weight.shape[0]
input_d_data = model.linear_embedding.weight.shape[1] // patch_size
# output_d_data from mlp_layers[-1].out_features / patch_size
mlp_last = model.mlp_layers[-1]
output_d_data = mlp_last.out_features // patch_size
# Count transformer layers
n_layer = len(model.transformer_encoder.encoder_layers)
# Get n_head from first attention layer
first_attn = model.transformer_encoder.encoder_layers[0].attention_layer
n_head = first_attn.n_head
# Get hidden_dim from first feed forward layer
first_ff = model.transformer_encoder.encoder_layers[0].feed_forward_layer
hidden_dim = first_ff.first_fc_layer.out_features
hidden_dim_rate = hidden_dim / d_embed
# Detect positional encoding type
has_pe = model.transformer_encoder.positional_encoding
if has_pe:
pe_layer = model.transformer_encoder.positional_encoding_layer
pe_type = type(pe_layer).__name__
if 'Sinusoidal' in pe_type:
positional_encoding = 'Sinusoidal'
elif 'Absolute' in pe_type:
positional_encoding = 'Absolute'
else:
positional_encoding = pe_type
else:
positional_encoding = None
# Detect relative position embedding
relative_position_embedding = first_attn.relative_position_embedding
# Dropout rate from encoder layer
dropout = model.transformer_encoder.encoder_layers[0].dropout_layer.p
return {
'model_type': 'AnomalyBERT',
'input_d_data': input_d_data,
'output_d_data': output_d_data,
'patch_size': patch_size,
'd_embed': d_embed,
'hidden_dim_rate': hidden_dim_rate,
'max_seq_len': max_seq_len,
'positional_encoding': positional_encoding,
'relative_position_embedding': relative_position_embedding,
'transformer_n_layer': n_layer,
'transformer_n_head': n_head,
'dropout': dropout,
}
def convert_checkpoint(pt_path: Path, output_dir: Path) -> None:
"""Convert a single .pt checkpoint to state_dict + config."""
print(f'Converting {pt_path.name}...')
model = torch.load(pt_path, map_location='cpu', weights_only=False)
config = extract_config(model)
state_dict = model.state_dict()
output_dir.mkdir(parents=True, exist_ok=True)
# Save config
config_path = output_dir / 'config.json'
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
print(f' config.json: {json.dumps(config, indent=2)}')
# Save state_dict as safetensors
safetensors_path = output_dir / 'model.safetensors'
save_file(state_dict, str(safetensors_path))
size_mb = safetensors_path.stat().st_size / 1024 / 1024
print(f' model.safetensors: {size_mb:.1f} MB ({len(state_dict)} tensors)')
print(f' -> {output_dir}/')
def main() -> None:
"""Convert all *_parameters.pt files."""
pt_files = sorted(Path('.').glob('*_parameters.pt'))
if not pt_files:
print('No *_parameters.pt files found.')
return
for pt_path in pt_files:
dataset_name = pt_path.stem.replace('_parameters', '')
output_dir = Path(dataset_name)
convert_checkpoint(pt_path, output_dir)
print(f'\nDone. Converted {len(pt_files)} checkpoints.')
if __name__ == '__main__':
main()