HuggingFaceH4/ultrachat_200k
Viewer • Updated • 515k • 57.7k • 739
Now just for tests.
Trained google/t5gemma-2-1b-1b dflash speculator on ultrachat with 6 epochs on 5070TI with love by me and Speculators repo.
Now you can use the vllm plugin.
{
"loss_epoch": 0.3060084866100722,
"full_acc_epoch": 0.7698322311721257,
"position_1_acc_epoch": 0.7811536156278055,
"position_2_acc_epoch": 0.7764394806932667,
"position_3_acc_epoch": 0.7726681094470357,
"position_4_acc_epoch": 0.7686484105084982,
"position_5_acc_epoch": 0.766445727777341,
"position_6_acc_epoch": 0.7653757631631541,
"position_7_acc_epoch": 0.7580444155462153
}
| run | req | out tok | out tok/s | req/s | lat mean | lat p50 | lat p95 | TTFT p50 | TTFT p95 | ITL p50 | ITL p95 | TPOT p50 | acc rate | acc len | accepted/draft |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| raw bs1/baseline | 1303 | 77687 | 12.71 | 0.213 | 4.69s | 4.85s | 6.10s | 154.8ms | 251.4ms | 75.4ms | 94.8ms | 75.5ms | - | - | - |
| k7 bs1/dflash | 1303 | 78256 | 16.14 | 0.269 | 3.72s | 4.65s | 6.28s | 167.1ms | 263.2ms | 80.0ms | 104.7ms | 72.8ms | 5.46% | 1.38 | 21850/399980 |
| k3 bs1/dflash | 1303 | 78255 | 15.19 | 0.253 | 3.95s | 4.80s | 6.39s | 173.1ms | 285.9ms | 82.2ms | 104.7ms | 75.1ms | 10.73% | 1.32 | 19086/177891 |
| raw bs2/baseline | 1303 | 77123 | 21.10 | 0.357 | 5.61s | 5.82s | 7.25s | 237.9ms | 380.4ms | 90.1ms | 112.1ms | 90.1ms | - | - | - |
| k7 bs2/dflash | 1303 | 77901 | 26.72 | 0.447 | 4.47s | 5.60s | 7.34s | 269.4ms | 419.0ms | 94.1ms | 120.7ms | 87.1ms | 5.28% | 1.37 | 21213/401527 |
| k3 bs2/dflash | 1303 | 78122 | 24.61 | 0.410 | 4.87s | 6.00s | 7.60s | 284.2ms | 447.7ms | 100.2ms | 123.6ms | 92.5ms | 10.70% | 1.32 | 19004/177627 |
Custom branch, adapter TODO
#!/bin/bash
set -euo pipefail
MODEL="${MODEL:-google/t5gemma-2-1b-1b}"
DATASET="${DATASET:-ultrachat}"
OUTPUT_DIR="${OUTPUT_DIR:-$HOME/dflash-output/dflash_t5gemma2_online}"
MAX_SAMPLES="${MAX_SAMPLES:-5000}"
# Use 5 verifier layers, evenly spaced across the decoder, matching the DFlash
# paper's setup (5 target hidden states between the early and late layers).
TARGET_LAYER_IDS=(2 7 13 18 21)
# The Arrow dataset is small: it stores only encoder/decoder token IDs and masks.
if [[ ! -f "$OUTPUT_DIR/dataset_info.json" ]]; then
python scripts/prepare_t5gemma_data.py \
--model "$MODEL" \
--data "$DATASET" \
--output "$OUTPUT_DIR" \
--max-samples "$MAX_SAMPLES" \
--encoder-seq-length 2048 \
--decoder-seq-length 1024
else
echo "Prepared dataset already exists at $OUTPUT_DIR; reusing it."
fi
# The verifier is kept frozen on the same GPU as the drafter. num_workers must
# remain zero: worker processes cannot share this in-process CUDA model.
PYTORCH_ALLOC_CONF=expandable_segments:True \
TORCHDYNAMO_DISABLE=1 \
python scripts/train_t5gemma_online.py \
--verifier-name-or-path "$MODEL" \
--data-path "$OUTPUT_DIR" \
--save-path "$OUTPUT_DIR/checkpoints" \
--draft-vocab-size 32000 \
--speculator-type dflash \
--draft-arch llama \
--draft-hidden-act silu \
--draft-attn-impl eager \
--block-size 8 \
--max-anchors 256 \
--num-layers 5 \
--target-layer-ids "${TARGET_LAYER_IDS[@]}" \
--total-seq-len 2048 \
--epochs 6 \
--lr 6e-4 \
--loss-fn kl_div \
--on-missing raise \
--num-workers 0 \
--prefetch-factor 1
Base model
google/t5gemma-2-1b-1b