DistilGPT-2 ONNX (with attention outputs)
A custom ONNX export of distilbert/distilgpt2
that includes attention weights as graph outputs, quantized to int8 for efficient
in-browser inference via Transformers.js.
Used by forwardpass.dev — an interactive visual guide to LLMs — to power live next-token logits, sampling, generation, and real attention heatmap visualizations computed entirely in the browser.
Why this export
Standard ONNX exports of GPT-2 strip out attention weights (the attentions tuple
returned by model(..., output_attentions=True)) because they aren't needed for
inference. This export preserves them as named graph outputs so the model can be used
for educational visualizations of how attention works in causal language models.
Model details
- Base model:
distilbert/distilgpt2(82M parameters) - Architecture: Causal LM (decoder-only, GPT-2 style)
- Layers: 6 transformer blocks
- Heads: 12 attention heads per layer
- Hidden size: 768
- Vocabulary: 50,257 tokens (BPE)
- Quantization: int8 dynamic
- File size: ~84 MB
- Format: Single-file ONNX (no external data)
ONNX inputs / outputs
| Name | Shape | Description |
|---|---|---|
input_ids |
[batch, seq] |
Token IDs |
attention_mask |
[batch, seq] |
All ones for non-padded inputs |
position_ids |
[batch, seq] |
[0, 1, ..., seq-1] |
logits |
[batch, seq, 50257] |
Next-token logits |
attentions.0 ... attentions.5 |
[batch, 12, seq, seq] |
Attention weights per layer |
Usage
Transformers.js
import { AutoTokenizer, AutoModel, Tensor } from "@huggingface/transformers";
const tokenizer = await AutoTokenizer.from_pretrained("dbernsohn/distilgpt2-onnx");
const model = await AutoModel.from_pretrained("dbernsohn/distilgpt2-onnx", {
dtype: "fp32",
model_file_name: "model",
});
const inputs = tokenizer("The capital of France is");
const seqLen = inputs.input_ids.dims[1];
const result = await model.forward({
input_ids: inputs.input_ids,
attention_mask: new Tensor(
"int64",
new BigInt64Array(seqLen).fill(1n),
[1, seqLen]
),
position_ids: new Tensor(
"int64",
BigInt64Array.from({ length: seqLen }, (_, i) => BigInt(i)),
[1, seqLen]
),
});
// result.logits: [1, seq, 50257]
// result["attentions.0"] ... result["attentions.5"]: [1, 12, seq, seq]
Python (onnxruntime)
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("dbernsohn/distilgpt2-onnx")
session = ort.InferenceSession("onnx/model.onnx")
ids = tokenizer("The capital of France is", return_tensors="np")
seq_len = ids["input_ids"].shape[1]
outputs = session.run(None, {
"input_ids": ids["input_ids"].astype(np.int64),
"attention_mask": np.ones((1, seq_len), dtype=np.int64),
"position_ids": np.arange(seq_len, dtype=np.int64).reshape(1, -1),
})
logits = outputs[0] # [1, seq, 50257]
attentions = outputs[1:] # 6 tensors of [1, 12, seq, seq]
Reproducing this export
Install dependencies:
pip install torch transformers optimum[onnxruntime] onnx onnxscript
Then run the script below (also available at scripts/export_distilgpt2.py in the forwardpass.dev repo):
import shutil
from pathlib import Path
import torch
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "distilbert/distilgpt2"
OUTPUT_DIR = Path("exported-distilgpt2")
if OUTPUT_DIR.exists():
shutil.rmtree(OUTPUT_DIR)
OUTPUT_DIR.mkdir()
(OUTPUT_DIR / "onnx").mkdir()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, attn_implementation="eager")
model.eval()
# Wrapper that returns logits + attention layers as separate outputs
class AttentionWrapper(torch.nn.Module):
def __init__(self, base_model):
super().__init__()
self.base = base_model
def forward(self, input_ids, attention_mask, position_ids):
out = self.base(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=True,
use_cache=False,
)
return (out.logits,) + out.attentions
wrapper = AttentionWrapper(model)
wrapper.eval()
dummy = tokenizer("Hello world", return_tensors="pt")
input_ids = dummy["input_ids"]
attention_mask = dummy["attention_mask"]
seq_len = input_ids.shape[1]
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
with torch.no_grad():
test_out = wrapper(input_ids, attention_mask, position_ids)
num_attn_layers = len(test_out) - 1
output_names = ["logits"]
dynamic_axes = {
"input_ids": {0: "batch", 1: "seq"},
"attention_mask": {0: "batch", 1: "seq"},
"position_ids": {0: "batch", 1: "seq"},
"logits": {0: "batch", 1: "seq"},
}
for i in range(num_attn_layers):
name = f"attentions.{i}"
output_names.append(name)
dynamic_axes[name] = {0: "batch", 2: "seq", 3: "seq"}
onnx_path = OUTPUT_DIR / "onnx" / "model_fp32.onnx"
torch.onnx.export(
wrapper,
(input_ids, attention_mask, position_ids),
str(onnx_path),
input_names=["input_ids", "attention_mask", "position_ids"],
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=17,
do_constant_folding=True,
dynamo=True,
optimize=False, # skip optimization that fails on GPT-2 + attention outputs
)
# Merge external data into single file
data_file = onnx_path.parent / (onnx_path.name + ".data")
if data_file.exists():
m = onnx.load(str(onnx_path), load_external_data=True)
onnx.save_model(m, str(onnx_path), save_as_external_data=False)
data_file.unlink()
# Quantize to int8
quantized_path = OUTPUT_DIR / "onnx" / "model.onnx"
quantize_dynamic(
str(onnx_path),
str(quantized_path),
weight_type=QuantType.QInt8,
)
onnx_path.unlink()
# Save tokenizer + config
tokenizer.save_pretrained(str(OUTPUT_DIR))
model.config.save_pretrained(str(OUTPUT_DIR))
Credits
- Base model: distilbert/distilgpt2 by the Hugging Face team (Victor Sanh et al.)
- Original GPT-2: OpenAI
- Distillation paper: DistilBERT, a distilled version of BERT
- Export & hosting: Dor Bernsohn for forwardpass.dev
License
Apache 2.0 (inherited from the base model).
- Downloads last month
- 376
Model tree for dbernsohn/distilgpt2-onnx
Base model
distilbert/distilgpt2