Florian valade
commited on
Commit
·
33efa44
1
Parent(s):
432ea6e
Track metrics during streaming, remove redundant generation re-runs
Browse files- Add StreamingResult dataclass to hold metrics from streaming generation
- Update StreamEvent with 'complete' event type and result field
- Extract _format_and_encode_prompt() helper to reduce code duplication
- Update generate_streaming() and generate_full_model_streaming() to
yield final 'complete' event with accumulated metrics
- Refactor app.py to use streaming results instead of re-running generation
- Remove unused code: set_full_cache(), create_cache_from_model(),
to_json() methods, get_threshold()
- Add jagged_cache.py and test files for KV cache operations
This fixes the bug where output would change after streaming completed
and metrics would appear after a delay.
- app.py +268 -118
- src/inference.py +88 -68
- src/jagged_cache.py +247 -0
- src/model_adapters.py +1 -5
- src/model_config.py +1 -14
- tests/__init__.py +1 -0
- tests/benchmark_kv_cache.py +358 -0
- tests/run_benchmark.py +239 -0
- tests/test_cache_integration.py +195 -0
- tests/test_cache_operations.py +495 -0
app.py
CHANGED
|
@@ -4,10 +4,12 @@ Showcases early exit inference with color-coded tokens showing which head genera
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import gradio as gr
|
|
|
|
| 7 |
from pathlib import Path
|
|
|
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
|
| 10 |
-
from src.inference import load_dssd_model, DSSDecoder, TokenInfo, StreamEvent
|
| 11 |
|
| 12 |
# Available models configuration
|
| 13 |
AVAILABLE_MODELS = {
|
|
@@ -33,6 +35,10 @@ HEAD_COLORS = [
|
|
| 33 |
]
|
| 34 |
FULL_MODEL_COLOR = "#95D5B2" # Light green - Full model
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# Global decoder cache
|
| 37 |
_decoder_cache = {}
|
| 38 |
|
|
@@ -103,7 +109,7 @@ def tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> str:
|
|
| 103 |
|
| 104 |
html_parts.append(
|
| 105 |
f'<span style="background-color: {color}; padding: 2px 4px; '
|
| 106 |
-
f'border-radius: 3px; margin: 1px; display: inline-block;" title="{title}">{text}</span>'
|
| 107 |
)
|
| 108 |
|
| 109 |
# Wrap in container with word-wrap to prevent overflow
|
|
@@ -121,8 +127,8 @@ def drafted_tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> s
|
|
| 121 |
layer = head_layers[token.exit_head]
|
| 122 |
title = f"PENDING - Head {token.exit_head} (Layer {layer})"
|
| 123 |
else:
|
| 124 |
-
color =
|
| 125 |
-
title = "PENDING -
|
| 126 |
|
| 127 |
text = (
|
| 128 |
token.token_text.replace("&", "&")
|
|
@@ -134,7 +140,8 @@ def drafted_tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> s
|
|
| 134 |
html_parts.append(
|
| 135 |
f'<span style="background-color: {color}; padding: 2px 4px; '
|
| 136 |
f"border-radius: 3px; margin: 1px; display: inline-block; "
|
| 137 |
-
f
|
|
|
|
| 138 |
)
|
| 139 |
|
| 140 |
return "".join(html_parts)
|
|
@@ -156,17 +163,78 @@ def create_legend(head_layers: list[int]) -> str:
|
|
| 156 |
return " ".join(legend_items)
|
| 157 |
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
|
| 172 |
def generate(
|
|
@@ -178,13 +246,29 @@ def generate(
|
|
| 178 |
compare_mode: bool,
|
| 179 |
):
|
| 180 |
"""Main generation function for Gradio interface with streaming."""
|
|
|
|
| 181 |
try:
|
| 182 |
decoder = get_decoder(model_key)
|
| 183 |
except Exception as e:
|
| 184 |
error_msg = f"<p style='color: red;'>Error loading model: {e}</p>"
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
return
|
| 187 |
|
|
|
|
| 188 |
head_layers = decoder.model_config.head_layer_indices
|
| 189 |
legend = create_legend(head_layers)
|
| 190 |
|
|
@@ -199,12 +283,21 @@ def generate(
|
|
| 199 |
# Compare mode with streaming for early exit
|
| 200 |
# First, stream the early exit generation
|
| 201 |
final_ee_tokens = []
|
|
|
|
|
|
|
| 202 |
for event in decoder.generate_streaming(
|
| 203 |
prompt=prompt,
|
| 204 |
max_tokens=int(max_tokens),
|
| 205 |
accuracy_level=closest_level,
|
| 206 |
use_chat_template=True,
|
| 207 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
validated_html = ""
|
| 209 |
if event.tokens:
|
| 210 |
validated_html = tokens_to_html(event.tokens, head_layers)
|
|
@@ -219,91 +312,97 @@ def generate(
|
|
| 219 |
|
| 220 |
combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""
|
| 221 |
|
| 222 |
-
status =
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
yield (
|
| 229 |
combined_html,
|
| 230 |
-
"<p style='color:
|
| 231 |
status,
|
|
|
|
| 232 |
legend,
|
| 233 |
)
|
| 234 |
-
final_ee_tokens = event.tokens
|
| 235 |
|
| 236 |
# Now stream full model
|
| 237 |
final_full_tokens = []
|
|
|
|
|
|
|
| 238 |
for event in decoder.generate_full_model_streaming(
|
| 239 |
prompt=prompt,
|
| 240 |
max_tokens=int(max_tokens),
|
| 241 |
use_chat_template=True,
|
| 242 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
html_full = tokens_to_html(event.tokens, head_layers)
|
| 244 |
-
status =
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
yield (
|
| 250 |
tokens_to_html(final_ee_tokens, head_layers),
|
| 251 |
html_full,
|
| 252 |
status,
|
|
|
|
| 253 |
legend,
|
| 254 |
)
|
| 255 |
-
final_full_tokens = event.tokens
|
| 256 |
-
|
| 257 |
-
# Final stats
|
| 258 |
-
result_ee = decoder.generate(
|
| 259 |
-
prompt=prompt,
|
| 260 |
-
max_tokens=int(max_tokens),
|
| 261 |
-
use_early_exit=True,
|
| 262 |
-
accuracy_level=closest_level,
|
| 263 |
-
use_chat_template=True,
|
| 264 |
-
)
|
| 265 |
-
result_full = decoder.generate(
|
| 266 |
-
prompt=prompt,
|
| 267 |
-
max_tokens=int(max_tokens),
|
| 268 |
-
use_early_exit=False,
|
| 269 |
-
use_chat_template=True,
|
| 270 |
-
)
|
| 271 |
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
)
|
| 280 |
-
stats = f"""
|
| 281 |
-
<div style="padding: 15px; background: #e8f5e9; border-radius: 8px;">
|
| 282 |
-
<h3 style="margin: 0 0 10px 0;">🚀 Speedup: {speedup:.2f}x</h3>
|
| 283 |
-
<div style="display: flex; gap: 20px;">
|
| 284 |
-
<div style="flex: 1; padding: 10px; background: white; border-radius: 8px;">
|
| 285 |
-
<h4>Early Exit</h4>
|
| 286 |
-
<p><b>Time:</b> {result_ee.total_time:.2f}s | <b>Tokens/sec:</b> {result_ee.tokens_per_second:.2f}</p>
|
| 287 |
-
<p><b>Avg Exit Layer:</b> {result_ee.avg_exit_layer:.1f}</p>
|
| 288 |
-
</div>
|
| 289 |
-
<div style="flex: 1; padding: 10px; background: white; border-radius: 8px;">
|
| 290 |
-
<h4>Full Model</h4>
|
| 291 |
-
<p><b>Time:</b> {result_full.total_time:.2f}s | <b>Tokens/sec:</b> {result_full.tokens_per_second:.2f}</p>
|
| 292 |
-
<p><b>Avg Exit Layer:</b> {result_full.avg_exit_layer:.1f}</p>
|
| 293 |
-
</div>
|
| 294 |
-
</div>
|
| 295 |
-
</div>
|
| 296 |
-
"""
|
| 297 |
-
yield (html_ee, html_full, stats, legend)
|
| 298 |
|
| 299 |
elif use_early_exit:
|
| 300 |
# STREAMING mode for early exit - show draft/verify process
|
|
|
|
|
|
|
|
|
|
| 301 |
for event in decoder.generate_streaming(
|
| 302 |
prompt=prompt,
|
| 303 |
max_tokens=int(max_tokens),
|
| 304 |
accuracy_level=closest_level,
|
| 305 |
use_chat_template=True,
|
| 306 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
# Build HTML showing validated + drafted tokens
|
| 308 |
validated_html = ""
|
| 309 |
if event.tokens:
|
|
@@ -322,63 +421,86 @@ def generate(
|
|
| 322 |
combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""
|
| 323 |
|
| 324 |
# Status message
|
| 325 |
-
status =
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
-
# Final
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
| 341 |
)
|
| 342 |
-
html = tokens_to_html(result.tokens, head_layers)
|
| 343 |
-
stats = f"""
|
| 344 |
-
<div style="padding: 15px; background: #f5f5f5; border-radius: 8px;">
|
| 345 |
-
<h4 style="margin: 0 0 10px 0;">Early Exit Statistics (Final)</h4>
|
| 346 |
-
<p><b>Tokens:</b> {len(result.tokens)} | <b>Tokens/sec:</b> {result.tokens_per_second:.2f} | <b>Avg Exit Layer:</b> {result.avg_exit_layer:.1f}</p>
|
| 347 |
-
<p><b>Exit Distribution:</b> {result.exit_distribution}</p>
|
| 348 |
-
</div>
|
| 349 |
-
"""
|
| 350 |
-
yield (html, "", stats, legend)
|
| 351 |
|
| 352 |
else:
|
| 353 |
# Full model mode (streaming)
|
|
|
|
|
|
|
|
|
|
| 354 |
for event in decoder.generate_full_model_streaming(
|
| 355 |
prompt=prompt,
|
| 356 |
max_tokens=int(max_tokens),
|
| 357 |
use_chat_template=True,
|
| 358 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
html = tokens_to_html(event.tokens, head_layers)
|
| 360 |
-
status =
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
)
|
| 374 |
-
html = tokens_to_html(result.tokens, head_layers)
|
| 375 |
-
stats = f"""
|
| 376 |
-
<div style="padding: 15px; background: #f5f5f5; border-radius: 8px;">
|
| 377 |
-
<h4 style="margin: 0 0 10px 0;">Full Model Statistics</h4>
|
| 378 |
-
<p><b>Tokens:</b> {len(result.tokens)} | <b>Time:</b> {result.total_time:.2f}s | <b>Tokens/sec:</b> {result.tokens_per_second:.2f}</p>
|
| 379 |
-
</div>
|
| 380 |
-
"""
|
| 381 |
-
yield (html, "", stats, legend)
|
| 382 |
|
| 383 |
|
| 384 |
def build_demo():
|
|
@@ -444,8 +566,22 @@ def build_demo():
|
|
| 444 |
gr.Markdown("### Full Model (Comparison)")
|
| 445 |
output_full = gr.HTML()
|
| 446 |
|
| 447 |
-
|
| 448 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
|
| 450 |
def update_visibility(compare):
|
| 451 |
return gr.update(visible=compare)
|
|
@@ -466,7 +602,21 @@ def build_demo():
|
|
| 466 |
max_tokens,
|
| 467 |
compare_mode,
|
| 468 |
],
|
| 469 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
)
|
| 471 |
|
| 472 |
return demo
|
|
@@ -474,4 +624,4 @@ def build_demo():
|
|
| 474 |
|
| 475 |
if __name__ == "__main__":
|
| 476 |
demo = build_demo()
|
| 477 |
-
demo.launch(share=False)
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
from pathlib import Path
|
| 9 |
+
import time
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
|
| 12 |
+
from src.inference import load_dssd_model, DSSDecoder, TokenInfo, StreamEvent, StreamingResult
|
| 13 |
|
| 14 |
# Available models configuration
|
| 15 |
AVAILABLE_MODELS = {
|
|
|
|
| 35 |
]
|
| 36 |
FULL_MODEL_COLOR = "#95D5B2" # Light green - Full model
|
| 37 |
|
| 38 |
+
PENDING_TOKEN_BORDER = "var(--border-color-primary)"
|
| 39 |
+
PENDING_TOKEN_TEXT = "var(--body-text-color)"
|
| 40 |
+
DRAFTED_FALLBACK_COLOR = "var(--neutral-200)"
|
| 41 |
+
|
| 42 |
# Global decoder cache
|
| 43 |
_decoder_cache = {}
|
| 44 |
|
|
|
|
| 109 |
|
| 110 |
html_parts.append(
|
| 111 |
f'<span style="background-color: {color}; padding: 2px 4px; '
|
| 112 |
+
f'border-radius: 3px; margin: 1px; display: inline-block; color: #111827;" title="{title}">{text}</span>'
|
| 113 |
)
|
| 114 |
|
| 115 |
# Wrap in container with word-wrap to prevent overflow
|
|
|
|
| 127 |
layer = head_layers[token.exit_head]
|
| 128 |
title = f"PENDING - Head {token.exit_head} (Layer {layer})"
|
| 129 |
else:
|
| 130 |
+
color = DRAFTED_FALLBACK_COLOR
|
| 131 |
+
title = "PENDING - Unassigned"
|
| 132 |
|
| 133 |
text = (
|
| 134 |
token.token_text.replace("&", "&")
|
|
|
|
| 140 |
html_parts.append(
|
| 141 |
f'<span style="background-color: {color}; padding: 2px 4px; '
|
| 142 |
f"border-radius: 3px; margin: 1px; display: inline-block; "
|
| 143 |
+
f"border: 2px dashed {PENDING_TOKEN_BORDER}; color: {PENDING_TOKEN_TEXT}; "
|
| 144 |
+
f'opacity: 0.75;" title="{title}">{text}</span>'
|
| 145 |
)
|
| 146 |
|
| 147 |
return "".join(html_parts)
|
|
|
|
| 163 |
return " ".join(legend_items)
|
| 164 |
|
| 165 |
|
| 166 |
+
|
| 167 |
+
@dataclass
|
| 168 |
+
class StatsPayload:
|
| 169 |
+
generated_at: float
|
| 170 |
+
speedup_text: str
|
| 171 |
+
ee_time: str | None
|
| 172 |
+
ee_tps: str | None
|
| 173 |
+
ee_avg: str | None
|
| 174 |
+
full_time: str | None
|
| 175 |
+
full_tps: str | None
|
| 176 |
+
full_avg: str | None
|
| 177 |
+
show_ee: bool
|
| 178 |
+
show_full: bool
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def build_stats_outputs(
|
| 182 |
+
result_ee,
|
| 183 |
+
result_full,
|
| 184 |
+
use_early_exit: bool,
|
| 185 |
+
compare_mode: bool,
|
| 186 |
+
generated_at: float | None = None,
|
| 187 |
+
):
|
| 188 |
+
speedup_text = ""
|
| 189 |
+
if result_ee and result_full and result_full.tokens_per_second > 0:
|
| 190 |
+
speedup = result_ee.tokens_per_second / result_full.tokens_per_second
|
| 191 |
+
speedup_text = f"**Speedup:** {speedup:.2f}x"
|
| 192 |
+
elif result_ee:
|
| 193 |
+
speedup_text = "**Speedup:** N/A (full model not run)"
|
| 194 |
+
elif result_full:
|
| 195 |
+
speedup_text = "**Speedup:** N/A (early exit disabled)"
|
| 196 |
+
|
| 197 |
+
if not speedup_text:
|
| 198 |
+
speedup_text = "**Speedup:** N/A"
|
| 199 |
+
|
| 200 |
+
ee_time = f"{result_ee.total_time:.2f}" if result_ee else None
|
| 201 |
+
ee_tps = f"{result_ee.tokens_per_second:.2f}" if result_ee else None
|
| 202 |
+
ee_avg = f"{result_ee.avg_exit_layer:.1f}" if result_ee else None
|
| 203 |
+
|
| 204 |
+
full_time = f"{result_full.total_time:.2f}" if result_full else None
|
| 205 |
+
full_tps = f"{result_full.tokens_per_second:.2f}" if result_full else None
|
| 206 |
+
full_avg = f"{result_full.avg_exit_layer:.1f}" if result_full else None
|
| 207 |
+
|
| 208 |
+
show_ee = compare_mode or use_early_exit
|
| 209 |
+
show_full = compare_mode or not use_early_exit
|
| 210 |
+
|
| 211 |
+
return StatsPayload(
|
| 212 |
+
generated_at=generated_at if generated_at is not None else time.time(),
|
| 213 |
+
speedup_text=speedup_text,
|
| 214 |
+
ee_time=ee_time,
|
| 215 |
+
ee_tps=ee_tps,
|
| 216 |
+
ee_avg=ee_avg,
|
| 217 |
+
full_time=full_time,
|
| 218 |
+
full_tps=full_tps,
|
| 219 |
+
full_avg=full_avg,
|
| 220 |
+
show_ee=show_ee,
|
| 221 |
+
show_full=show_full,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def stats_payload_to_outputs(payload: StatsPayload):
|
| 226 |
+
return (
|
| 227 |
+
payload.speedup_text,
|
| 228 |
+
payload.ee_time,
|
| 229 |
+
payload.ee_tps,
|
| 230 |
+
payload.ee_avg,
|
| 231 |
+
payload.full_time,
|
| 232 |
+
payload.full_tps,
|
| 233 |
+
payload.full_avg,
|
| 234 |
+
gr.update(visible=payload.show_ee),
|
| 235 |
+
gr.update(visible=payload.show_full),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
|
| 239 |
|
| 240 |
def generate(
|
|
|
|
| 246 |
compare_mode: bool,
|
| 247 |
):
|
| 248 |
"""Main generation function for Gradio interface with streaming."""
|
| 249 |
+
initial_stats_timestamp = time.time()
|
| 250 |
try:
|
| 251 |
decoder = get_decoder(model_key)
|
| 252 |
except Exception as e:
|
| 253 |
error_msg = f"<p style='color: red;'>Error loading model: {e}</p>"
|
| 254 |
+
status_msg = f"**Error loading model:** {e}"
|
| 255 |
+
stats_payload = build_stats_outputs(
|
| 256 |
+
None,
|
| 257 |
+
None,
|
| 258 |
+
use_early_exit,
|
| 259 |
+
compare_mode,
|
| 260 |
+
generated_at=initial_stats_timestamp,
|
| 261 |
+
)
|
| 262 |
+
yield (
|
| 263 |
+
error_msg,
|
| 264 |
+
"",
|
| 265 |
+
status_msg,
|
| 266 |
+
*stats_payload_to_outputs(stats_payload),
|
| 267 |
+
"",
|
| 268 |
+
)
|
| 269 |
return
|
| 270 |
|
| 271 |
+
|
| 272 |
head_layers = decoder.model_config.head_layer_indices
|
| 273 |
legend = create_legend(head_layers)
|
| 274 |
|
|
|
|
| 283 |
# Compare mode with streaming for early exit
|
| 284 |
# First, stream the early exit generation
|
| 285 |
final_ee_tokens = []
|
| 286 |
+
ee_streaming_result = None
|
| 287 |
+
|
| 288 |
for event in decoder.generate_streaming(
|
| 289 |
prompt=prompt,
|
| 290 |
max_tokens=int(max_tokens),
|
| 291 |
accuracy_level=closest_level,
|
| 292 |
use_chat_template=True,
|
| 293 |
):
|
| 294 |
+
# Handle "complete" event - extract result and break
|
| 295 |
+
if event.event_type == "complete":
|
| 296 |
+
ee_streaming_result = event.result
|
| 297 |
+
final_ee_tokens = event.tokens
|
| 298 |
+
break
|
| 299 |
+
|
| 300 |
+
final_ee_tokens = event.tokens
|
| 301 |
validated_html = ""
|
| 302 |
if event.tokens:
|
| 303 |
validated_html = tokens_to_html(event.tokens, head_layers)
|
|
|
|
| 312 |
|
| 313 |
combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""
|
| 314 |
|
| 315 |
+
status = (
|
| 316 |
+
"**Early Exit:** {message} \n"
|
| 317 |
+
"**Full Model:** Waiting..."
|
| 318 |
+
).format(
|
| 319 |
+
message=event.message,
|
| 320 |
+
)
|
| 321 |
|
| 322 |
+
stats_payload = build_stats_outputs(
|
| 323 |
+
None,
|
| 324 |
+
None,
|
| 325 |
+
use_early_exit,
|
| 326 |
+
compare_mode,
|
| 327 |
+
generated_at=initial_stats_timestamp,
|
| 328 |
+
)
|
| 329 |
yield (
|
| 330 |
combined_html,
|
| 331 |
+
"<p style='color: var(--body-text-color-subdued);'>Waiting for early exit to complete...</p>",
|
| 332 |
status,
|
| 333 |
+
*stats_payload_to_outputs(stats_payload),
|
| 334 |
legend,
|
| 335 |
)
|
|
|
|
| 336 |
|
| 337 |
# Now stream full model
|
| 338 |
final_full_tokens = []
|
| 339 |
+
full_streaming_result = None
|
| 340 |
+
|
| 341 |
for event in decoder.generate_full_model_streaming(
|
| 342 |
prompt=prompt,
|
| 343 |
max_tokens=int(max_tokens),
|
| 344 |
use_chat_template=True,
|
| 345 |
):
|
| 346 |
+
# Handle "complete" event - extract result and break
|
| 347 |
+
if event.event_type == "complete":
|
| 348 |
+
full_streaming_result = event.result
|
| 349 |
+
final_full_tokens = event.tokens
|
| 350 |
+
break
|
| 351 |
+
|
| 352 |
+
final_full_tokens = event.tokens
|
| 353 |
html_full = tokens_to_html(event.tokens, head_layers)
|
| 354 |
+
status = (
|
| 355 |
+
"**Full Model:** {message}"
|
| 356 |
+
).format(
|
| 357 |
+
message=event.message,
|
| 358 |
+
)
|
| 359 |
+
stats_payload = build_stats_outputs(
|
| 360 |
+
None,
|
| 361 |
+
None,
|
| 362 |
+
use_early_exit,
|
| 363 |
+
compare_mode,
|
| 364 |
+
generated_at=initial_stats_timestamp,
|
| 365 |
+
)
|
| 366 |
yield (
|
| 367 |
tokens_to_html(final_ee_tokens, head_layers),
|
| 368 |
html_full,
|
| 369 |
status,
|
| 370 |
+
*stats_payload_to_outputs(stats_payload),
|
| 371 |
legend,
|
| 372 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
+
# Final output with metrics from streaming results (no re-run needed)
|
| 375 |
+
html_ee = tokens_to_html(final_ee_tokens, head_layers)
|
| 376 |
+
html_full = tokens_to_html(final_full_tokens, head_layers)
|
| 377 |
+
|
| 378 |
+
stats_payload = build_stats_outputs(ee_streaming_result, full_streaming_result, use_early_exit, compare_mode)
|
| 379 |
+
yield (
|
| 380 |
+
html_ee,
|
| 381 |
+
html_full,
|
| 382 |
+
"",
|
| 383 |
+
*stats_payload_to_outputs(stats_payload),
|
| 384 |
+
legend,
|
| 385 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
elif use_early_exit:
|
| 388 |
# STREAMING mode for early exit - show draft/verify process
|
| 389 |
+
streaming_result = None
|
| 390 |
+
final_tokens = []
|
| 391 |
+
|
| 392 |
for event in decoder.generate_streaming(
|
| 393 |
prompt=prompt,
|
| 394 |
max_tokens=int(max_tokens),
|
| 395 |
accuracy_level=closest_level,
|
| 396 |
use_chat_template=True,
|
| 397 |
):
|
| 398 |
+
# Handle "complete" event - extract result and break
|
| 399 |
+
if event.event_type == "complete":
|
| 400 |
+
streaming_result = event.result
|
| 401 |
+
final_tokens = event.tokens
|
| 402 |
+
break
|
| 403 |
+
|
| 404 |
+
final_tokens = event.tokens
|
| 405 |
+
|
| 406 |
# Build HTML showing validated + drafted tokens
|
| 407 |
validated_html = ""
|
| 408 |
if event.tokens:
|
|
|
|
| 421 |
combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""
|
| 422 |
|
| 423 |
# Status message
|
| 424 |
+
status = (
|
| 425 |
+
"**Status:** {message}"
|
| 426 |
+
).format(
|
| 427 |
+
message=event.message,
|
| 428 |
+
)
|
| 429 |
|
| 430 |
+
stats_payload = build_stats_outputs(
|
| 431 |
+
None,
|
| 432 |
+
None,
|
| 433 |
+
use_early_exit,
|
| 434 |
+
compare_mode,
|
| 435 |
+
generated_at=initial_stats_timestamp,
|
| 436 |
+
)
|
| 437 |
+
yield (
|
| 438 |
+
combined_html,
|
| 439 |
+
"",
|
| 440 |
+
status,
|
| 441 |
+
*stats_payload_to_outputs(stats_payload),
|
| 442 |
+
legend,
|
| 443 |
+
)
|
| 444 |
|
| 445 |
+
# Final output with metrics from streaming result (no re-run needed)
|
| 446 |
+
html = tokens_to_html(final_tokens, head_layers)
|
| 447 |
+
stats_payload = build_stats_outputs(streaming_result, None, use_early_exit, compare_mode)
|
| 448 |
+
yield (
|
| 449 |
+
html,
|
| 450 |
+
"",
|
| 451 |
+
"",
|
| 452 |
+
*stats_payload_to_outputs(stats_payload),
|
| 453 |
+
legend,
|
| 454 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
else:
|
| 457 |
# Full model mode (streaming)
|
| 458 |
+
streaming_result = None
|
| 459 |
+
final_tokens = []
|
| 460 |
+
|
| 461 |
for event in decoder.generate_full_model_streaming(
|
| 462 |
prompt=prompt,
|
| 463 |
max_tokens=int(max_tokens),
|
| 464 |
use_chat_template=True,
|
| 465 |
):
|
| 466 |
+
# Handle "complete" event - extract result and break
|
| 467 |
+
if event.event_type == "complete":
|
| 468 |
+
streaming_result = event.result
|
| 469 |
+
final_tokens = event.tokens
|
| 470 |
+
break
|
| 471 |
+
|
| 472 |
+
final_tokens = event.tokens
|
| 473 |
html = tokens_to_html(event.tokens, head_layers)
|
| 474 |
+
status = (
|
| 475 |
+
"**Full Model:** {message}"
|
| 476 |
+
).format(
|
| 477 |
+
message=event.message,
|
| 478 |
+
)
|
| 479 |
+
stats_payload = build_stats_outputs(
|
| 480 |
+
None,
|
| 481 |
+
None,
|
| 482 |
+
use_early_exit,
|
| 483 |
+
compare_mode,
|
| 484 |
+
generated_at=initial_stats_timestamp,
|
| 485 |
+
)
|
| 486 |
+
yield (
|
| 487 |
+
html,
|
| 488 |
+
"",
|
| 489 |
+
status,
|
| 490 |
+
*stats_payload_to_outputs(stats_payload),
|
| 491 |
+
legend,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# Final output with metrics from streaming result (no re-run needed)
|
| 495 |
+
html = tokens_to_html(final_tokens, head_layers)
|
| 496 |
+
stats_payload = build_stats_outputs(None, streaming_result, use_early_exit, compare_mode)
|
| 497 |
+
yield (
|
| 498 |
+
html,
|
| 499 |
+
"",
|
| 500 |
+
"",
|
| 501 |
+
*stats_payload_to_outputs(stats_payload),
|
| 502 |
+
legend,
|
| 503 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
|
| 505 |
|
| 506 |
def build_demo():
|
|
|
|
| 566 |
gr.Markdown("### Full Model (Comparison)")
|
| 567 |
output_full = gr.HTML()
|
| 568 |
|
| 569 |
+
status_html = gr.Markdown()
|
| 570 |
+
|
| 571 |
+
with gr.Group():
|
| 572 |
+
gr.Markdown("### Speedup Recap")
|
| 573 |
+
speedup_md = gr.Markdown()
|
| 574 |
+
with gr.Row():
|
| 575 |
+
with gr.Column(visible=True) as ee_stats_col:
|
| 576 |
+
gr.Markdown("#### Early Exit")
|
| 577 |
+
ee_time = gr.Label(label="Time (s)")
|
| 578 |
+
ee_tps = gr.Label(label="Tokens/sec")
|
| 579 |
+
ee_avg = gr.Label(label="Avg Exit Layer")
|
| 580 |
+
with gr.Column(visible=False) as full_stats_col:
|
| 581 |
+
gr.Markdown("#### Full Model")
|
| 582 |
+
full_time = gr.Label(label="Time (s)")
|
| 583 |
+
full_tps = gr.Label(label="Tokens/sec")
|
| 584 |
+
full_avg = gr.Label(label="Avg Exit Layer")
|
| 585 |
|
| 586 |
def update_visibility(compare):
|
| 587 |
return gr.update(visible=compare)
|
|
|
|
| 602 |
max_tokens,
|
| 603 |
compare_mode,
|
| 604 |
],
|
| 605 |
+
outputs=[
|
| 606 |
+
output_ee,
|
| 607 |
+
output_full,
|
| 608 |
+
status_html,
|
| 609 |
+
speedup_md,
|
| 610 |
+
ee_time,
|
| 611 |
+
ee_tps,
|
| 612 |
+
ee_avg,
|
| 613 |
+
full_time,
|
| 614 |
+
full_tps,
|
| 615 |
+
full_avg,
|
| 616 |
+
ee_stats_col,
|
| 617 |
+
full_stats_col,
|
| 618 |
+
legend_html,
|
| 619 |
+
],
|
| 620 |
)
|
| 621 |
|
| 622 |
return demo
|
|
|
|
| 624 |
|
| 625 |
if __name__ == "__main__":
|
| 626 |
demo = build_demo()
|
| 627 |
+
demo.launch(share=False, debug=True)
|
src/inference.py
CHANGED
|
@@ -53,14 +53,47 @@ class TokenInfo:
|
|
| 53 |
uncertainty: float
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
@dataclass
|
| 57 |
class StreamEvent:
|
| 58 |
"""Event for streaming generation updates."""
|
| 59 |
|
| 60 |
-
event_type: str # "draft", "verify_start", "accept", "reject", "full_model"
|
| 61 |
tokens: List[TokenInfo] # All tokens so far (validated)
|
| 62 |
drafted_tokens: List[TokenInfo] # Currently drafted (pending verification)
|
| 63 |
message: str # Human-readable status
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
@dataclass
|
|
@@ -100,19 +133,8 @@ class DSSDecoder:
|
|
| 100 |
self.device = device
|
| 101 |
self.uncertainty_fn = compute_entropy
|
| 102 |
|
| 103 |
-
def
|
| 104 |
-
|
| 105 |
-
prompt: str,
|
| 106 |
-
max_tokens: int = 100,
|
| 107 |
-
use_early_exit: bool = True,
|
| 108 |
-
accuracy_level: float = 0.75,
|
| 109 |
-
use_chat_template: bool = True,
|
| 110 |
-
) -> GenerationResult:
|
| 111 |
-
"""
|
| 112 |
-
Generate text with optional early exit.
|
| 113 |
-
Returns detailed token-level information for visualization.
|
| 114 |
-
"""
|
| 115 |
-
# Format prompt - check if tokenizer has a chat template set
|
| 116 |
if (
|
| 117 |
use_chat_template
|
| 118 |
and hasattr(self.tokenizer, "chat_template")
|
|
@@ -123,18 +145,26 @@ class DSSDecoder:
|
|
| 123 |
formatted = self.tokenizer.apply_chat_template(
|
| 124 |
messages, add_generation_prompt=True, tokenize=False
|
| 125 |
)
|
| 126 |
-
|
| 127 |
self.device
|
| 128 |
)
|
| 129 |
except Exception:
|
| 130 |
-
#
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
# Get thresholds
|
| 140 |
thresholds = {}
|
|
@@ -186,29 +216,9 @@ class DSSDecoder:
|
|
| 186 |
"""
|
| 187 |
Generate with streaming - yields events showing draft/verify process.
|
| 188 |
Each event shows current validated tokens and pending drafted tokens.
|
|
|
|
| 189 |
"""
|
| 190 |
-
|
| 191 |
-
if (
|
| 192 |
-
use_chat_template
|
| 193 |
-
and hasattr(self.tokenizer, "chat_template")
|
| 194 |
-
and self.tokenizer.chat_template is not None
|
| 195 |
-
):
|
| 196 |
-
try:
|
| 197 |
-
messages = [{"role": "user", "content": prompt}]
|
| 198 |
-
formatted = self.tokenizer.apply_chat_template(
|
| 199 |
-
messages, add_generation_prompt=True, tokenize=False
|
| 200 |
-
)
|
| 201 |
-
input_ids = self.tokenizer.encode(formatted, return_tensors="pt").to(
|
| 202 |
-
self.device
|
| 203 |
-
)
|
| 204 |
-
except Exception:
|
| 205 |
-
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
| 206 |
-
self.device
|
| 207 |
-
)
|
| 208 |
-
else:
|
| 209 |
-
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
| 210 |
-
self.device
|
| 211 |
-
)
|
| 212 |
|
| 213 |
# Get thresholds
|
| 214 |
thresholds = {}
|
|
@@ -218,6 +228,7 @@ class DSSDecoder:
|
|
| 218 |
validated_tokens = []
|
| 219 |
current_ids = input_ids.clone()
|
| 220 |
num_layers = self.adapter.get_num_layers()
|
|
|
|
| 221 |
|
| 222 |
while len(validated_tokens) < max_tokens:
|
| 223 |
# ============================================================
|
|
@@ -226,6 +237,7 @@ class DSSDecoder:
|
|
| 226 |
drafted_tokens = []
|
| 227 |
draft_ids = current_ids.clone()
|
| 228 |
got_lm_head_token = False
|
|
|
|
| 229 |
|
| 230 |
for _ in range(max_draft_length):
|
| 231 |
if len(validated_tokens) + len(drafted_tokens) >= max_tokens:
|
|
@@ -240,7 +252,8 @@ class DSSDecoder:
|
|
| 240 |
# EOS handling
|
| 241 |
if exit_head is not None and drafted_tokens:
|
| 242 |
break # Verify pending drafts first
|
| 243 |
-
|
|
|
|
| 244 |
|
| 245 |
token_text = self.tokenizer.decode([token_id])
|
| 246 |
drafted_token = TokenInfo(
|
|
@@ -274,6 +287,10 @@ class DSSDecoder:
|
|
| 274 |
message=f"Drafting token {len(drafted_tokens)} using Head {exit_head}",
|
| 275 |
)
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
# ============================================================
|
| 278 |
# VERIFY PHASE
|
| 279 |
# ============================================================
|
|
@@ -382,6 +399,17 @@ class DSSDecoder:
|
|
| 382 |
):
|
| 383 |
break
|
| 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
def _generate_with_early_exit(
|
| 386 |
self,
|
| 387 |
input_ids: torch.Tensor,
|
|
@@ -773,33 +801,14 @@ class DSSDecoder:
|
|
| 773 |
):
|
| 774 |
"""
|
| 775 |
Generate with full model in streaming mode - yields each token as generated.
|
|
|
|
| 776 |
"""
|
| 777 |
-
|
| 778 |
-
if (
|
| 779 |
-
use_chat_template
|
| 780 |
-
and hasattr(self.tokenizer, "chat_template")
|
| 781 |
-
and self.tokenizer.chat_template is not None
|
| 782 |
-
):
|
| 783 |
-
try:
|
| 784 |
-
messages = [{"role": "user", "content": prompt}]
|
| 785 |
-
formatted = self.tokenizer.apply_chat_template(
|
| 786 |
-
messages, add_generation_prompt=True, tokenize=False
|
| 787 |
-
)
|
| 788 |
-
input_ids = self.tokenizer.encode(formatted, return_tensors="pt").to(
|
| 789 |
-
self.device
|
| 790 |
-
)
|
| 791 |
-
except Exception:
|
| 792 |
-
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
| 793 |
-
self.device
|
| 794 |
-
)
|
| 795 |
-
else:
|
| 796 |
-
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
| 797 |
-
self.device
|
| 798 |
-
)
|
| 799 |
|
| 800 |
tokens = []
|
| 801 |
current_ids = input_ids.clone()
|
| 802 |
num_layers = self.adapter.get_num_layers()
|
|
|
|
| 803 |
|
| 804 |
for i in range(max_tokens):
|
| 805 |
with torch.no_grad():
|
|
@@ -832,6 +841,17 @@ class DSSDecoder:
|
|
| 832 |
message=f"Token {i + 1}: '{token_text}'",
|
| 833 |
)
|
| 834 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 835 |
|
| 836 |
def load_dssd_model(
|
| 837 |
model_name: str,
|
|
|
|
| 53 |
uncertainty: float
|
| 54 |
|
| 55 |
|
| 56 |
+
@dataclass
|
| 57 |
+
class StreamingResult:
|
| 58 |
+
"""Result from streaming generation with accumulated metrics."""
|
| 59 |
+
|
| 60 |
+
tokens: List[TokenInfo]
|
| 61 |
+
total_time: float
|
| 62 |
+
tokens_per_second: float
|
| 63 |
+
avg_exit_layer: float
|
| 64 |
+
exit_distribution: Dict[str, int]
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def from_tokens(cls, tokens: List[TokenInfo], total_time: float, num_layers: int) -> "StreamingResult":
|
| 68 |
+
"""Build a StreamingResult from a list of tokens and timing info."""
|
| 69 |
+
exit_dist: Dict[str, int] = {}
|
| 70 |
+
layer_sum = 0
|
| 71 |
+
|
| 72 |
+
for t in tokens:
|
| 73 |
+
key = str(t.exit_head) if t.exit_head is not None else "full"
|
| 74 |
+
exit_dist[key] = exit_dist.get(key, 0) + 1
|
| 75 |
+
layer_sum += t.exit_layer
|
| 76 |
+
|
| 77 |
+
avg_layer = layer_sum / len(tokens) if tokens else num_layers
|
| 78 |
+
|
| 79 |
+
return cls(
|
| 80 |
+
tokens=tokens,
|
| 81 |
+
total_time=total_time,
|
| 82 |
+
tokens_per_second=len(tokens) / total_time if total_time > 0 else 0,
|
| 83 |
+
avg_exit_layer=avg_layer,
|
| 84 |
+
exit_distribution=exit_dist,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
@dataclass
|
| 89 |
class StreamEvent:
|
| 90 |
"""Event for streaming generation updates."""
|
| 91 |
|
| 92 |
+
event_type: str # "draft", "verify_start", "accept", "reject", "full_model", "complete"
|
| 93 |
tokens: List[TokenInfo] # All tokens so far (validated)
|
| 94 |
drafted_tokens: List[TokenInfo] # Currently drafted (pending verification)
|
| 95 |
message: str # Human-readable status
|
| 96 |
+
result: Optional[StreamingResult] = None # Set on final "complete" event
|
| 97 |
|
| 98 |
|
| 99 |
@dataclass
|
|
|
|
| 133 |
self.device = device
|
| 134 |
self.uncertainty_fn = compute_entropy
|
| 135 |
|
| 136 |
+
def _format_and_encode_prompt(self, prompt: str, use_chat_template: bool) -> torch.Tensor:
|
| 137 |
+
"""Format prompt with optional chat template and return input_ids tensor."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
if (
|
| 139 |
use_chat_template
|
| 140 |
and hasattr(self.tokenizer, "chat_template")
|
|
|
|
| 145 |
formatted = self.tokenizer.apply_chat_template(
|
| 146 |
messages, add_generation_prompt=True, tokenize=False
|
| 147 |
)
|
| 148 |
+
return self.tokenizer.encode(formatted, return_tensors="pt").to(
|
| 149 |
self.device
|
| 150 |
)
|
| 151 |
except Exception:
|
| 152 |
+
pass # Fall through to raw prompt encoding
|
| 153 |
+
return self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
| 154 |
+
|
| 155 |
+
def generate(
|
| 156 |
+
self,
|
| 157 |
+
prompt: str,
|
| 158 |
+
max_tokens: int = 100,
|
| 159 |
+
use_early_exit: bool = True,
|
| 160 |
+
accuracy_level: float = 0.75,
|
| 161 |
+
use_chat_template: bool = True,
|
| 162 |
+
) -> GenerationResult:
|
| 163 |
+
"""
|
| 164 |
+
Generate text with optional early exit.
|
| 165 |
+
Returns detailed token-level information for visualization.
|
| 166 |
+
"""
|
| 167 |
+
input_ids = self._format_and_encode_prompt(prompt, use_chat_template)
|
| 168 |
|
| 169 |
# Get thresholds
|
| 170 |
thresholds = {}
|
|
|
|
| 216 |
"""
|
| 217 |
Generate with streaming - yields events showing draft/verify process.
|
| 218 |
Each event shows current validated tokens and pending drafted tokens.
|
| 219 |
+
Yields a final "complete" event with StreamingResult containing metrics.
|
| 220 |
"""
|
| 221 |
+
input_ids = self._format_and_encode_prompt(prompt, use_chat_template)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
# Get thresholds
|
| 224 |
thresholds = {}
|
|
|
|
| 228 |
validated_tokens = []
|
| 229 |
current_ids = input_ids.clone()
|
| 230 |
num_layers = self.adapter.get_num_layers()
|
| 231 |
+
start_time = time.time()
|
| 232 |
|
| 233 |
while len(validated_tokens) < max_tokens:
|
| 234 |
# ============================================================
|
|
|
|
| 237 |
drafted_tokens = []
|
| 238 |
draft_ids = current_ids.clone()
|
| 239 |
got_lm_head_token = False
|
| 240 |
+
should_stop = False
|
| 241 |
|
| 242 |
for _ in range(max_draft_length):
|
| 243 |
if len(validated_tokens) + len(drafted_tokens) >= max_tokens:
|
|
|
|
| 252 |
# EOS handling
|
| 253 |
if exit_head is not None and drafted_tokens:
|
| 254 |
break # Verify pending drafts first
|
| 255 |
+
should_stop = True
|
| 256 |
+
break # Stop generation
|
| 257 |
|
| 258 |
token_text = self.tokenizer.decode([token_id])
|
| 259 |
drafted_token = TokenInfo(
|
|
|
|
| 287 |
message=f"Drafting token {len(drafted_tokens)} using Head {exit_head}",
|
| 288 |
)
|
| 289 |
|
| 290 |
+
# Check if we should stop (EOS encountered with no pending drafts)
|
| 291 |
+
if should_stop:
|
| 292 |
+
break
|
| 293 |
+
|
| 294 |
# ============================================================
|
| 295 |
# VERIFY PHASE
|
| 296 |
# ============================================================
|
|
|
|
| 399 |
):
|
| 400 |
break
|
| 401 |
|
| 402 |
+
# Yield final "complete" event with metrics
|
| 403 |
+
total_time = time.time() - start_time
|
| 404 |
+
result = StreamingResult.from_tokens(validated_tokens, total_time, num_layers)
|
| 405 |
+
yield StreamEvent(
|
| 406 |
+
event_type="complete",
|
| 407 |
+
tokens=list(validated_tokens),
|
| 408 |
+
drafted_tokens=[],
|
| 409 |
+
message="Generation complete",
|
| 410 |
+
result=result,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
def _generate_with_early_exit(
|
| 414 |
self,
|
| 415 |
input_ids: torch.Tensor,
|
|
|
|
| 801 |
):
|
| 802 |
"""
|
| 803 |
Generate with full model in streaming mode - yields each token as generated.
|
| 804 |
+
Yields a final "complete" event with StreamingResult containing metrics.
|
| 805 |
"""
|
| 806 |
+
input_ids = self._format_and_encode_prompt(prompt, use_chat_template)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
|
| 808 |
tokens = []
|
| 809 |
current_ids = input_ids.clone()
|
| 810 |
num_layers = self.adapter.get_num_layers()
|
| 811 |
+
start_time = time.time()
|
| 812 |
|
| 813 |
for i in range(max_tokens):
|
| 814 |
with torch.no_grad():
|
|
|
|
| 841 |
message=f"Token {i + 1}: '{token_text}'",
|
| 842 |
)
|
| 843 |
|
| 844 |
+
# Yield final "complete" event with metrics
|
| 845 |
+
total_time = time.time() - start_time
|
| 846 |
+
result = StreamingResult.from_tokens(tokens, total_time, num_layers)
|
| 847 |
+
yield StreamEvent(
|
| 848 |
+
event_type="complete",
|
| 849 |
+
tokens=list(tokens),
|
| 850 |
+
drafted_tokens=[],
|
| 851 |
+
message="Generation complete",
|
| 852 |
+
result=result,
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
|
| 856 |
def load_dssd_model(
|
| 857 |
model_name: str,
|
src/jagged_cache.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
JaggedKVCache - Sparse KV Cache for Early Exit Inference.
|
| 3 |
+
|
| 4 |
+
This cache tracks per-layer sequence lengths, enabling efficient
|
| 5 |
+
generation with early exit heads that stop at different layers.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from typing import List, Tuple, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class JaggedKVCache:
|
| 13 |
+
"""
|
| 14 |
+
Sparse KV Cache that tracks per-layer sequence lengths.
|
| 15 |
+
|
| 16 |
+
Unlike standard KV caches where all layers have the same length,
|
| 17 |
+
this cache allows different layers to have different cached lengths.
|
| 18 |
+
This is essential for early exit inference where tokens may exit
|
| 19 |
+
at different layers.
|
| 20 |
+
|
| 21 |
+
Key features:
|
| 22 |
+
- Per-layer KV storage with independent lengths
|
| 23 |
+
- Lazy fill: missing positions are detected and can be computed on-demand
|
| 24 |
+
- Truncation: efficient rollback on rejection
|
| 25 |
+
- Cloning: snapshot for speculative drafting
|
| 26 |
+
|
| 27 |
+
Attributes:
|
| 28 |
+
num_layers: Total number of transformer layers
|
| 29 |
+
batch_size: Batch size (typically 1 for inference)
|
| 30 |
+
num_kv_heads: Number of key-value heads
|
| 31 |
+
head_dim: Dimension of each head
|
| 32 |
+
device: Device to store tensors on
|
| 33 |
+
dtype: Data type for tensors
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
num_layers: int,
|
| 39 |
+
batch_size: int = 1,
|
| 40 |
+
num_kv_heads: int = 8,
|
| 41 |
+
head_dim: int = 128,
|
| 42 |
+
device: str = "cpu",
|
| 43 |
+
dtype: torch.dtype = torch.float32,
|
| 44 |
+
):
|
| 45 |
+
self.num_layers = num_layers
|
| 46 |
+
self.batch_size = batch_size
|
| 47 |
+
self.num_kv_heads = num_kv_heads
|
| 48 |
+
self.head_dim = head_dim
|
| 49 |
+
self.device = device
|
| 50 |
+
self.dtype = dtype
|
| 51 |
+
|
| 52 |
+
# Per-layer storage: List of (key_cache, value_cache) or None
|
| 53 |
+
self.layer_caches: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
|
| 54 |
+
None for _ in range(num_layers)
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
# Track sequence length per layer (capacity = max_position + 1)
|
| 58 |
+
self.layer_seq_lengths: List[int] = [0] * num_layers
|
| 59 |
+
|
| 60 |
+
# Track which positions are actually filled (for lazy fill detection)
|
| 61 |
+
# This is a list of sets, one per layer
|
| 62 |
+
self.filled_positions: List[set] = [set() for _ in range(num_layers)]
|
| 63 |
+
|
| 64 |
+
def update(
|
| 65 |
+
self,
|
| 66 |
+
layer_idx: int,
|
| 67 |
+
key_states: torch.Tensor,
|
| 68 |
+
value_states: torch.Tensor,
|
| 69 |
+
cache_position: torch.Tensor,
|
| 70 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 71 |
+
"""
|
| 72 |
+
Update cache for a layer at specific positions.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
layer_idx: Layer index to update
|
| 76 |
+
key_states: [B, num_kv_heads, seq_len, head_dim] new key states
|
| 77 |
+
value_states: [B, num_kv_heads, seq_len, head_dim] new value states
|
| 78 |
+
cache_position: [seq_len] tensor of positions to update
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
(full_keys, full_values) tuple with all cached data
|
| 82 |
+
"""
|
| 83 |
+
new_len = cache_position[-1].item() + 1
|
| 84 |
+
input_seq_len = key_states.shape[2]
|
| 85 |
+
positions = cache_position.tolist()
|
| 86 |
+
|
| 87 |
+
if self.layer_caches[layer_idx] is None:
|
| 88 |
+
# First time - check if positions are contiguous starting from 0
|
| 89 |
+
if cache_position[0].item() == 0 and input_seq_len == new_len:
|
| 90 |
+
# Simple case: positions [0, 1, ..., n-1] - just clone
|
| 91 |
+
self.layer_caches[layer_idx] = (
|
| 92 |
+
key_states.clone(),
|
| 93 |
+
value_states.clone(),
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
# Non-contiguous or not starting from 0 - allocate full size
|
| 97 |
+
k_cache = torch.zeros(
|
| 98 |
+
(self.batch_size, self.num_kv_heads, new_len, self.head_dim),
|
| 99 |
+
device=self.device,
|
| 100 |
+
dtype=self.dtype,
|
| 101 |
+
)
|
| 102 |
+
v_cache = torch.zeros(
|
| 103 |
+
(self.batch_size, self.num_kv_heads, new_len, self.head_dim),
|
| 104 |
+
device=self.device,
|
| 105 |
+
dtype=self.dtype,
|
| 106 |
+
)
|
| 107 |
+
k_cache[:, :, cache_position.long(), :] = key_states
|
| 108 |
+
v_cache[:, :, cache_position.long(), :] = value_states
|
| 109 |
+
self.layer_caches[layer_idx] = (k_cache, v_cache)
|
| 110 |
+
|
| 111 |
+
self.layer_seq_lengths[layer_idx] = new_len
|
| 112 |
+
else:
|
| 113 |
+
k_cache, v_cache = self.layer_caches[layer_idx]
|
| 114 |
+
current_len = k_cache.shape[2]
|
| 115 |
+
|
| 116 |
+
if new_len > current_len:
|
| 117 |
+
# Need to extend cache
|
| 118 |
+
extension_size = new_len - current_len
|
| 119 |
+
k_extension = torch.zeros(
|
| 120 |
+
(self.batch_size, self.num_kv_heads, extension_size, self.head_dim),
|
| 121 |
+
device=self.device,
|
| 122 |
+
dtype=self.dtype,
|
| 123 |
+
)
|
| 124 |
+
v_extension = torch.zeros(
|
| 125 |
+
(self.batch_size, self.num_kv_heads, extension_size, self.head_dim),
|
| 126 |
+
device=self.device,
|
| 127 |
+
dtype=self.dtype,
|
| 128 |
+
)
|
| 129 |
+
k_cache = torch.cat([k_cache, k_extension], dim=2)
|
| 130 |
+
v_cache = torch.cat([v_cache, v_extension], dim=2)
|
| 131 |
+
|
| 132 |
+
# Update at cache_position
|
| 133 |
+
k_cache[:, :, cache_position.long(), :] = key_states
|
| 134 |
+
v_cache[:, :, cache_position.long(), :] = value_states
|
| 135 |
+
|
| 136 |
+
self.layer_caches[layer_idx] = (k_cache, v_cache)
|
| 137 |
+
self.layer_seq_lengths[layer_idx] = max(
|
| 138 |
+
self.layer_seq_lengths[layer_idx], new_len
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Track filled positions
|
| 142 |
+
self.filled_positions[layer_idx].update(positions)
|
| 143 |
+
|
| 144 |
+
return self.layer_caches[layer_idx]
|
| 145 |
+
|
| 146 |
+
def get_kv(self, layer_idx: int) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
| 147 |
+
"""Get cached KV for a layer, or None if not cached."""
|
| 148 |
+
return self.layer_caches[layer_idx]
|
| 149 |
+
|
| 150 |
+
def get_seq_length(self, layer_idx: int) -> int:
|
| 151 |
+
"""Get the sequence length (capacity) for a layer."""
|
| 152 |
+
return self.layer_seq_lengths[layer_idx]
|
| 153 |
+
|
| 154 |
+
def has_position(self, layer_idx: int, position: int) -> bool:
|
| 155 |
+
"""Check if a specific position is filled for a layer."""
|
| 156 |
+
return position in self.filled_positions[layer_idx]
|
| 157 |
+
|
| 158 |
+
def get_unfilled_positions(self, layer_idx: int, up_to: int) -> List[int]:
|
| 159 |
+
"""Get list of positions that are not filled for a layer, up to `up_to` (exclusive)."""
|
| 160 |
+
all_positions = set(range(up_to))
|
| 161 |
+
filled = self.filled_positions[layer_idx]
|
| 162 |
+
return sorted(all_positions - filled)
|
| 163 |
+
|
| 164 |
+
def needs_fill(self, layer_idx: int, positions: List[int]) -> bool:
|
| 165 |
+
"""Check if any of the given positions need to be filled for a layer."""
|
| 166 |
+
return not all(p in self.filled_positions[layer_idx] for p in positions)
|
| 167 |
+
|
| 168 |
+
def get_missing_layers(self, position: int, target_layer: int) -> List[int]:
|
| 169 |
+
"""
|
| 170 |
+
Get list of layers that need computation for a position.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
position: The position we need KV for
|
| 174 |
+
target_layer: The deepest layer we need to reach
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
List of layer indices that need computation for this position
|
| 178 |
+
"""
|
| 179 |
+
missing = []
|
| 180 |
+
for layer_idx in range(target_layer + 1):
|
| 181 |
+
if position not in self.filled_positions[layer_idx]:
|
| 182 |
+
missing.append(layer_idx)
|
| 183 |
+
return missing
|
| 184 |
+
|
| 185 |
+
def truncate_from(self, position: int):
|
| 186 |
+
"""
|
| 187 |
+
Truncate all layer caches from position onwards (exclusive).
|
| 188 |
+
Used for rollback on rejection.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
position: First position to remove (keeps 0..position-1)
|
| 192 |
+
"""
|
| 193 |
+
for layer_idx in range(self.num_layers):
|
| 194 |
+
if self.layer_caches[layer_idx] is not None:
|
| 195 |
+
k, v = self.layer_caches[layer_idx]
|
| 196 |
+
if k.shape[2] > position:
|
| 197 |
+
self.layer_caches[layer_idx] = (
|
| 198 |
+
k[:, :, :position, :].contiguous(),
|
| 199 |
+
v[:, :, :position, :].contiguous(),
|
| 200 |
+
)
|
| 201 |
+
self.layer_seq_lengths[layer_idx] = min(
|
| 202 |
+
self.layer_seq_lengths[layer_idx], position
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Remove filled positions >= position
|
| 206 |
+
self.filled_positions[layer_idx] = {
|
| 207 |
+
p for p in self.filled_positions[layer_idx] if p < position
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
def clone(self) -> "JaggedKVCache":
|
| 211 |
+
"""
|
| 212 |
+
Create a deep copy of the cache for speculative drafting.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
Independent copy that can be modified without affecting original
|
| 216 |
+
"""
|
| 217 |
+
new_cache = JaggedKVCache(
|
| 218 |
+
num_layers=self.num_layers,
|
| 219 |
+
batch_size=self.batch_size,
|
| 220 |
+
num_kv_heads=self.num_kv_heads,
|
| 221 |
+
head_dim=self.head_dim,
|
| 222 |
+
device=self.device,
|
| 223 |
+
dtype=self.dtype,
|
| 224 |
+
)
|
| 225 |
+
for i, kv in enumerate(self.layer_caches):
|
| 226 |
+
if kv is not None:
|
| 227 |
+
new_cache.layer_caches[i] = (kv[0].clone(), kv[1].clone())
|
| 228 |
+
new_cache.layer_seq_lengths = self.layer_seq_lengths.copy()
|
| 229 |
+
new_cache.filled_positions = [s.copy() for s in self.filled_positions]
|
| 230 |
+
return new_cache
|
| 231 |
+
|
| 232 |
+
def reset(self):
|
| 233 |
+
"""Reset the cache to empty state."""
|
| 234 |
+
self.layer_caches = [None for _ in range(self.num_layers)]
|
| 235 |
+
self.layer_seq_lengths = [0] * self.num_layers
|
| 236 |
+
self.filled_positions = [set() for _ in range(self.num_layers)]
|
| 237 |
+
|
| 238 |
+
def __repr__(self) -> str:
|
| 239 |
+
lines = [f"JaggedKVCache(num_layers={self.num_layers}, device={self.device})"]
|
| 240 |
+
for i in range(min(self.num_layers, 10)): # Show first 10 layers
|
| 241 |
+
seq_len = self.layer_seq_lengths[i]
|
| 242 |
+
filled = len(self.filled_positions[i])
|
| 243 |
+
if seq_len > 0:
|
| 244 |
+
lines.append(f" Layer {i:2d}: {filled}/{seq_len} filled")
|
| 245 |
+
if self.num_layers > 10:
|
| 246 |
+
lines.append(f" ... ({self.num_layers - 10} more layers)")
|
| 247 |
+
return "\n".join(lines)
|
src/model_adapters.py
CHANGED
|
@@ -127,11 +127,7 @@ class LlamaStyleAdapter(ModelAdapter):
|
|
| 127 |
) -> Optional[Tuple[Tensor, Tensor]]:
|
| 128 |
if self._rotary is not None:
|
| 129 |
cos, sin = self._rotary(hidden_states, position_ids)
|
| 130 |
-
#
|
| 131 |
-
# This matches LlamaModel behavior which prepares embeddings for layers
|
| 132 |
-
if cos.dim() == 3:
|
| 133 |
-
cos = cos.unsqueeze(1)
|
| 134 |
-
sin = sin.unsqueeze(1)
|
| 135 |
return (cos, sin)
|
| 136 |
return None
|
| 137 |
|
|
|
|
| 127 |
) -> Optional[Tuple[Tensor, Tensor]]:
|
| 128 |
if self._rotary is not None:
|
| 129 |
cos, sin = self._rotary(hidden_states, position_ids)
|
| 130 |
+
# Return as-is - the model's apply_rotary_pos_emb handles unsqueezing
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
return (cos, sin)
|
| 132 |
return None
|
| 133 |
|
src/model_config.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
# Re-exported from the main package for demo use
|
| 3 |
|
| 4 |
import json
|
| 5 |
-
from dataclasses import dataclass, field
|
| 6 |
from typing import Dict, List, Optional
|
| 7 |
|
| 8 |
|
|
@@ -34,10 +34,6 @@ class ModelConfig:
|
|
| 34 |
training_config=data.get("training_config"),
|
| 35 |
)
|
| 36 |
|
| 37 |
-
def to_json(self, path: str) -> None:
|
| 38 |
-
with open(path, "w") as f:
|
| 39 |
-
json.dump(asdict(self), f, indent=2)
|
| 40 |
-
|
| 41 |
|
| 42 |
@dataclass
|
| 43 |
class CalibrationResult:
|
|
@@ -57,15 +53,6 @@ class CalibrationResult:
|
|
| 57 |
data = json.load(f)
|
| 58 |
return cls(**data)
|
| 59 |
|
| 60 |
-
def to_json(self, path: str) -> None:
|
| 61 |
-
with open(path, "w") as f:
|
| 62 |
-
json.dump(asdict(self), f, indent=2)
|
| 63 |
-
|
| 64 |
-
def get_threshold(self, accuracy_level: float, head_idx: int) -> float:
|
| 65 |
-
level_key = f"{accuracy_level:.2f}"
|
| 66 |
-
head_key = str(head_idx)
|
| 67 |
-
return self.thresholds[level_key][head_key]
|
| 68 |
-
|
| 69 |
def get_thresholds_for_level(self, accuracy_level: float) -> Dict[int, float]:
|
| 70 |
"""Get all thresholds for a given accuracy level."""
|
| 71 |
level_key = f"{accuracy_level:.2f}"
|
|
|
|
| 2 |
# Re-exported from the main package for demo use
|
| 3 |
|
| 4 |
import json
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
from typing import Dict, List, Optional
|
| 7 |
|
| 8 |
|
|
|
|
| 34 |
training_config=data.get("training_config"),
|
| 35 |
)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
@dataclass
|
| 39 |
class CalibrationResult:
|
|
|
|
| 53 |
data = json.load(f)
|
| 54 |
return cls(**data)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def get_thresholds_for_level(self, accuracy_level: float) -> Dict[int, float]:
|
| 57 |
"""Get all thresholds for a given accuracy level."""
|
| 58 |
level_key = f"{accuracy_level:.2f}"
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Tests package for DSSD demo
|
tests/benchmark_kv_cache.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Benchmark tests for KV Cache optimization in DSSD.
|
| 3 |
+
|
| 4 |
+
This module provides deterministic benchmarks to measure:
|
| 5 |
+
1. Layer forward counts (direct measure of computation)
|
| 6 |
+
2. Wall clock time for draft + verify phases
|
| 7 |
+
3. Optional FLOPs estimation
|
| 8 |
+
|
| 9 |
+
Run with: python -m tests.benchmark_kv_cache
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import time
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Dict, List, Optional, Tuple
|
| 17 |
+
from contextlib import contextmanager
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# =============================================================================
|
| 21 |
+
# Instrumentation
|
| 22 |
+
# =============================================================================
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class BenchmarkMetrics:
|
| 27 |
+
"""Tracks metrics during benchmark run."""
|
| 28 |
+
|
| 29 |
+
# Layer forward counts
|
| 30 |
+
layer_forward_counts: Dict[int, int] = field(default_factory=dict)
|
| 31 |
+
total_layer_forwards: int = 0
|
| 32 |
+
|
| 33 |
+
# Timing
|
| 34 |
+
draft_time_ms: float = 0.0
|
| 35 |
+
verify_time_ms: float = 0.0
|
| 36 |
+
total_time_ms: float = 0.0
|
| 37 |
+
|
| 38 |
+
# Token counts
|
| 39 |
+
tokens_drafted: int = 0
|
| 40 |
+
tokens_accepted: int = 0
|
| 41 |
+
tokens_rejected: int = 0
|
| 42 |
+
|
| 43 |
+
# Early exit distribution
|
| 44 |
+
exit_layers: List[int] = field(default_factory=list)
|
| 45 |
+
|
| 46 |
+
def reset(self):
|
| 47 |
+
"""Reset all metrics."""
|
| 48 |
+
self.layer_forward_counts.clear()
|
| 49 |
+
self.total_layer_forwards = 0
|
| 50 |
+
self.draft_time_ms = 0.0
|
| 51 |
+
self.verify_time_ms = 0.0
|
| 52 |
+
self.total_time_ms = 0.0
|
| 53 |
+
self.tokens_drafted = 0
|
| 54 |
+
self.tokens_accepted = 0
|
| 55 |
+
self.tokens_rejected = 0
|
| 56 |
+
self.exit_layers.clear()
|
| 57 |
+
|
| 58 |
+
def record_layer_forward(self, layer_idx: int):
|
| 59 |
+
"""Record a layer forward pass."""
|
| 60 |
+
self.layer_forward_counts[layer_idx] = (
|
| 61 |
+
self.layer_forward_counts.get(layer_idx, 0) + 1
|
| 62 |
+
)
|
| 63 |
+
self.total_layer_forwards += 1
|
| 64 |
+
|
| 65 |
+
def summary(self) -> str:
|
| 66 |
+
"""Return human-readable summary."""
|
| 67 |
+
lines = [
|
| 68 |
+
"=" * 50,
|
| 69 |
+
"BENCHMARK METRICS",
|
| 70 |
+
"=" * 50,
|
| 71 |
+
f"Total Layer Forwards: {self.total_layer_forwards}",
|
| 72 |
+
f"Tokens Drafted: {self.tokens_drafted}",
|
| 73 |
+
f"Tokens Accepted: {self.tokens_accepted}",
|
| 74 |
+
f"Tokens Rejected: {self.tokens_rejected}",
|
| 75 |
+
f"Draft Time: {self.draft_time_ms:.2f} ms",
|
| 76 |
+
f"Verify Time: {self.verify_time_ms:.2f} ms",
|
| 77 |
+
f"Total Time: {self.total_time_ms:.2f} ms",
|
| 78 |
+
"",
|
| 79 |
+
"Layer Forward Distribution:",
|
| 80 |
+
]
|
| 81 |
+
for layer_idx in sorted(self.layer_forward_counts.keys()):
|
| 82 |
+
count = self.layer_forward_counts[layer_idx]
|
| 83 |
+
lines.append(f" Layer {layer_idx:2d}: {count} forwards")
|
| 84 |
+
|
| 85 |
+
if self.exit_layers:
|
| 86 |
+
avg_exit = sum(self.exit_layers) / len(self.exit_layers)
|
| 87 |
+
lines.append(f"\nAverage Exit Layer: {avg_exit:.1f}")
|
| 88 |
+
|
| 89 |
+
lines.append("=" * 50)
|
| 90 |
+
return "\n".join(lines)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Global metrics instance for instrumentation
|
| 94 |
+
_metrics: Optional[BenchmarkMetrics] = None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_metrics() -> Optional[BenchmarkMetrics]:
|
| 98 |
+
"""Get the current metrics instance."""
|
| 99 |
+
return _metrics
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@contextmanager
|
| 103 |
+
def benchmark_context():
|
| 104 |
+
"""Context manager that enables metric collection."""
|
| 105 |
+
global _metrics
|
| 106 |
+
_metrics = BenchmarkMetrics()
|
| 107 |
+
try:
|
| 108 |
+
yield _metrics
|
| 109 |
+
finally:
|
| 110 |
+
_metrics = None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def instrument_layer_forward(layer_idx: int):
|
| 114 |
+
"""Call this from forward_layer to record layer execution."""
|
| 115 |
+
if _metrics is not None:
|
| 116 |
+
_metrics.record_layer_forward(layer_idx)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# =============================================================================
|
| 120 |
+
# Timer Utilities
|
| 121 |
+
# =============================================================================
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Timer:
|
| 125 |
+
"""Simple timer for benchmarking."""
|
| 126 |
+
|
| 127 |
+
def __init__(self):
|
| 128 |
+
self.start_time = None
|
| 129 |
+
self.elapsed_ms = 0.0
|
| 130 |
+
|
| 131 |
+
def start(self):
|
| 132 |
+
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
| 133 |
+
self.start_time = time.perf_counter()
|
| 134 |
+
|
| 135 |
+
def stop(self) -> float:
|
| 136 |
+
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
| 137 |
+
if self.start_time is not None:
|
| 138 |
+
self.elapsed_ms = (time.perf_counter() - self.start_time) * 1000
|
| 139 |
+
return self.elapsed_ms
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# =============================================================================
|
| 143 |
+
# Benchmark Test Scenarios
|
| 144 |
+
# =============================================================================
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@dataclass
|
| 148 |
+
class BenchmarkConfig:
|
| 149 |
+
"""Configuration for benchmark runs."""
|
| 150 |
+
|
| 151 |
+
# Model setting
|
| 152 |
+
model_name: str = "Qwen/Qwen3-0.6B"
|
| 153 |
+
|
| 154 |
+
# Generation settings
|
| 155 |
+
prompt: str = "Explain what machine learning is in simple terms."
|
| 156 |
+
max_draft_length: int = 5
|
| 157 |
+
num_iterations: int = 3 # Multiple iterations for averaging
|
| 158 |
+
|
| 159 |
+
# Thresholds for early exit (simulated or real)
|
| 160 |
+
accuracy_level: float = 0.75
|
| 161 |
+
|
| 162 |
+
# Reproducibility
|
| 163 |
+
seed: int = 42
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def run_single_draft_verify_benchmark(
|
| 167 |
+
decoder, # DSSDecoder
|
| 168 |
+
config: BenchmarkConfig,
|
| 169 |
+
use_cache: bool = False,
|
| 170 |
+
) -> BenchmarkMetrics:
|
| 171 |
+
"""
|
| 172 |
+
Run a single draft + verify cycle and measure metrics.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
decoder: The DSSDecoder instance
|
| 176 |
+
config: Benchmark configuration
|
| 177 |
+
use_cache: Whether to use JaggedKVCache (for comparison)
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
BenchmarkMetrics with recorded data
|
| 181 |
+
"""
|
| 182 |
+
# Set seed for reproducibility
|
| 183 |
+
torch.manual_seed(config.seed)
|
| 184 |
+
if torch.cuda.is_available():
|
| 185 |
+
torch.cuda.manual_seed(config.seed)
|
| 186 |
+
|
| 187 |
+
with benchmark_context() as metrics:
|
| 188 |
+
timer = Timer()
|
| 189 |
+
|
| 190 |
+
# Tokenize prompt
|
| 191 |
+
input_ids = decoder.tokenizer.encode(config.prompt, return_tensors="pt").to(
|
| 192 |
+
decoder.device
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Get thresholds
|
| 196 |
+
thresholds = {}
|
| 197 |
+
if decoder.calibration:
|
| 198 |
+
thresholds = decoder.calibration.get_thresholds_for_level(
|
| 199 |
+
config.accuracy_level
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# ========== DRAFT PHASE ==========
|
| 203 |
+
timer.start()
|
| 204 |
+
drafted_tokens = []
|
| 205 |
+
draft_ids = input_ids.clone()
|
| 206 |
+
|
| 207 |
+
for _ in range(config.max_draft_length):
|
| 208 |
+
# Call the drafting function
|
| 209 |
+
# Note: This will need to be modified to use our instrumented version
|
| 210 |
+
draft_result = decoder._draft_single_token(draft_ids, thresholds)
|
| 211 |
+
|
| 212 |
+
if draft_result is None:
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
token_id, exit_head, exit_layer, uncertainty = draft_result
|
| 216 |
+
drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
|
| 217 |
+
metrics.exit_layers.append(exit_layer)
|
| 218 |
+
|
| 219 |
+
if token_id == decoder.tokenizer.eos_token_id:
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
draft_ids = torch.cat(
|
| 223 |
+
[draft_ids, torch.tensor([[token_id]], device=decoder.device)], dim=1
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
metrics.draft_time_ms = timer.stop()
|
| 227 |
+
metrics.tokens_drafted = len(drafted_tokens)
|
| 228 |
+
|
| 229 |
+
# ========== VERIFY PHASE ==========
|
| 230 |
+
timer.start()
|
| 231 |
+
|
| 232 |
+
if drafted_tokens:
|
| 233 |
+
with torch.no_grad():
|
| 234 |
+
outputs = decoder.model(draft_ids, use_cache=False)
|
| 235 |
+
verify_logits = outputs.logits
|
| 236 |
+
|
| 237 |
+
# Verify each token
|
| 238 |
+
start_pos = input_ids.shape[1] - 1
|
| 239 |
+
accepted = 0
|
| 240 |
+
|
| 241 |
+
for i, (token_id, exit_head, exit_layer, uncertainty) in enumerate(
|
| 242 |
+
drafted_tokens
|
| 243 |
+
):
|
| 244 |
+
verify_pos = start_pos + i
|
| 245 |
+
verified_token = torch.argmax(verify_logits[0, verify_pos, :]).item()
|
| 246 |
+
|
| 247 |
+
if token_id == verified_token:
|
| 248 |
+
accepted += 1
|
| 249 |
+
else:
|
| 250 |
+
break
|
| 251 |
+
|
| 252 |
+
metrics.tokens_accepted = accepted
|
| 253 |
+
metrics.tokens_rejected = len(drafted_tokens) - accepted
|
| 254 |
+
|
| 255 |
+
metrics.verify_time_ms = timer.stop()
|
| 256 |
+
metrics.total_time_ms = metrics.draft_time_ms + metrics.verify_time_ms
|
| 257 |
+
|
| 258 |
+
return metrics
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def run_baseline_benchmark(decoder, config: BenchmarkConfig) -> BenchmarkMetrics:
|
| 262 |
+
"""
|
| 263 |
+
Run baseline benchmark (current implementation without cache optimization).
|
| 264 |
+
"""
|
| 265 |
+
print(f"\n{'=' * 60}")
|
| 266 |
+
print("BASELINE BENCHMARK (No Cache)")
|
| 267 |
+
print(f"{'=' * 60}")
|
| 268 |
+
print(f"Model: {config.model_name}")
|
| 269 |
+
print(f"Prompt: '{config.prompt[:50]}...'")
|
| 270 |
+
print(f"Max Draft Length: {config.max_draft_length}")
|
| 271 |
+
print(f"Iterations: {config.num_iterations}")
|
| 272 |
+
|
| 273 |
+
all_metrics = []
|
| 274 |
+
|
| 275 |
+
for i in range(config.num_iterations):
|
| 276 |
+
print(f"\nIteration {i + 1}/{config.num_iterations}...")
|
| 277 |
+
metrics = run_single_draft_verify_benchmark(decoder, config, use_cache=False)
|
| 278 |
+
all_metrics.append(metrics)
|
| 279 |
+
print(f" Layer Forwards: {metrics.total_layer_forwards}")
|
| 280 |
+
print(f" Draft Time: {metrics.draft_time_ms:.2f} ms")
|
| 281 |
+
print(f" Verify Time: {metrics.verify_time_ms:.2f} ms")
|
| 282 |
+
|
| 283 |
+
# Average metrics
|
| 284 |
+
avg_metrics = BenchmarkMetrics()
|
| 285 |
+
avg_metrics.total_layer_forwards = sum(
|
| 286 |
+
m.total_layer_forwards for m in all_metrics
|
| 287 |
+
) // len(all_metrics)
|
| 288 |
+
avg_metrics.draft_time_ms = sum(m.draft_time_ms for m in all_metrics) / len(
|
| 289 |
+
all_metrics
|
| 290 |
+
)
|
| 291 |
+
avg_metrics.verify_time_ms = sum(m.verify_time_ms for m in all_metrics) / len(
|
| 292 |
+
all_metrics
|
| 293 |
+
)
|
| 294 |
+
avg_metrics.total_time_ms = sum(m.total_time_ms for m in all_metrics) / len(
|
| 295 |
+
all_metrics
|
| 296 |
+
)
|
| 297 |
+
avg_metrics.tokens_drafted = all_metrics[0].tokens_drafted
|
| 298 |
+
avg_metrics.tokens_accepted = all_metrics[0].tokens_accepted
|
| 299 |
+
avg_metrics.tokens_rejected = all_metrics[0].tokens_rejected
|
| 300 |
+
|
| 301 |
+
# Combine layer counts
|
| 302 |
+
for m in all_metrics:
|
| 303 |
+
for layer_idx, count in m.layer_forward_counts.items():
|
| 304 |
+
avg_metrics.layer_forward_counts[layer_idx] = (
|
| 305 |
+
avg_metrics.layer_forward_counts.get(layer_idx, 0)
|
| 306 |
+
+ count // len(all_metrics)
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
print("\n" + avg_metrics.summary())
|
| 310 |
+
return avg_metrics
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# =============================================================================
|
| 314 |
+
# Main Entry Point
|
| 315 |
+
# =============================================================================
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def main():
|
| 319 |
+
"""Run benchmark suite."""
|
| 320 |
+
import sys
|
| 321 |
+
|
| 322 |
+
sys.path.insert(0, "/home/fvalade/workspace/DSSD_demo")
|
| 323 |
+
|
| 324 |
+
from src.inference import load_dssd_model
|
| 325 |
+
|
| 326 |
+
config = BenchmarkConfig()
|
| 327 |
+
|
| 328 |
+
print("Loading model...")
|
| 329 |
+
try:
|
| 330 |
+
# You'll need to update these paths to match your setup
|
| 331 |
+
decoder, tokenizer = load_dssd_model(
|
| 332 |
+
model_name=config.model_name,
|
| 333 |
+
heads_path="../checkpoints/qwen3-0.6b/aux_heads.pt",
|
| 334 |
+
config_path="../checkpoints/qwen3-0.6b/config.json",
|
| 335 |
+
calibration_path="../checkpoints/qwen3-0.6b/calibration.json",
|
| 336 |
+
device="auto",
|
| 337 |
+
)
|
| 338 |
+
print("Model loaded successfully!")
|
| 339 |
+
except Exception as e:
|
| 340 |
+
print(f"Error loading model: {e}")
|
| 341 |
+
print("\nTo run this benchmark, ensure you have:")
|
| 342 |
+
print(" 1. A trained auxiliary heads checkpoint")
|
| 343 |
+
print(" 2. The corresponding config.json")
|
| 344 |
+
print(" 3. (Optional) calibration.json for thresholds")
|
| 345 |
+
return
|
| 346 |
+
|
| 347 |
+
# Run baseline benchmark
|
| 348 |
+
baseline_metrics = run_baseline_benchmark(decoder, config)
|
| 349 |
+
|
| 350 |
+
# Save results for later comparison
|
| 351 |
+
print("\n" + "=" * 60)
|
| 352 |
+
print("BASELINE RESULTS SAVED")
|
| 353 |
+
print("Run this again after implementing JaggedKVCache to compare.")
|
| 354 |
+
print("=" * 60)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if __name__ == "__main__":
|
| 358 |
+
main()
|
tests/run_benchmark.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Benchmark comparison: Standard generation vs Cache-optimized generation.
|
| 4 |
+
|
| 5 |
+
This script measures and compares:
|
| 6 |
+
- Layer forward counts
|
| 7 |
+
- Wall clock time
|
| 8 |
+
- Tokens per second
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python tests/run_benchmark.py --model Qwen/Qwen3-0.6B --heads-path /path/to/heads.pt
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import time
|
| 16 |
+
import sys
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
# Add project to path
|
| 20 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def make_dummy_decoder():
|
| 26 |
+
"""Create a minimal decoder for benchmarking without GPU."""
|
| 27 |
+
from src.jagged_cache import JaggedKVCache
|
| 28 |
+
|
| 29 |
+
print("\n" + "=" * 60)
|
| 30 |
+
print("BENCHMARK: JaggedKVCache Operations (No GPU Required)")
|
| 31 |
+
print("=" * 60)
|
| 32 |
+
|
| 33 |
+
# Test cache performance
|
| 34 |
+
num_layers = 28
|
| 35 |
+
batch_size = 1
|
| 36 |
+
num_heads = 8
|
| 37 |
+
head_dim = 128
|
| 38 |
+
seq_len = 100
|
| 39 |
+
|
| 40 |
+
cache = JaggedKVCache(
|
| 41 |
+
num_layers=num_layers,
|
| 42 |
+
batch_size=batch_size,
|
| 43 |
+
num_kv_heads=num_heads,
|
| 44 |
+
head_dim=head_dim,
|
| 45 |
+
device="cpu",
|
| 46 |
+
dtype=torch.float32,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Simulate prefill
|
| 50 |
+
print(f"\nSimulating prefill ({seq_len} tokens, {num_layers} layers)...")
|
| 51 |
+
start = time.perf_counter()
|
| 52 |
+
for pos in range(seq_len):
|
| 53 |
+
for layer_idx in range(num_layers):
|
| 54 |
+
k = torch.randn(batch_size, num_heads, 1, head_dim)
|
| 55 |
+
v = torch.randn(batch_size, num_heads, 1, head_dim)
|
| 56 |
+
cache.update(layer_idx, k, v, torch.tensor([pos]))
|
| 57 |
+
prefill_time = (time.perf_counter() - start) * 1000
|
| 58 |
+
print(f" Prefill time: {prefill_time:.2f} ms")
|
| 59 |
+
|
| 60 |
+
# Simulate draft phase (early exit at different layers)
|
| 61 |
+
print("\nSimulating draft phase (5 tokens, variable exit layers)...")
|
| 62 |
+
exit_layers = [4, 8, 6, 12, 10] # Simulate different exit layers
|
| 63 |
+
draft_cache = cache.clone()
|
| 64 |
+
|
| 65 |
+
start = time.perf_counter()
|
| 66 |
+
for i, exit_layer in enumerate(exit_layers):
|
| 67 |
+
pos = seq_len + i
|
| 68 |
+
for layer_idx in range(exit_layer + 1):
|
| 69 |
+
k = torch.randn(batch_size, num_heads, 1, head_dim)
|
| 70 |
+
v = torch.randn(batch_size, num_heads, 1, head_dim)
|
| 71 |
+
draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
|
| 72 |
+
draft_time = (time.perf_counter() - start) * 1000
|
| 73 |
+
print(f" Draft time: {draft_time:.2f} ms")
|
| 74 |
+
|
| 75 |
+
# Print cache state
|
| 76 |
+
print("\nCache state after drafting:")
|
| 77 |
+
for layer_idx in [0, 4, 8, 12, 16, 20, 24, 27]:
|
| 78 |
+
filled = len(draft_cache.filled_positions[layer_idx])
|
| 79 |
+
print(f" Layer {layer_idx:2d}: {filled} positions filled")
|
| 80 |
+
|
| 81 |
+
# Simulate verification (fill all layers for all positions)
|
| 82 |
+
print("\nSimulating verification (lazy fill + full model)...")
|
| 83 |
+
start = time.perf_counter()
|
| 84 |
+
for pos in range(seq_len, seq_len + 5):
|
| 85 |
+
# Find missing layers
|
| 86 |
+
missing = draft_cache.get_missing_layers(pos, num_layers - 1)
|
| 87 |
+
for layer_idx in missing:
|
| 88 |
+
k = torch.randn(batch_size, num_heads, 1, head_dim)
|
| 89 |
+
v = torch.randn(batch_size, num_heads, 1, head_dim)
|
| 90 |
+
draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
|
| 91 |
+
verify_time = (time.perf_counter() - start) * 1000
|
| 92 |
+
print(f" Verify time: {verify_time:.2f} ms")
|
| 93 |
+
|
| 94 |
+
# Calculate and explain savings
|
| 95 |
+
print("\n" + "=" * 60)
|
| 96 |
+
print("ANALYSIS: Layer Operations")
|
| 97 |
+
print("=" * 60)
|
| 98 |
+
|
| 99 |
+
# Prefill ops (same for all approaches - one-time cost)
|
| 100 |
+
prefill_ops = seq_len * num_layers
|
| 101 |
+
print(f"\nPREFILL (one-time): {prefill_ops} layer ops")
|
| 102 |
+
|
| 103 |
+
# Draft phase with early exit
|
| 104 |
+
draft_ops = sum(exit_layer + 1 for exit_layer in exit_layers)
|
| 105 |
+
draft_ops_full = 5 * num_layers # Without early exit
|
| 106 |
+
print(f"\nDRAFT PHASE (5 tokens):")
|
| 107 |
+
print(f" With early exit: {draft_ops} ops (avg {draft_ops / 5:.1f} layers/token)")
|
| 108 |
+
print(f" Without early exit: {draft_ops_full} ops ({num_layers} layers/token)")
|
| 109 |
+
print(
|
| 110 |
+
f" Draft savings: {draft_ops_full - draft_ops} ops ({100 * (1 - draft_ops / draft_ops_full):.0f}% reduction)"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# The KEY benefit: with cache, each draft token is O(1 token * exit_layer)
|
| 114 |
+
# Without cache, it would be O(seq_len * exit_layer) per token
|
| 115 |
+
print(f"\nCACHE BENEFIT:")
|
| 116 |
+
print(f" Without cache, each draft would recompute {seq_len}-token context")
|
| 117 |
+
print(f" With cache, each draft processes only 1 new token")
|
| 118 |
+
per_token_savings = seq_len - 1 # Positions we don't recompute
|
| 119 |
+
total_context_savings = per_token_savings * draft_ops
|
| 120 |
+
print(f" Context reuse savings: ~{total_context_savings} avoided operations")
|
| 121 |
+
|
| 122 |
+
# Verify phase
|
| 123 |
+
verify_ops = 5 * num_layers
|
| 124 |
+
print(f"\nVERIFY PHASE: {verify_ops} ops (fills all layers for drafted tokens)")
|
| 125 |
+
|
| 126 |
+
print(f"\nTotal time: {prefill_time + draft_time + verify_time:.2f} ms")
|
| 127 |
+
|
| 128 |
+
return True
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def run_full_benchmark(model_name, heads_path, config_path, calibration_path=None):
|
| 132 |
+
"""Run full benchmark with actual model."""
|
| 133 |
+
from src.inference import load_dssd_model
|
| 134 |
+
|
| 135 |
+
print("\n" + "=" * 60)
|
| 136 |
+
print(f"BENCHMARK: Full Model Comparison")
|
| 137 |
+
print(f"Model: {model_name}")
|
| 138 |
+
print("=" * 60)
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
decoder, tokenizer = load_dssd_model(
|
| 142 |
+
model_name=model_name,
|
| 143 |
+
heads_path=heads_path,
|
| 144 |
+
config_path=config_path,
|
| 145 |
+
calibration_path=calibration_path,
|
| 146 |
+
device="auto",
|
| 147 |
+
)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"Error loading model: {e}")
|
| 150 |
+
return False
|
| 151 |
+
|
| 152 |
+
prompt = "Explain what machine learning is in three sentences."
|
| 153 |
+
max_tokens = 50
|
| 154 |
+
|
| 155 |
+
# Warmup
|
| 156 |
+
print("\nWarming up...")
|
| 157 |
+
_ = decoder.generate(
|
| 158 |
+
prompt, max_tokens=10, use_early_exit=False, use_chat_template=True
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Benchmark standard generation
|
| 162 |
+
print("\nRunning standard generation (no cache)...")
|
| 163 |
+
start = time.perf_counter()
|
| 164 |
+
result_standard = decoder.generate(
|
| 165 |
+
prompt,
|
| 166 |
+
max_tokens=max_tokens,
|
| 167 |
+
use_early_exit=True,
|
| 168 |
+
accuracy_level=0.75,
|
| 169 |
+
use_chat_template=True,
|
| 170 |
+
)
|
| 171 |
+
time_standard = time.perf_counter() - start
|
| 172 |
+
|
| 173 |
+
# Benchmark cache-optimized generation (fast version)
|
| 174 |
+
print("Running cache-optimized generation (fast)...")
|
| 175 |
+
start = time.perf_counter()
|
| 176 |
+
result_cached = decoder.generate_fast(
|
| 177 |
+
prompt,
|
| 178 |
+
max_tokens=max_tokens,
|
| 179 |
+
accuracy_level=0.75,
|
| 180 |
+
use_chat_template=True,
|
| 181 |
+
)
|
| 182 |
+
time_cached = time.perf_counter() - start
|
| 183 |
+
|
| 184 |
+
# Print results
|
| 185 |
+
print("\n" + "=" * 60)
|
| 186 |
+
print("RESULTS")
|
| 187 |
+
print("=" * 60)
|
| 188 |
+
|
| 189 |
+
print("\nStandard Generation:")
|
| 190 |
+
print(f" Tokens generated: {len(result_standard.tokens)}")
|
| 191 |
+
print(f" Time: {time_standard:.2f}s")
|
| 192 |
+
print(f" Tokens/sec: {len(result_standard.tokens) / time_standard:.2f}")
|
| 193 |
+
print(f" Avg exit layer: {result_standard.avg_exit_layer:.1f}")
|
| 194 |
+
|
| 195 |
+
print("\nCache-Optimized Generation:")
|
| 196 |
+
print(f" Tokens generated: {len(result_cached.tokens)}")
|
| 197 |
+
print(f" Time: {time_cached:.2f}s")
|
| 198 |
+
print(f" Tokens/sec: {len(result_cached.tokens) / time_cached:.2f}")
|
| 199 |
+
print(f" Avg exit layer: {result_cached.avg_exit_layer:.1f}")
|
| 200 |
+
if "total_drafted" in result_cached.exit_distribution:
|
| 201 |
+
print(f" Drafted: {result_cached.exit_distribution['total_drafted']}")
|
| 202 |
+
print(f" Accepted: {result_cached.exit_distribution['total_accepted']}")
|
| 203 |
+
print(
|
| 204 |
+
f" Acceptance rate: {result_cached.exit_distribution['acceptance_rate']:.1%}"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
print("\nSpeedup:")
|
| 208 |
+
speedup = time_standard / time_cached if time_cached > 0 else 0
|
| 209 |
+
print(f" {speedup:.2f}x faster with cache")
|
| 210 |
+
|
| 211 |
+
return True
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def main():
|
| 215 |
+
parser = argparse.ArgumentParser(description="Benchmark DSSD generation")
|
| 216 |
+
parser.add_argument("--model", default="Qwen/Qwen3-0.6B", help="Model name")
|
| 217 |
+
parser.add_argument("--heads-path", help="Path to aux heads checkpoint")
|
| 218 |
+
parser.add_argument("--config-path", help="Path to model config")
|
| 219 |
+
parser.add_argument("--calibration-path", help="Path to calibration file")
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--cpu-only", action="store_true", help="Run CPU-only cache benchmark"
|
| 222 |
+
)
|
| 223 |
+
args = parser.parse_args()
|
| 224 |
+
|
| 225 |
+
if args.cpu_only or not args.heads_path:
|
| 226 |
+
# Run CPU-only cache operations benchmark
|
| 227 |
+
make_dummy_decoder()
|
| 228 |
+
else:
|
| 229 |
+
# Run full benchmark with model
|
| 230 |
+
run_full_benchmark(
|
| 231 |
+
args.model,
|
| 232 |
+
args.heads_path,
|
| 233 |
+
args.config_path,
|
| 234 |
+
args.calibration_path,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
main()
|
tests/test_cache_integration.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration tests for JaggedKVCache with inference pipeline.
|
| 3 |
+
|
| 4 |
+
Run with: pytest tests/test_cache_integration.py -v
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import torch
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
|
| 11 |
+
# Import from production module
|
| 12 |
+
import sys
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, "/home/fvalade/workspace/DSSD_demo")
|
| 15 |
+
|
| 16 |
+
from src.jagged_cache import JaggedKVCache
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TestJaggedKVCacheProduction:
|
| 20 |
+
"""Test the production JaggedKVCache implementation."""
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def cache(self):
|
| 24 |
+
"""Create a test cache."""
|
| 25 |
+
return JaggedKVCache(
|
| 26 |
+
num_layers=8,
|
| 27 |
+
batch_size=1,
|
| 28 |
+
num_kv_heads=4,
|
| 29 |
+
head_dim=64,
|
| 30 |
+
device="cpu",
|
| 31 |
+
dtype=torch.float32,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
@pytest.fixture
|
| 35 |
+
def sample_kv(self):
|
| 36 |
+
"""Create sample KV tensors."""
|
| 37 |
+
|
| 38 |
+
def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64):
|
| 39 |
+
k = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 40 |
+
v = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 41 |
+
return k, v
|
| 42 |
+
|
| 43 |
+
return _make_kv
|
| 44 |
+
|
| 45 |
+
def test_filled_positions_tracking(self, cache, sample_kv):
|
| 46 |
+
"""Test that filled_positions correctly tracks which positions are filled."""
|
| 47 |
+
# Update layer 0 with position 0
|
| 48 |
+
k, v = sample_kv()
|
| 49 |
+
cache.update(0, k, v, torch.tensor([0]))
|
| 50 |
+
|
| 51 |
+
assert cache.has_position(0, 0) == True
|
| 52 |
+
assert cache.has_position(0, 1) == False
|
| 53 |
+
assert cache.has_position(1, 0) == False # Layer 1 not touched
|
| 54 |
+
|
| 55 |
+
def test_needs_fill(self, cache, sample_kv):
|
| 56 |
+
"""Test needs_fill correctly identifies missing positions."""
|
| 57 |
+
# Fill layer 0 with position 0
|
| 58 |
+
k, v = sample_kv()
|
| 59 |
+
cache.update(0, k, v, torch.tensor([0]))
|
| 60 |
+
|
| 61 |
+
# Layer 0 has position 0, doesn't need fill
|
| 62 |
+
assert cache.needs_fill(0, [0]) == False
|
| 63 |
+
|
| 64 |
+
# Layer 0 doesn't have position 1
|
| 65 |
+
assert cache.needs_fill(0, [1]) == True
|
| 66 |
+
|
| 67 |
+
# Layer 1 has nothing
|
| 68 |
+
assert cache.needs_fill(1, [0]) == True
|
| 69 |
+
|
| 70 |
+
def test_get_unfilled_positions(self, cache, sample_kv):
|
| 71 |
+
"""Test getting unfilled positions."""
|
| 72 |
+
# Fill positions 0 and 2 for layer 0
|
| 73 |
+
k, v = sample_kv()
|
| 74 |
+
cache.update(0, k, v, torch.tensor([0]))
|
| 75 |
+
k, v = sample_kv()
|
| 76 |
+
cache.update(0, k, v, torch.tensor([2]))
|
| 77 |
+
|
| 78 |
+
# Unfilled up to position 4 should be [1, 3]
|
| 79 |
+
unfilled = cache.get_unfilled_positions(0, 4)
|
| 80 |
+
assert unfilled == [1, 3]
|
| 81 |
+
|
| 82 |
+
def test_truncate_clears_filled_positions(self, cache, sample_kv):
|
| 83 |
+
"""Test that truncation also clears filled_positions."""
|
| 84 |
+
# Fill positions 0-4
|
| 85 |
+
for pos in range(5):
|
| 86 |
+
k, v = sample_kv()
|
| 87 |
+
cache.update(0, k, v, torch.tensor([pos]))
|
| 88 |
+
|
| 89 |
+
assert cache.has_position(0, 4) == True
|
| 90 |
+
|
| 91 |
+
# Truncate at position 3
|
| 92 |
+
cache.truncate_from(3)
|
| 93 |
+
|
| 94 |
+
# Positions 3 and 4 should be gone
|
| 95 |
+
assert cache.has_position(0, 2) == True
|
| 96 |
+
assert cache.has_position(0, 3) == False
|
| 97 |
+
assert cache.has_position(0, 4) == False
|
| 98 |
+
|
| 99 |
+
def test_clone_copies_filled_positions(self, cache, sample_kv):
|
| 100 |
+
"""Test that clone also copies filled_positions."""
|
| 101 |
+
k, v = sample_kv()
|
| 102 |
+
cache.update(0, k, v, torch.tensor([0]))
|
| 103 |
+
|
| 104 |
+
cloned = cache.clone()
|
| 105 |
+
|
| 106 |
+
assert cloned.has_position(0, 0) == True
|
| 107 |
+
|
| 108 |
+
# Modify original
|
| 109 |
+
k, v = sample_kv()
|
| 110 |
+
cache.update(0, k, v, torch.tensor([1]))
|
| 111 |
+
|
| 112 |
+
# Clone should be unaffected
|
| 113 |
+
assert cache.has_position(0, 1) == True
|
| 114 |
+
assert cloned.has_position(0, 1) == False
|
| 115 |
+
|
| 116 |
+
def test_reset(self, cache, sample_kv):
|
| 117 |
+
"""Test that reset clears everything."""
|
| 118 |
+
k, v = sample_kv()
|
| 119 |
+
cache.update(0, k, v, torch.tensor([0]))
|
| 120 |
+
|
| 121 |
+
cache.reset()
|
| 122 |
+
|
| 123 |
+
assert cache.get_kv(0) is None
|
| 124 |
+
assert cache.get_seq_length(0) == 0
|
| 125 |
+
assert cache.has_position(0, 0) == False
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class TestLazyFillScenario:
|
| 129 |
+
"""Test realistic lazy fill scenarios."""
|
| 130 |
+
|
| 131 |
+
@pytest.fixture
|
| 132 |
+
def cache(self):
|
| 133 |
+
return JaggedKVCache(
|
| 134 |
+
num_layers=8,
|
| 135 |
+
batch_size=1,
|
| 136 |
+
num_kv_heads=4,
|
| 137 |
+
head_dim=64,
|
| 138 |
+
device="cpu",
|
| 139 |
+
dtype=torch.float32,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
@pytest.fixture
|
| 143 |
+
def sample_kv(self):
|
| 144 |
+
def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64):
|
| 145 |
+
k = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 146 |
+
v = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 147 |
+
return k, v
|
| 148 |
+
|
| 149 |
+
return _make_kv
|
| 150 |
+
|
| 151 |
+
def test_lazy_fill_scenario(self, cache, sample_kv):
|
| 152 |
+
"""
|
| 153 |
+
Simulate:
|
| 154 |
+
- Prefill prompt (positions 0-4) through all layers
|
| 155 |
+
- Draft token 5 exiting at layer 2
|
| 156 |
+
- Draft token 6 exiting at layer 6 (needs lazy fill)
|
| 157 |
+
"""
|
| 158 |
+
# Prefill: positions 0-4 through all 8 layers
|
| 159 |
+
for pos in range(5):
|
| 160 |
+
for layer_idx in range(8):
|
| 161 |
+
k, v = sample_kv()
|
| 162 |
+
cache.update(layer_idx, k, v, torch.tensor([pos]))
|
| 163 |
+
|
| 164 |
+
# Verify prefill complete
|
| 165 |
+
for layer_idx in range(8):
|
| 166 |
+
assert cache.get_seq_length(layer_idx) == 5
|
| 167 |
+
for pos in range(5):
|
| 168 |
+
assert cache.has_position(layer_idx, pos)
|
| 169 |
+
|
| 170 |
+
# Draft token 5, exit at layer 2
|
| 171 |
+
for layer_idx in range(3): # Layers 0, 1, 2
|
| 172 |
+
k, v = sample_kv()
|
| 173 |
+
cache.update(layer_idx, k, v, torch.tensor([5]))
|
| 174 |
+
|
| 175 |
+
# Position 5 is filled only for layers 0-2
|
| 176 |
+
assert cache.has_position(0, 5)
|
| 177 |
+
assert cache.has_position(2, 5)
|
| 178 |
+
assert not cache.has_position(3, 5)
|
| 179 |
+
|
| 180 |
+
# Draft token 6, need to exit at layer 6
|
| 181 |
+
# Check what positions are missing for layer 6
|
| 182 |
+
missing_at_layer_6 = cache.get_missing_layers(5, 6)
|
| 183 |
+
|
| 184 |
+
# Layers 3-6 are missing position 5
|
| 185 |
+
assert 3 in missing_at_layer_6
|
| 186 |
+
assert 6 in missing_at_layer_6
|
| 187 |
+
assert 0 not in missing_at_layer_6 # Layer 0 has position 5
|
| 188 |
+
|
| 189 |
+
# Check unfilled positions for layer 6 up to position 6
|
| 190 |
+
unfilled = cache.get_unfilled_positions(6, 6)
|
| 191 |
+
assert 5 in unfilled # Position 5 is unfilled at layer 6
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
pytest.main([__file__, "-v"])
|
tests/test_cache_operations.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Step-by-step verification tests for KV Cache operations.
|
| 3 |
+
|
| 4 |
+
These tests verify the correctness of the JaggedKVCache implementation
|
| 5 |
+
without requiring a full model. Run with: pytest tests/test_cache_operations.py -v
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import torch
|
| 10 |
+
from typing import List, Tuple, Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# =============================================================================
|
| 14 |
+
# Mock Cache Implementation (to be replaced with real JaggedKVCache)
|
| 15 |
+
# =============================================================================
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class JaggedKVCache:
|
| 19 |
+
"""
|
| 20 |
+
Jagged KV Cache that tracks per-layer sequence lengths.
|
| 21 |
+
|
| 22 |
+
This is a reference implementation for testing. The production version
|
| 23 |
+
will be in src/jagged_cache.py.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
num_layers: int,
|
| 29 |
+
batch_size: int = 1,
|
| 30 |
+
num_kv_heads: int = 8,
|
| 31 |
+
head_dim: int = 128,
|
| 32 |
+
device: str = "cpu",
|
| 33 |
+
dtype: torch.dtype = torch.float32,
|
| 34 |
+
):
|
| 35 |
+
self.num_layers = num_layers
|
| 36 |
+
self.batch_size = batch_size
|
| 37 |
+
self.num_kv_heads = num_kv_heads
|
| 38 |
+
self.head_dim = head_dim
|
| 39 |
+
self.device = device
|
| 40 |
+
self.dtype = dtype
|
| 41 |
+
|
| 42 |
+
# Per-layer storage: List of (key_cache, value_cache) or None
|
| 43 |
+
self.layer_caches: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
|
| 44 |
+
None for _ in range(num_layers)
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# Track sequence length per layer
|
| 48 |
+
self.layer_seq_lengths: List[int] = [0] * num_layers
|
| 49 |
+
|
| 50 |
+
def update(
|
| 51 |
+
self,
|
| 52 |
+
layer_idx: int,
|
| 53 |
+
key_states: torch.Tensor,
|
| 54 |
+
value_states: torch.Tensor,
|
| 55 |
+
cache_position: torch.Tensor,
|
| 56 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 57 |
+
"""
|
| 58 |
+
Update cache for a layer at specific positions.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
layer_idx: Layer index to update
|
| 62 |
+
key_states: [B, num_kv_heads, seq_len, head_dim]
|
| 63 |
+
value_states: [B, num_kv_heads, seq_len, head_dim]
|
| 64 |
+
cache_position: [seq_len] positions to update
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
(full_keys, full_values) including cached + new
|
| 68 |
+
"""
|
| 69 |
+
new_len = cache_position[-1].item() + 1
|
| 70 |
+
input_seq_len = key_states.shape[2]
|
| 71 |
+
|
| 72 |
+
if self.layer_caches[layer_idx] is None:
|
| 73 |
+
# First time - check if positions are contiguous starting from 0
|
| 74 |
+
if cache_position[0].item() == 0 and input_seq_len == new_len:
|
| 75 |
+
# Simple case: positions [0, 1, ..., n-1] - just clone
|
| 76 |
+
self.layer_caches[layer_idx] = (
|
| 77 |
+
key_states.clone(),
|
| 78 |
+
value_states.clone(),
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
# Non-contiguous or not starting from 0 - need to allocate full size
|
| 82 |
+
k_cache = torch.zeros(
|
| 83 |
+
(self.batch_size, self.num_kv_heads, new_len, self.head_dim),
|
| 84 |
+
device=self.device,
|
| 85 |
+
dtype=self.dtype,
|
| 86 |
+
)
|
| 87 |
+
v_cache = torch.zeros(
|
| 88 |
+
(self.batch_size, self.num_kv_heads, new_len, self.head_dim),
|
| 89 |
+
device=self.device,
|
| 90 |
+
dtype=self.dtype,
|
| 91 |
+
)
|
| 92 |
+
k_cache[:, :, cache_position.long(), :] = key_states
|
| 93 |
+
v_cache[:, :, cache_position.long(), :] = value_states
|
| 94 |
+
self.layer_caches[layer_idx] = (k_cache, v_cache)
|
| 95 |
+
self.layer_seq_lengths[layer_idx] = new_len
|
| 96 |
+
else:
|
| 97 |
+
k_cache, v_cache = self.layer_caches[layer_idx]
|
| 98 |
+
current_len = k_cache.shape[2]
|
| 99 |
+
|
| 100 |
+
if new_len > current_len:
|
| 101 |
+
# Need to extend cache
|
| 102 |
+
extension_size = new_len - current_len
|
| 103 |
+
k_extension = torch.zeros(
|
| 104 |
+
(self.batch_size, self.num_kv_heads, extension_size, self.head_dim),
|
| 105 |
+
device=self.device,
|
| 106 |
+
dtype=self.dtype,
|
| 107 |
+
)
|
| 108 |
+
v_extension = torch.zeros(
|
| 109 |
+
(self.batch_size, self.num_kv_heads, extension_size, self.head_dim),
|
| 110 |
+
device=self.device,
|
| 111 |
+
dtype=self.dtype,
|
| 112 |
+
)
|
| 113 |
+
k_cache = torch.cat([k_cache, k_extension], dim=2)
|
| 114 |
+
v_cache = torch.cat([v_cache, v_extension], dim=2)
|
| 115 |
+
|
| 116 |
+
# Update at cache_position (handles both extension and gap-filling)
|
| 117 |
+
k_cache[:, :, cache_position.long(), :] = key_states
|
| 118 |
+
v_cache[:, :, cache_position.long(), :] = value_states
|
| 119 |
+
|
| 120 |
+
self.layer_caches[layer_idx] = (k_cache, v_cache)
|
| 121 |
+
self.layer_seq_lengths[layer_idx] = max(
|
| 122 |
+
self.layer_seq_lengths[layer_idx], new_len
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return self.layer_caches[layer_idx]
|
| 126 |
+
|
| 127 |
+
def get_kv(self, layer_idx: int) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
| 128 |
+
"""Get cached KV for a layer, or None if not cached."""
|
| 129 |
+
return self.layer_caches[layer_idx]
|
| 130 |
+
|
| 131 |
+
def get_seq_length(self, layer_idx: int) -> int:
|
| 132 |
+
"""Get the sequence length cached for a layer."""
|
| 133 |
+
return self.layer_seq_lengths[layer_idx]
|
| 134 |
+
|
| 135 |
+
def truncate_from(self, position: int):
|
| 136 |
+
"""
|
| 137 |
+
Truncate all layer caches from position onwards.
|
| 138 |
+
Used for rollback on rejection.
|
| 139 |
+
"""
|
| 140 |
+
for layer_idx in range(self.num_layers):
|
| 141 |
+
if self.layer_caches[layer_idx] is not None:
|
| 142 |
+
k, v = self.layer_caches[layer_idx]
|
| 143 |
+
if k.shape[2] > position:
|
| 144 |
+
self.layer_caches[layer_idx] = (
|
| 145 |
+
k[:, :, :position, :],
|
| 146 |
+
v[:, :, :position, :],
|
| 147 |
+
)
|
| 148 |
+
self.layer_seq_lengths[layer_idx] = min(
|
| 149 |
+
self.layer_seq_lengths[layer_idx], position
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def clone(self) -> "JaggedKVCache":
|
| 153 |
+
"""Create a deep copy of the cache for speculation."""
|
| 154 |
+
new_cache = JaggedKVCache(
|
| 155 |
+
num_layers=self.num_layers,
|
| 156 |
+
batch_size=self.batch_size,
|
| 157 |
+
num_kv_heads=self.num_kv_heads,
|
| 158 |
+
head_dim=self.head_dim,
|
| 159 |
+
device=self.device,
|
| 160 |
+
dtype=self.dtype,
|
| 161 |
+
)
|
| 162 |
+
for i, kv in enumerate(self.layer_caches):
|
| 163 |
+
if kv is not None:
|
| 164 |
+
new_cache.layer_caches[i] = (kv[0].clone(), kv[1].clone())
|
| 165 |
+
new_cache.layer_seq_lengths = self.layer_seq_lengths.copy()
|
| 166 |
+
return new_cache
|
| 167 |
+
|
| 168 |
+
def get_missing_layers(self, position: int, target_layer: int) -> List[int]:
|
| 169 |
+
"""
|
| 170 |
+
Get list of layers that need computation for this position.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
position: The position we need KV for
|
| 174 |
+
target_layer: The deepest layer we need to reach
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
List of layer indices that need to be computed
|
| 178 |
+
"""
|
| 179 |
+
missing = []
|
| 180 |
+
for layer_idx in range(target_layer + 1):
|
| 181 |
+
if self.layer_seq_lengths[layer_idx] <= position:
|
| 182 |
+
missing.append(layer_idx)
|
| 183 |
+
return missing
|
| 184 |
+
|
| 185 |
+
def __repr__(self):
|
| 186 |
+
lines = [f"JaggedKVCache(num_layers={self.num_layers})"]
|
| 187 |
+
for i in range(self.num_layers):
|
| 188 |
+
seq_len = self.layer_seq_lengths[i]
|
| 189 |
+
lines.append(f" Layer {i:2d}: {seq_len} positions cached")
|
| 190 |
+
return "\n".join(lines)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# =============================================================================
|
| 194 |
+
# Test Fixtures
|
| 195 |
+
# =============================================================================
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@pytest.fixture
|
| 199 |
+
def small_cache():
|
| 200 |
+
"""Create a small cache for testing."""
|
| 201 |
+
return JaggedKVCache(
|
| 202 |
+
num_layers=8,
|
| 203 |
+
batch_size=1,
|
| 204 |
+
num_kv_heads=4,
|
| 205 |
+
head_dim=64,
|
| 206 |
+
device="cpu",
|
| 207 |
+
dtype=torch.float32,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@pytest.fixture
|
| 212 |
+
def sample_kv():
|
| 213 |
+
"""Create sample KV tensors."""
|
| 214 |
+
|
| 215 |
+
def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64):
|
| 216 |
+
k = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 217 |
+
v = torch.randn(batch_size, num_heads, seq_len, head_dim)
|
| 218 |
+
return k, v
|
| 219 |
+
|
| 220 |
+
return _make_kv
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# =============================================================================
|
| 224 |
+
# Test 1: Basic Cache Operations
|
| 225 |
+
# =============================================================================
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class TestCacheBasicOperations:
|
| 229 |
+
"""Test basic cache update and retrieval."""
|
| 230 |
+
|
| 231 |
+
def test_cache_starts_empty(self, small_cache):
|
| 232 |
+
"""Cache should start with no entries."""
|
| 233 |
+
for i in range(small_cache.num_layers):
|
| 234 |
+
assert small_cache.get_kv(i) is None
|
| 235 |
+
assert small_cache.get_seq_length(i) == 0
|
| 236 |
+
|
| 237 |
+
def test_single_position_update(self, small_cache, sample_kv):
|
| 238 |
+
"""Test updating cache with a single position."""
|
| 239 |
+
k, v = sample_kv()
|
| 240 |
+
cache_position = torch.tensor([0])
|
| 241 |
+
|
| 242 |
+
small_cache.update(
|
| 243 |
+
layer_idx=0, key_states=k, value_states=v, cache_position=cache_position
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
assert small_cache.get_kv(0) is not None
|
| 247 |
+
assert small_cache.get_seq_length(0) == 1
|
| 248 |
+
assert small_cache.get_kv(1) is None # Other layers unchanged
|
| 249 |
+
|
| 250 |
+
def test_multiple_positions_update(self, small_cache, sample_kv):
|
| 251 |
+
"""Test updating cache with multiple positions at once."""
|
| 252 |
+
k, v = sample_kv(seq_len=3)
|
| 253 |
+
cache_position = torch.tensor([0, 1, 2])
|
| 254 |
+
|
| 255 |
+
small_cache.update(
|
| 256 |
+
layer_idx=0, key_states=k, value_states=v, cache_position=cache_position
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
assert small_cache.get_seq_length(0) == 3
|
| 260 |
+
cached_k, cached_v = small_cache.get_kv(0)
|
| 261 |
+
assert cached_k.shape[2] == 3
|
| 262 |
+
|
| 263 |
+
def test_extending_cache(self, small_cache, sample_kv):
|
| 264 |
+
"""Test extending cache with new positions."""
|
| 265 |
+
# First update
|
| 266 |
+
k1, v1 = sample_kv(seq_len=2)
|
| 267 |
+
small_cache.update(0, k1, v1, torch.tensor([0, 1]))
|
| 268 |
+
|
| 269 |
+
# Extend with more positions
|
| 270 |
+
k2, v2 = sample_kv(seq_len=2)
|
| 271 |
+
small_cache.update(0, k2, v2, torch.tensor([2, 3]))
|
| 272 |
+
|
| 273 |
+
assert small_cache.get_seq_length(0) == 4
|
| 274 |
+
cached_k, _ = small_cache.get_kv(0)
|
| 275 |
+
assert cached_k.shape[2] == 4
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# =============================================================================
|
| 279 |
+
# Test 2: Jagged Cache Behavior
|
| 280 |
+
# =============================================================================
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class TestJaggedCacheBehavior:
|
| 284 |
+
"""Test that cache correctly handles different layers with different lengths."""
|
| 285 |
+
|
| 286 |
+
def test_different_layers_different_lengths(self, small_cache, sample_kv):
|
| 287 |
+
"""Simulate early exit where different layers have different cached lengths.
|
| 288 |
+
|
| 289 |
+
Note: seq_length tracks capacity (max_pos + 1), not filled count.
|
| 290 |
+
When layer 3 is first updated at position [1], it allocates space for
|
| 291 |
+
positions [0, 1], but position 0 contains zeros (unfilled).
|
| 292 |
+
The lazy fill mechanism will fill these gaps when needed.
|
| 293 |
+
"""
|
| 294 |
+
# Token 0: Exit at layer 2 -> layers 0-2 get cached
|
| 295 |
+
for layer_idx in range(3):
|
| 296 |
+
k, v = sample_kv()
|
| 297 |
+
small_cache.update(layer_idx, k, v, torch.tensor([0]))
|
| 298 |
+
|
| 299 |
+
# Token 1: Exit at layer 4 -> layers 0-4 get cached
|
| 300 |
+
for layer_idx in range(5):
|
| 301 |
+
k, v = sample_kv()
|
| 302 |
+
small_cache.update(layer_idx, k, v, torch.tensor([1]))
|
| 303 |
+
|
| 304 |
+
# Check jagged structure
|
| 305 |
+
# seq_length = capacity = max_position + 1
|
| 306 |
+
assert small_cache.get_seq_length(0) == 2 # Both tokens
|
| 307 |
+
assert small_cache.get_seq_length(1) == 2
|
| 308 |
+
assert small_cache.get_seq_length(2) == 2
|
| 309 |
+
# Layers 3-4 have capacity 2 (allocated for positions 0,1)
|
| 310 |
+
# Position 0 is zeros (unfilled) - will be lazy-filled when needed
|
| 311 |
+
assert small_cache.get_seq_length(3) == 2
|
| 312 |
+
assert small_cache.get_seq_length(4) == 2
|
| 313 |
+
assert small_cache.get_seq_length(5) == 0 # Never reached
|
| 314 |
+
|
| 315 |
+
def test_get_missing_layers(self, small_cache, sample_kv):
|
| 316 |
+
"""Test detecting which layers need computation."""
|
| 317 |
+
# Cache position 0 for layers 0-2 only
|
| 318 |
+
for layer_idx in range(3):
|
| 319 |
+
k, v = sample_kv()
|
| 320 |
+
small_cache.update(layer_idx, k, v, torch.tensor([0]))
|
| 321 |
+
|
| 322 |
+
# Check what's missing for position 0 up to layer 5
|
| 323 |
+
missing = small_cache.get_missing_layers(position=0, target_layer=5)
|
| 324 |
+
assert missing == [3, 4, 5] # Layers 3-5 are missing
|
| 325 |
+
|
| 326 |
+
# Check for position 1 (not cached anywhere)
|
| 327 |
+
missing = small_cache.get_missing_layers(position=1, target_layer=5)
|
| 328 |
+
assert missing == [0, 1, 2, 3, 4, 5] # All layers missing
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# =============================================================================
|
| 332 |
+
# Test 3: Truncation for Rollback
|
| 333 |
+
# =============================================================================
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class TestCacheTruncation:
|
| 337 |
+
"""Test cache truncation for rejection rollback."""
|
| 338 |
+
|
| 339 |
+
def test_truncate_removes_positions(self, small_cache, sample_kv):
|
| 340 |
+
"""Test that truncation removes positions correctly."""
|
| 341 |
+
# Fill cache with 5 positions
|
| 342 |
+
for pos in range(5):
|
| 343 |
+
k, v = sample_kv()
|
| 344 |
+
small_cache.update(0, k, v, torch.tensor([pos]))
|
| 345 |
+
|
| 346 |
+
assert small_cache.get_seq_length(0) == 5
|
| 347 |
+
|
| 348 |
+
# Truncate at position 3 (keep 0, 1, 2)
|
| 349 |
+
small_cache.truncate_from(3)
|
| 350 |
+
|
| 351 |
+
assert small_cache.get_seq_length(0) == 3
|
| 352 |
+
cached_k, _ = small_cache.get_kv(0)
|
| 353 |
+
assert cached_k.shape[2] == 3
|
| 354 |
+
|
| 355 |
+
def test_truncate_all_layers(self, small_cache, sample_kv):
|
| 356 |
+
"""Test that truncation affects all layers."""
|
| 357 |
+
# Fill multiple layers with different lengths
|
| 358 |
+
for layer_idx in range(3):
|
| 359 |
+
for pos in range(5):
|
| 360 |
+
k, v = sample_kv()
|
| 361 |
+
small_cache.update(layer_idx, k, v, torch.tensor([pos]))
|
| 362 |
+
|
| 363 |
+
# Add more to layer 0
|
| 364 |
+
for pos in range(5, 8):
|
| 365 |
+
k, v = sample_kv()
|
| 366 |
+
small_cache.update(0, k, v, torch.tensor([pos]))
|
| 367 |
+
|
| 368 |
+
assert small_cache.get_seq_length(0) == 8
|
| 369 |
+
assert small_cache.get_seq_length(1) == 5
|
| 370 |
+
assert small_cache.get_seq_length(2) == 5
|
| 371 |
+
|
| 372 |
+
# Truncate at position 4
|
| 373 |
+
small_cache.truncate_from(4)
|
| 374 |
+
|
| 375 |
+
assert small_cache.get_seq_length(0) == 4
|
| 376 |
+
assert small_cache.get_seq_length(1) == 4
|
| 377 |
+
assert small_cache.get_seq_length(2) == 4
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# =============================================================================
|
| 381 |
+
# Test 4: Clone for Speculation
|
| 382 |
+
# =============================================================================
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class TestCacheCloning:
|
| 386 |
+
"""Test cache cloning for speculative drafting."""
|
| 387 |
+
|
| 388 |
+
def test_clone_creates_independent_copy(self, small_cache, sample_kv):
|
| 389 |
+
"""Test that clone creates truly independent copy."""
|
| 390 |
+
# Fill original cache
|
| 391 |
+
k, v = sample_kv(seq_len=3)
|
| 392 |
+
small_cache.update(0, k, v, torch.tensor([0, 1, 2]))
|
| 393 |
+
|
| 394 |
+
# Clone
|
| 395 |
+
cloned = small_cache.clone()
|
| 396 |
+
|
| 397 |
+
# Modify original
|
| 398 |
+
k2, v2 = sample_kv()
|
| 399 |
+
small_cache.update(0, k2, v2, torch.tensor([3]))
|
| 400 |
+
|
| 401 |
+
# Check clone is unchanged
|
| 402 |
+
assert small_cache.get_seq_length(0) == 4
|
| 403 |
+
assert cloned.get_seq_length(0) == 3
|
| 404 |
+
|
| 405 |
+
def test_clone_preserves_data(self, small_cache, sample_kv):
|
| 406 |
+
"""Test that clone preserves actual tensor values."""
|
| 407 |
+
k, v = sample_kv()
|
| 408 |
+
small_cache.update(0, k, v, torch.tensor([0]))
|
| 409 |
+
|
| 410 |
+
cloned = small_cache.clone()
|
| 411 |
+
|
| 412 |
+
orig_k, orig_v = small_cache.get_kv(0)
|
| 413 |
+
clone_k, clone_v = cloned.get_kv(0)
|
| 414 |
+
|
| 415 |
+
assert torch.allclose(orig_k, clone_k)
|
| 416 |
+
assert torch.allclose(orig_v, clone_v)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
# =============================================================================
|
| 420 |
+
# Test 5: Simulated Draft/Verify Scenario
|
| 421 |
+
# =============================================================================
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class TestDraftVerifyScenario:
|
| 425 |
+
"""Simulate a realistic draft/verify scenario."""
|
| 426 |
+
|
| 427 |
+
def test_draft_verify_with_full_accept(self, small_cache, sample_kv):
|
| 428 |
+
"""Simulate drafting 3 tokens, all accepted."""
|
| 429 |
+
# Prompt prefill (position 0-4)
|
| 430 |
+
for pos in range(5):
|
| 431 |
+
for layer_idx in range(small_cache.num_layers):
|
| 432 |
+
k, v = sample_kv()
|
| 433 |
+
small_cache.update(layer_idx, k, v, torch.tensor([pos]))
|
| 434 |
+
|
| 435 |
+
# Clone for drafting
|
| 436 |
+
draft_cache = small_cache.clone()
|
| 437 |
+
|
| 438 |
+
# Draft 3 tokens (positions 5, 6, 7), exiting at different layers
|
| 439 |
+
exit_layers = [2, 4, 3] # Token 5 exits at layer 2, etc.
|
| 440 |
+
|
| 441 |
+
for i, (pos, exit_layer) in enumerate(zip([5, 6, 7], exit_layers)):
|
| 442 |
+
for layer_idx in range(exit_layer + 1):
|
| 443 |
+
k, v = sample_kv()
|
| 444 |
+
draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
|
| 445 |
+
|
| 446 |
+
# Check jagged structure after drafting
|
| 447 |
+
assert draft_cache.get_seq_length(0) == 8 # All 8 positions
|
| 448 |
+
assert draft_cache.get_seq_length(2) == 8 # All tokens reached layer 2
|
| 449 |
+
assert draft_cache.get_seq_length(4) == 7 # Only tokens 5,6 reached layer 4
|
| 450 |
+
|
| 451 |
+
# "Verification" - all accepted, fill remaining layers
|
| 452 |
+
for pos in [5, 6, 7]:
|
| 453 |
+
for layer_idx in range(small_cache.num_layers):
|
| 454 |
+
if draft_cache.get_seq_length(layer_idx) <= pos:
|
| 455 |
+
k, v = sample_kv()
|
| 456 |
+
draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
|
| 457 |
+
|
| 458 |
+
# After verification, all layers should have all positions
|
| 459 |
+
for layer_idx in range(small_cache.num_layers):
|
| 460 |
+
assert draft_cache.get_seq_length(layer_idx) == 8
|
| 461 |
+
|
| 462 |
+
def test_draft_verify_with_rejection(self, small_cache, sample_kv):
|
| 463 |
+
"""Simulate drafting 3 tokens, rejected at position 6."""
|
| 464 |
+
# Prompt prefill
|
| 465 |
+
for pos in range(5):
|
| 466 |
+
for layer_idx in range(small_cache.num_layers):
|
| 467 |
+
k, v = sample_kv()
|
| 468 |
+
small_cache.update(layer_idx, k, v, torch.tensor([pos]))
|
| 469 |
+
|
| 470 |
+
# Clone for drafting
|
| 471 |
+
draft_cache = small_cache.clone()
|
| 472 |
+
|
| 473 |
+
# Draft 3 tokens
|
| 474 |
+
for pos in [5, 6, 7]:
|
| 475 |
+
for layer_idx in range(3): # All exit at layer 2
|
| 476 |
+
k, v = sample_kv()
|
| 477 |
+
draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
|
| 478 |
+
|
| 479 |
+
# Simulate rejection at position 6
|
| 480 |
+
# Accept position 5, reject 6 (and 7)
|
| 481 |
+
draft_cache.truncate_from(6)
|
| 482 |
+
|
| 483 |
+
# Should only have positions 0-5
|
| 484 |
+
assert draft_cache.get_seq_length(0) == 6
|
| 485 |
+
assert draft_cache.get_seq_length(1) == 6
|
| 486 |
+
assert draft_cache.get_seq_length(2) == 6
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
# =============================================================================
|
| 490 |
+
# Run tests directly
|
| 491 |
+
# =============================================================================
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
if __name__ == "__main__":
|
| 495 |
+
pytest.main([__file__, "-v"])
|