DistilGPT-2 ONNX (with attention outputs)

A custom ONNX export of distilbert/distilgpt2 that includes attention weights as graph outputs, quantized to int8 for efficient in-browser inference via Transformers.js.

Used by forwardpass.dev — an interactive visual guide to LLMs — to power live next-token logits, sampling, generation, and real attention heatmap visualizations computed entirely in the browser.

Why this export

Standard ONNX exports of GPT-2 strip out attention weights (the attentions tuple returned by model(..., output_attentions=True)) because they aren't needed for inference. This export preserves them as named graph outputs so the model can be used for educational visualizations of how attention works in causal language models.

Model details

  • Base model: distilbert/distilgpt2 (82M parameters)
  • Architecture: Causal LM (decoder-only, GPT-2 style)
  • Layers: 6 transformer blocks
  • Heads: 12 attention heads per layer
  • Hidden size: 768
  • Vocabulary: 50,257 tokens (BPE)
  • Quantization: int8 dynamic
  • File size: ~84 MB
  • Format: Single-file ONNX (no external data)

ONNX inputs / outputs

Name Shape Description
input_ids [batch, seq] Token IDs
attention_mask [batch, seq] All ones for non-padded inputs
position_ids [batch, seq] [0, 1, ..., seq-1]
logits [batch, seq, 50257] Next-token logits
attentions.0 ... attentions.5 [batch, 12, seq, seq] Attention weights per layer

Usage

Transformers.js

import { AutoTokenizer, AutoModel, Tensor } from "@huggingface/transformers";

const tokenizer = await AutoTokenizer.from_pretrained("dbernsohn/distilgpt2-onnx");
const model = await AutoModel.from_pretrained("dbernsohn/distilgpt2-onnx", {
  dtype: "fp32",
  model_file_name: "model",
});

const inputs = tokenizer("The capital of France is");
const seqLen = inputs.input_ids.dims[1];

const result = await model.forward({
  input_ids: inputs.input_ids,
  attention_mask: new Tensor(
    "int64",
    new BigInt64Array(seqLen).fill(1n),
    [1, seqLen]
  ),
  position_ids: new Tensor(
    "int64",
    BigInt64Array.from({ length: seqLen }, (_, i) => BigInt(i)),
    [1, seqLen]
  ),
});

// result.logits: [1, seq, 50257]
// result["attentions.0"] ... result["attentions.5"]: [1, 12, seq, seq]

Python (onnxruntime)

import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("dbernsohn/distilgpt2-onnx")
session = ort.InferenceSession("onnx/model.onnx")

ids = tokenizer("The capital of France is", return_tensors="np")
seq_len = ids["input_ids"].shape[1]

outputs = session.run(None, {
    "input_ids": ids["input_ids"].astype(np.int64),
    "attention_mask": np.ones((1, seq_len), dtype=np.int64),
    "position_ids": np.arange(seq_len, dtype=np.int64).reshape(1, -1),
})

logits = outputs[0]                # [1, seq, 50257]
attentions = outputs[1:]           # 6 tensors of [1, 12, seq, seq]

Reproducing this export

Install dependencies:

pip install torch transformers optimum[onnxruntime] onnx onnxscript

Then run the script below (also available at scripts/export_distilgpt2.py in the forwardpass.dev repo):

import shutil
from pathlib import Path
import torch
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "distilbert/distilgpt2"
OUTPUT_DIR = Path("exported-distilgpt2")

if OUTPUT_DIR.exists():
    shutil.rmtree(OUTPUT_DIR)
OUTPUT_DIR.mkdir()
(OUTPUT_DIR / "onnx").mkdir()

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, attn_implementation="eager")
model.eval()

# Wrapper that returns logits + attention layers as separate outputs
class AttentionWrapper(torch.nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model

    def forward(self, input_ids, attention_mask, position_ids):
        out = self.base(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=True,
            use_cache=False,
        )
        return (out.logits,) + out.attentions

wrapper = AttentionWrapper(model)
wrapper.eval()

dummy = tokenizer("Hello world", return_tensors="pt")
input_ids = dummy["input_ids"]
attention_mask = dummy["attention_mask"]
seq_len = input_ids.shape[1]
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)

with torch.no_grad():
    test_out = wrapper(input_ids, attention_mask, position_ids)
num_attn_layers = len(test_out) - 1

output_names = ["logits"]
dynamic_axes = {
    "input_ids": {0: "batch", 1: "seq"},
    "attention_mask": {0: "batch", 1: "seq"},
    "position_ids": {0: "batch", 1: "seq"},
    "logits": {0: "batch", 1: "seq"},
}
for i in range(num_attn_layers):
    name = f"attentions.{i}"
    output_names.append(name)
    dynamic_axes[name] = {0: "batch", 2: "seq", 3: "seq"}

onnx_path = OUTPUT_DIR / "onnx" / "model_fp32.onnx"

torch.onnx.export(
    wrapper,
    (input_ids, attention_mask, position_ids),
    str(onnx_path),
    input_names=["input_ids", "attention_mask", "position_ids"],
    output_names=output_names,
    dynamic_axes=dynamic_axes,
    opset_version=17,
    do_constant_folding=True,
    dynamo=True,
    optimize=False,  # skip optimization that fails on GPT-2 + attention outputs
)

# Merge external data into single file
data_file = onnx_path.parent / (onnx_path.name + ".data")
if data_file.exists():
    m = onnx.load(str(onnx_path), load_external_data=True)
    onnx.save_model(m, str(onnx_path), save_as_external_data=False)
    data_file.unlink()

# Quantize to int8
quantized_path = OUTPUT_DIR / "onnx" / "model.onnx"
quantize_dynamic(
    str(onnx_path),
    str(quantized_path),
    weight_type=QuantType.QInt8,
)
onnx_path.unlink()

# Save tokenizer + config
tokenizer.save_pretrained(str(OUTPUT_DIR))
model.config.save_pretrained(str(OUTPUT_DIR))

Credits

License

Apache 2.0 (inherited from the base model).

Downloads last month
376
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for dbernsohn/distilgpt2-onnx

Quantized
(22)
this model

Paper for dbernsohn/distilgpt2-onnx