UraionSpec / docs /DSpark_implementation_notes.md
UraionLabs's picture
Initial public release: UraionSpec v0.1.0 β€” Faithful DSpark-style speculative decoding
3c1da87 verified
|
Raw
History Blame Contribute Delete
4.76 kB

DSpark Implementation Notes

Architecture Summary

DSpark introduces two key innovations over existing speculative decoding methods:

1. Semi-Autoregressive Generation

  • Parallel backbone (DFlash-style): Processes all Ξ³ positions in a single forward pass, making drafting latency nearly independent of block size.
  • Sequential head: Two variants:
    • Markov head (VanillaMarkov/GatedMarkovHead): Low-rank transition bias B = W1[x_{k-1}] @ W2 (r=256).
    • RNN head (RNNHead): GRU-like recurrent state accumulating full prefix history.
  • The sequential stage fixes the "multi-modal collision" problem where parallel drafters produce incoherent combinations like "of problem" instead of "of course".

2. Confidence-Scheduled Verification

  • Confidence head: Predicts per-position conditional acceptance probability c_k = Οƒ(w^T[h_k; W1[x_{k-1}]]).
  • Sequential Temperature Scaling (STS): Calibrates cumulative prefix survival probabilities left-to-right via 1D grid search minimizing ECE.
  • Hardware-Aware Prefix Scheduler (Algorithm 1): Maximizes Θ = Ο„ Β· SPS(B) by greedily admitting highest-survival-probability tokens with early stopping.

Algorithm Details

Acceptance Rule (Lossless)

P(accept token x_k) = min(1, p^t_k(x_k) / p^d_k(x_k))
  • First rejection at position k discards tokens k+1..Ξ³
  • One bonus token sampled from residual distribution at rejection position
  • Preserves exact target distribution

Draft Distribution

p_k(v | x_0, x_<k) = exp(U_k(v) + B_k(x_0, x_<k, v)) / Ξ£ exp(U_k(u) + B_k(...))

where U_k are base logits from parallel backbone and B_k is the sequential transition bias.

Training Objective (Eq. 12)

L = 0.1 Β· L_ce + 0.9 Β· L_tv + 1.0 Β· L_conf
  • L_ce: Cross-entropy for next-token prediction
  • L_tv: Total variation distance ||p_d - p_t||_1 (proxy for acceptance rate)
  • L_conf: Binary cross-entropy on confidence predictions
  • All position-weighted by w_k = exp(-(k-1)/Ξ³)

STS Calibration

For each position k = 1..Ξ³:

  • Compute cumulative product of calibrated scores up to k
  • Find temperature T_k minimizing ECE of cumulative product
  • Apply T_k to k-th position score
  • Order-preserving: doesn't disrupt relative rankings

Algorithm 1: Hardware-Aware Prefix Scheduler

1. Compute a_{r,j} = ∏_{i≀j} c_{r,i} for each request r, position j
2. Create candidate set E = {(r,j) | a_{r,j} > 0}
3. Sort E descending by a_{r,j}
4. Greedily add candidates:
   - Update β„“_r = j, B += 1, Ο„ += a_{r,j}
   - Compute Θ = Ο„ * SPS(B)
   - If Θ > best, update best lengths; else break
5. Return per-request verification lengths

What's Implemented Now

Component Status Location
Markov head (vanilla, gated) βœ… models/markov_head.py
RNN head βœ… models/rnn_head.py
Confidence head βœ… models/confidence_head.py
Acceptance rule βœ… decoding/acceptance.py
Hardware-aware scheduler βœ… decoding/scheduler.py
Static scheduler (fallback) βœ… decoding/scheduler.py
Speculative decoding loop βœ… decoding/speculative.py
Training dataset βœ… training/dataset.py
Loss functions (CE + TV + Conf) βœ… training/losses.py
Training loop βœ… training/train_drafter.py
Target cache generation βœ… training/cache_targets.py
STS calibration βœ… calibration/sts.py
Acceptance evaluation βœ… evaluation/eval_acceptance.py
Latency benchmarking βœ… evaluation/benchmark_latency.py
Unit tests (55) βœ… tests/
Smoke train script βœ… scripts/smoke_train.py
Smoke eval script βœ… scripts/smoke_eval.py
Benchmark script βœ… scripts/run_benchmark.py

Deviations from Paper

  1. Parallel backbone: Uses nn.TransformerEncoder instead of the full DFlash-style backbone with target model KV injection.
  2. Training efficiency: No multi-GPU training, no 38TB target cache. Target logits computed on-the-fly.
  3. SPS profiling: Uses a synthetic throughput profile instead of real engine profiling.
  4. Mask token: Uses token 0 as placeholder; the paper uses learned mask embeddings.
  5. Anchor modification: Paper treats anchor as first prediction position (Ξ³ inputs β†’ Ξ³ outputs). We follow the original DFlash style (Ξ³ inputs β†’ Ξ³ outputs) where anchor is part of the input.

Future Work

  1. Implement full DFlash backbone with target model KV injection
  2. Multi-GPU distributed training (DeepSpeed/FSDP)
  3. Real engine throughput profiling
  4. Tree-based verification for autoregressive drafters
  5. Integration with vLLM or similar serving frameworks
  6. Qwen3-specific model classes with proper configuration