t5gemma-2-1b-1b.dflash-dev

Now just for tests.

Trained google/t5gemma-2-1b-1b dflash speculator on ultrachat with 6 epochs on 5070TI with love by me and Speculators repo.

Run

Now you can use the vllm plugin.

Metrics

{
    "loss_epoch": 0.3060084866100722,

    "full_acc_epoch": 0.7698322311721257,

    "position_1_acc_epoch": 0.7811536156278055,
    "position_2_acc_epoch": 0.7764394806932667,
    "position_3_acc_epoch": 0.7726681094470357,
    "position_4_acc_epoch": 0.7686484105084982,
    "position_5_acc_epoch": 0.766445727777341,
    "position_6_acc_epoch": 0.7653757631631541,
    "position_7_acc_epoch": 0.7580444155462153
}

Performance

run req out tok out tok/s req/s lat mean lat p50 lat p95 TTFT p50 TTFT p95 ITL p50 ITL p95 TPOT p50 acc rate acc len accepted/draft
raw bs1/baseline 1303 77687 12.71 0.213 4.69s 4.85s 6.10s 154.8ms 251.4ms 75.4ms 94.8ms 75.5ms - - -
k7 bs1/dflash 1303 78256 16.14 0.269 3.72s 4.65s 6.28s 167.1ms 263.2ms 80.0ms 104.7ms 72.8ms 5.46% 1.38 21850/399980
k3 bs1/dflash 1303 78255 15.19 0.253 3.95s 4.80s 6.39s 173.1ms 285.9ms 82.2ms 104.7ms 75.1ms 10.73% 1.32 19086/177891
raw bs2/baseline 1303 77123 21.10 0.357 5.61s 5.82s 7.25s 237.9ms 380.4ms 90.1ms 112.1ms 90.1ms - - -
k7 bs2/dflash 1303 77901 26.72 0.447 4.47s 5.60s 7.34s 269.4ms 419.0ms 94.1ms 120.7ms 87.1ms 5.28% 1.37 21213/401527
k3 bs2/dflash 1303 78122 24.61 0.410 4.87s 6.00s 7.60s 284.2ms 447.7ms 100.2ms 123.6ms 92.5ms 10.70% 1.32 19004/177627

Training

Custom branch, adapter TODO

#!/bin/bash
set -euo pipefail

MODEL="${MODEL:-google/t5gemma-2-1b-1b}"
DATASET="${DATASET:-ultrachat}"
OUTPUT_DIR="${OUTPUT_DIR:-$HOME/dflash-output/dflash_t5gemma2_online}"
MAX_SAMPLES="${MAX_SAMPLES:-5000}"
# Use 5 verifier layers, evenly spaced across the decoder, matching the DFlash
# paper's setup (5 target hidden states between the early and late layers).
TARGET_LAYER_IDS=(2 7 13 18 21)

# The Arrow dataset is small: it stores only encoder/decoder token IDs and masks.
if [[ ! -f "$OUTPUT_DIR/dataset_info.json" ]]; then
    python scripts/prepare_t5gemma_data.py \
        --model "$MODEL" \
        --data "$DATASET" \
        --output "$OUTPUT_DIR" \
        --max-samples "$MAX_SAMPLES" \
        --encoder-seq-length 2048 \
        --decoder-seq-length 1024
else
    echo "Prepared dataset already exists at $OUTPUT_DIR; reusing it."
fi

# The verifier is kept frozen on the same GPU as the drafter. num_workers must
# remain zero: worker processes cannot share this in-process CUDA model.
PYTORCH_ALLOC_CONF=expandable_segments:True \
TORCHDYNAMO_DISABLE=1 \
python scripts/train_t5gemma_online.py \
    --verifier-name-or-path "$MODEL" \
    --data-path "$OUTPUT_DIR" \
    --save-path "$OUTPUT_DIR/checkpoints" \
    --draft-vocab-size 32000 \
    --speculator-type dflash \
    --draft-arch llama \
    --draft-hidden-act silu \
    --draft-attn-impl eager \
    --block-size 8 \
    --max-anchors 256 \
    --num-layers 5 \
    --target-layer-ids "${TARGET_LAYER_IDS[@]}" \
    --total-seq-len 2048 \
    --epochs 6 \
    --lr 6e-4 \
    --loss-fn kl_div \
    --on-missing raise \
    --num-workers 0 \
    --prefetch-factor 1
Downloads last month
44
Safetensors
Model size
0.5B params
Tensor type
I64
·
BF16
·
BOOL
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for d0rj/t5gemma-2-1b-1b.dflash-dev

Finetuned
(5)
this model

Dataset used to train d0rj/t5gemma-2-1b-1b.dflash-dev