UraionSpec / DSpark_implementation_notes.md
UraionLabs's picture
Upload DSpark_implementation_notes.md with huggingface_hub
4989d93 verified
|
Raw
History Blame Contribute Delete
4.76 kB

DSpark Implementation Notes

Architecture Summary

DSpark introduces two key innovations over existing speculative decoding methods:

1. Semi-Autoregressive Generation

  • Parallel backbone (DFlash-style): Processes all γ positions in a single forward pass, making drafting latency nearly independent of block size.
  • Sequential head: Two variants:
    • Markov head (VanillaMarkov/GatedMarkovHead): Low-rank transition bias B = W1[x_{k-1}] @ W2 (r=256).
    • RNN head (RNNHead): GRU-like recurrent state accumulating full prefix history.
  • The sequential stage fixes the "multi-modal collision" problem where parallel drafters produce incoherent combinations like "of problem" instead of "of course".

2. Confidence-Scheduled Verification

  • Confidence head: Predicts per-position conditional acceptance probability c_k = σ(w^T[h_k; W1[x_{k-1}]]).
  • Sequential Temperature Scaling (STS): Calibrates cumulative prefix survival probabilities left-to-right via 1D grid search minimizing ECE.
  • Hardware-Aware Prefix Scheduler (Algorithm 1): Maximizes Θ = τ · SPS(B) by greedily admitting highest-survival-probability tokens with early stopping.

Algorithm Details

Acceptance Rule (Lossless)

P(accept token x_k) = min(1, p^t_k(x_k) / p^d_k(x_k))
  • First rejection at position k discards tokens k+1..γ
  • One bonus token sampled from residual distribution at rejection position
  • Preserves exact target distribution

Draft Distribution

p_k(v | x_0, x_<k) = exp(U_k(v) + B_k(x_0, x_<k, v)) / Σ exp(U_k(u) + B_k(...))

where U_k are base logits from parallel backbone and B_k is the sequential transition bias.

Training Objective (Eq. 12)

L = 0.1 · L_ce + 0.9 · L_tv + 1.0 · L_conf
  • L_ce: Cross-entropy for next-token prediction
  • L_tv: Total variation distance ||p_d - p_t||_1 (proxy for acceptance rate)
  • L_conf: Binary cross-entropy on confidence predictions
  • All position-weighted by w_k = exp(-(k-1)/γ)

STS Calibration

For each position k = 1..γ:

  • Compute cumulative product of calibrated scores up to k
  • Find temperature T_k minimizing ECE of cumulative product
  • Apply T_k to k-th position score
  • Order-preserving: doesn't disrupt relative rankings

Algorithm 1: Hardware-Aware Prefix Scheduler

1. Compute a_{r,j} = ∏_{i≤j} c_{r,i} for each request r, position j
2. Create candidate set E = {(r,j) | a_{r,j} > 0}
3. Sort E descending by a_{r,j}
4. Greedily add candidates:
   - Update ℓ_r = j, B += 1, τ += a_{r,j}
   - Compute Θ = τ * SPS(B)
   - If Θ > best, update best lengths; else break
5. Return per-request verification lengths

What's Implemented Now

Component Status Location
Markov head (vanilla, gated) models/markov_head.py
RNN head models/rnn_head.py
Confidence head models/confidence_head.py
Acceptance rule decoding/acceptance.py
Hardware-aware scheduler decoding/scheduler.py
Static scheduler (fallback) decoding/scheduler.py
Speculative decoding loop decoding/speculative.py
Training dataset training/dataset.py
Loss functions (CE + TV + Conf) training/losses.py
Training loop training/train_drafter.py
Target cache generation training/cache_targets.py
STS calibration calibration/sts.py
Acceptance evaluation evaluation/eval_acceptance.py
Latency benchmarking evaluation/benchmark_latency.py
Unit tests (80) tests/
Smoke train script scripts/smoke_train.py
Smoke eval script scripts/smoke_eval.py
Benchmark script scripts/run_benchmark.py

Deviations from Paper

  1. Parallel backbone: Uses nn.TransformerEncoder instead of the full DFlash-style backbone with target model KV injection.
  2. Training efficiency: No multi-GPU training, no 38TB target cache. Target logits computed on-the-fly.
  3. SPS profiling: Uses a synthetic throughput profile instead of real engine profiling.
  4. Mask token: Uses token 0 as placeholder; the paper uses learned mask embeddings.
  5. Anchor modification: Paper treats anchor as first prediction position (γ inputs → γ outputs). We follow the original DFlash style (γ inputs → γ outputs) where anchor is part of the input.

Future Work

  1. Implement full DFlash backbone with target model KV injection
  2. Multi-GPU distributed training (DeepSpeed/FSDP)
  3. Real engine throughput profiling
  4. Tree-based verification for autoregressive drafters
  5. Integration with vLLM or similar serving frameworks
  6. Qwen3-specific model classes with proper configuration