metadata
library_name: pytorch
tags:
- time-series
- anomaly-detection
- transformer
datasets:
- MSL
- SMAP
- SWaT
- WADI
pipeline_tag: other
AnomalyBERT
Pre-trained checkpoints for AnomalyBERT β a self-supervised Transformer model for time series anomaly detection based on a data degradation scheme.
Paper: Self-supervised Transformer for Time Series Anomaly Detection using Data Degradation Scheme
Original code: Jhryu30/AnomalyBERT
Model Architecture
AnomalyBERT uses a Transformer encoder architecture with:
- Linear patch embedding
- Relative position embedding
- Pre-norm encoder layers (LayerNorm β Attention/FFN)
- MLP head for reconstruction
The model learns normal patterns via masked data degradation during training, and detects anomalies by measuring reconstruction error at inference time.
Checkpoints
Each dataset directory contains config.json (hyperparameters) and model.safetensors (weights).
| Dataset | input_d_data | patch_size | d_embed | n_layer | n_head | max_seq_len | Parameters |
|---|---|---|---|---|---|---|---|
| MSL | 55 | 2 | 512 | 6 | 8 | 512 | ~19M |
| SMAP | 25 | 4 | 512 | 6 | 8 | 512 | ~19M |
| SWaT | 50 | 14 | 512 | 6 | 8 | 512 | ~19M |
| WADI | 122 | 8 | 512 | 6 | 8 | 512 | ~19M |
Usage
import json
from pathlib import Path
import torch
from safetensors.torch import load_file
from models.anomaly_transformer import get_anomaly_transformer
def load_model(dataset_dir: str) -> torch.nn.Module:
"""Load an AnomalyBERT model from config + safetensors."""
dataset_path = Path(dataset_dir)
with open(dataset_path / 'config.json') as f:
config = json.load(f)
model = get_anomaly_transformer(
input_d_data=config['input_d_data'],
output_d_data=config['output_d_data'],
patch_size=config['patch_size'],
d_embed=config['d_embed'],
hidden_dim_rate=config['hidden_dim_rate'],
max_seq_len=config['max_seq_len'],
positional_encoding=config['positional_encoding'],
relative_position_embedding=config['relative_position_embedding'],
transformer_n_layer=config['transformer_n_layer'],
transformer_n_head=config['transformer_n_head'],
dropout=config['dropout'],
)
state_dict = load_file(str(dataset_path / 'model.safetensors'))
model.load_state_dict(state_dict)
model.eval()
return model
# Example: load the MSL model
model = load_model('MSL')
# Inference
# x shape: (batch, patch_size * max_seq_len, input_d_data)
x = torch.randn(1, 1024, 55)
with torch.no_grad():
output = model(x)
# output shape: (batch, patch_size * max_seq_len, output_d_data)
File Structure
βββ MSL/
β βββ config.json
β βββ model.safetensors
βββ SMAP/
β βββ config.json
β βββ model.safetensors
βββ SWaT/
β βββ config.json
β βββ model.safetensors
βββ WADI/
β βββ config.json
β βββ model.safetensors
βββ convert_to_hf.py # Conversion script (.pt -> safetensors)
βββ inspect_pt.py # Checkpoint inspection script
βββ verify_conversion.py # Conversion verification script
Citation
@article{jeong2023anomalybert,
title={AnomalyBERT: Self-Supervised Transformer for Time Series Anomaly Detection using Data Degradation Scheme},
author={Jeong, Yungi and Yang, Eunseok and Ryu, Jung Hyun and Park, Imseong and Kang, Myungjoo},
journal={arXiv preprint arXiv:2305.04468},
year={2023}
}