Florian valade commited on
Commit
33efa44
·
1 Parent(s): 432ea6e

Track metrics during streaming, remove redundant generation re-runs

Browse files

- Add StreamingResult dataclass to hold metrics from streaming generation
- Update StreamEvent with 'complete' event type and result field
- Extract _format_and_encode_prompt() helper to reduce code duplication
- Update generate_streaming() and generate_full_model_streaming() to
yield final 'complete' event with accumulated metrics
- Refactor app.py to use streaming results instead of re-running generation
- Remove unused code: set_full_cache(), create_cache_from_model(),
to_json() methods, get_threshold()
- Add jagged_cache.py and test files for KV cache operations

This fixes the bug where output would change after streaming completed
and metrics would appear after a delay.

app.py CHANGED
@@ -4,10 +4,12 @@ Showcases early exit inference with color-coded tokens showing which head genera
4
  """
5
 
6
  import gradio as gr
 
7
  from pathlib import Path
 
8
  from huggingface_hub import hf_hub_download
9
 
10
- from src.inference import load_dssd_model, DSSDecoder, TokenInfo, StreamEvent
11
 
12
  # Available models configuration
13
  AVAILABLE_MODELS = {
@@ -33,6 +35,10 @@ HEAD_COLORS = [
33
  ]
34
  FULL_MODEL_COLOR = "#95D5B2" # Light green - Full model
35
 
 
 
 
 
36
  # Global decoder cache
37
  _decoder_cache = {}
38
 
@@ -103,7 +109,7 @@ def tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> str:
103
 
104
  html_parts.append(
105
  f'<span style="background-color: {color}; padding: 2px 4px; '
106
- f'border-radius: 3px; margin: 1px; display: inline-block;" title="{title}">{text}</span>'
107
  )
108
 
109
  # Wrap in container with word-wrap to prevent overflow
@@ -121,8 +127,8 @@ def drafted_tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> s
121
  layer = head_layers[token.exit_head]
122
  title = f"PENDING - Head {token.exit_head} (Layer {layer})"
123
  else:
124
- color = FULL_MODEL_COLOR
125
- title = "PENDING - Full Model"
126
 
127
  text = (
128
  token.token_text.replace("&", "&amp;")
@@ -134,7 +140,8 @@ def drafted_tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> s
134
  html_parts.append(
135
  f'<span style="background-color: {color}; padding: 2px 4px; '
136
  f"border-radius: 3px; margin: 1px; display: inline-block; "
137
- f'border: 2px dashed #333; opacity: 0.7;" title="{title}">{text}</span>'
 
138
  )
139
 
140
  return "".join(html_parts)
@@ -156,17 +163,78 @@ def create_legend(head_layers: list[int]) -> str:
156
  return " ".join(legend_items)
157
 
158
 
159
- def create_stats_html(result, label: str) -> str:
160
- """Create statistics HTML display."""
161
- return f"""
162
- <div style="padding: 10px; background: #f5f5f5; border-radius: 8px; margin-top: 10px;">
163
- <h4 style="margin: 0 0 10px 0;">{label} Statistics</h4>
164
- <p><b>Time:</b> {result.total_time:.2f}s</p>
165
- <p><b>Tokens/sec:</b> {result.tokens_per_second:.2f}</p>
166
- <p><b>Avg Exit Layer:</b> {result.avg_exit_layer:.1f}</p>
167
- <p><b>Exit Distribution:</b> {result.exit_distribution}</p>
168
- </div>
169
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
 
172
  def generate(
@@ -178,13 +246,29 @@ def generate(
178
  compare_mode: bool,
179
  ):
180
  """Main generation function for Gradio interface with streaming."""
 
181
  try:
182
  decoder = get_decoder(model_key)
183
  except Exception as e:
184
  error_msg = f"<p style='color: red;'>Error loading model: {e}</p>"
185
- yield (error_msg, "", "", error_msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  return
187
 
 
188
  head_layers = decoder.model_config.head_layer_indices
189
  legend = create_legend(head_layers)
190
 
@@ -199,12 +283,21 @@ def generate(
199
  # Compare mode with streaming for early exit
200
  # First, stream the early exit generation
201
  final_ee_tokens = []
 
 
202
  for event in decoder.generate_streaming(
203
  prompt=prompt,
204
  max_tokens=int(max_tokens),
205
  accuracy_level=closest_level,
206
  use_chat_template=True,
207
  ):
 
 
 
 
 
 
 
208
  validated_html = ""
209
  if event.tokens:
210
  validated_html = tokens_to_html(event.tokens, head_layers)
@@ -219,91 +312,97 @@ def generate(
219
 
220
  combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""
221
 
222
- status = f"""
223
- <div style="padding: 10px; background: #fff3cd; border-radius: 8px;">
224
- <b>Early Exit:</b> {event.message} | <b>Full Model:</b> Waiting...
225
- </div>
226
- """
 
227
 
 
 
 
 
 
 
 
228
  yield (
229
  combined_html,
230
- "<p style='color: #666;'>Waiting for early exit to complete...</p>",
231
  status,
 
232
  legend,
233
  )
234
- final_ee_tokens = event.tokens
235
 
236
  # Now stream full model
237
  final_full_tokens = []
 
 
238
  for event in decoder.generate_full_model_streaming(
239
  prompt=prompt,
240
  max_tokens=int(max_tokens),
241
  use_chat_template=True,
242
  ):
 
 
 
 
 
 
 
243
  html_full = tokens_to_html(event.tokens, head_layers)
244
- status = f"""
245
- <div style="padding: 10px; background: #fff3cd; border-radius: 8px;">
246
- <b>Full Model:</b> {event.message}
247
- </div>
248
- """
 
 
 
 
 
 
 
249
  yield (
250
  tokens_to_html(final_ee_tokens, head_layers),
251
  html_full,
252
  status,
 
253
  legend,
254
  )
255
- final_full_tokens = event.tokens
256
-
257
- # Final stats
258
- result_ee = decoder.generate(
259
- prompt=prompt,
260
- max_tokens=int(max_tokens),
261
- use_early_exit=True,
262
- accuracy_level=closest_level,
263
- use_chat_template=True,
264
- )
265
- result_full = decoder.generate(
266
- prompt=prompt,
267
- max_tokens=int(max_tokens),
268
- use_early_exit=False,
269
- use_chat_template=True,
270
- )
271
 
272
- html_ee = tokens_to_html(result_ee.tokens, head_layers)
273
- html_full = tokens_to_html(result_full.tokens, head_layers)
274
-
275
- speedup = (
276
- result_ee.tokens_per_second / result_full.tokens_per_second
277
- if result_full.tokens_per_second > 0
278
- else 0
 
 
 
 
279
  )
280
- stats = f"""
281
- <div style="padding: 15px; background: #e8f5e9; border-radius: 8px;">
282
- <h3 style="margin: 0 0 10px 0;">🚀 Speedup: {speedup:.2f}x</h3>
283
- <div style="display: flex; gap: 20px;">
284
- <div style="flex: 1; padding: 10px; background: white; border-radius: 8px;">
285
- <h4>Early Exit</h4>
286
- <p><b>Time:</b> {result_ee.total_time:.2f}s | <b>Tokens/sec:</b> {result_ee.tokens_per_second:.2f}</p>
287
- <p><b>Avg Exit Layer:</b> {result_ee.avg_exit_layer:.1f}</p>
288
- </div>
289
- <div style="flex: 1; padding: 10px; background: white; border-radius: 8px;">
290
- <h4>Full Model</h4>
291
- <p><b>Time:</b> {result_full.total_time:.2f}s | <b>Tokens/sec:</b> {result_full.tokens_per_second:.2f}</p>
292
- <p><b>Avg Exit Layer:</b> {result_full.avg_exit_layer:.1f}</p>
293
- </div>
294
- </div>
295
- </div>
296
- """
297
- yield (html_ee, html_full, stats, legend)
298
 
299
  elif use_early_exit:
300
  # STREAMING mode for early exit - show draft/verify process
 
 
 
301
  for event in decoder.generate_streaming(
302
  prompt=prompt,
303
  max_tokens=int(max_tokens),
304
  accuracy_level=closest_level,
305
  use_chat_template=True,
306
  ):
 
 
 
 
 
 
 
 
307
  # Build HTML showing validated + drafted tokens
308
  validated_html = ""
309
  if event.tokens:
@@ -322,63 +421,86 @@ def generate(
322
  combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""
323
 
324
  # Status message
325
- status = f"""
326
- <div style="padding: 10px; background: #fff3cd; border-radius: 8px; margin-top: 5px;">
327
- <b>Status:</b> {event.message}
328
- </div>
329
- """
330
 
331
- yield (combined_html, "", status, legend)
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- # Final stats after streaming completes
334
- # Re-run to get final stats (or we could track during streaming)
335
- result = decoder.generate(
336
- prompt=prompt,
337
- max_tokens=int(max_tokens),
338
- use_early_exit=True,
339
- accuracy_level=closest_level,
340
- use_chat_template=True,
 
341
  )
342
- html = tokens_to_html(result.tokens, head_layers)
343
- stats = f"""
344
- <div style="padding: 15px; background: #f5f5f5; border-radius: 8px;">
345
- <h4 style="margin: 0 0 10px 0;">Early Exit Statistics (Final)</h4>
346
- <p><b>Tokens:</b> {len(result.tokens)} | <b>Tokens/sec:</b> {result.tokens_per_second:.2f} | <b>Avg Exit Layer:</b> {result.avg_exit_layer:.1f}</p>
347
- <p><b>Exit Distribution:</b> {result.exit_distribution}</p>
348
- </div>
349
- """
350
- yield (html, "", stats, legend)
351
 
352
  else:
353
  # Full model mode (streaming)
 
 
 
354
  for event in decoder.generate_full_model_streaming(
355
  prompt=prompt,
356
  max_tokens=int(max_tokens),
357
  use_chat_template=True,
358
  ):
 
 
 
 
 
 
 
359
  html = tokens_to_html(event.tokens, head_layers)
360
- status = f"""
361
- <div style="padding: 10px; background: #fff3cd; border-radius: 8px;">
362
- <b>Full Model:</b> {event.message}
363
- </div>
364
- """
365
- yield (html, "", status, legend)
366
-
367
- # Final stats
368
- result = decoder.generate(
369
- prompt=prompt,
370
- max_tokens=int(max_tokens),
371
- use_early_exit=False,
372
- use_chat_template=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  )
374
- html = tokens_to_html(result.tokens, head_layers)
375
- stats = f"""
376
- <div style="padding: 15px; background: #f5f5f5; border-radius: 8px;">
377
- <h4 style="margin: 0 0 10px 0;">Full Model Statistics</h4>
378
- <p><b>Tokens:</b> {len(result.tokens)} | <b>Time:</b> {result.total_time:.2f}s | <b>Tokens/sec:</b> {result.tokens_per_second:.2f}</p>
379
- </div>
380
- """
381
- yield (html, "", stats, legend)
382
 
383
 
384
  def build_demo():
@@ -444,8 +566,22 @@ def build_demo():
444
  gr.Markdown("### Full Model (Comparison)")
445
  output_full = gr.HTML()
446
 
447
- # Stats (full width)
448
- stats_html = gr.HTML()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
450
  def update_visibility(compare):
451
  return gr.update(visible=compare)
@@ -466,7 +602,21 @@ def build_demo():
466
  max_tokens,
467
  compare_mode,
468
  ],
469
- outputs=[output_ee, output_full, stats_html, legend_html],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  )
471
 
472
  return demo
@@ -474,4 +624,4 @@ def build_demo():
474
 
475
  if __name__ == "__main__":
476
  demo = build_demo()
477
- demo.launch(share=False)
 
4
  """
5
 
6
  import gradio as gr
7
+ from dataclasses import dataclass
8
  from pathlib import Path
9
+ import time
10
  from huggingface_hub import hf_hub_download
11
 
12
+ from src.inference import load_dssd_model, DSSDecoder, TokenInfo, StreamEvent, StreamingResult
13
 
14
  # Available models configuration
15
  AVAILABLE_MODELS = {
 
35
  ]
36
  FULL_MODEL_COLOR = "#95D5B2" # Light green - Full model
37
 
38
+ PENDING_TOKEN_BORDER = "var(--border-color-primary)"
39
+ PENDING_TOKEN_TEXT = "var(--body-text-color)"
40
+ DRAFTED_FALLBACK_COLOR = "var(--neutral-200)"
41
+
42
  # Global decoder cache
43
  _decoder_cache = {}
44
 
 
109
 
110
  html_parts.append(
111
  f'<span style="background-color: {color}; padding: 2px 4px; '
112
+ f'border-radius: 3px; margin: 1px; display: inline-block; color: #111827;" title="{title}">{text}</span>'
113
  )
114
 
115
  # Wrap in container with word-wrap to prevent overflow
 
127
  layer = head_layers[token.exit_head]
128
  title = f"PENDING - Head {token.exit_head} (Layer {layer})"
129
  else:
130
+ color = DRAFTED_FALLBACK_COLOR
131
+ title = "PENDING - Unassigned"
132
 
133
  text = (
134
  token.token_text.replace("&", "&amp;")
 
140
  html_parts.append(
141
  f'<span style="background-color: {color}; padding: 2px 4px; '
142
  f"border-radius: 3px; margin: 1px; display: inline-block; "
143
+ f"border: 2px dashed {PENDING_TOKEN_BORDER}; color: {PENDING_TOKEN_TEXT}; "
144
+ f'opacity: 0.75;" title="{title}">{text}</span>'
145
  )
146
 
147
  return "".join(html_parts)
 
163
  return " ".join(legend_items)
164
 
165
 
166
+
167
+ @dataclass
168
+ class StatsPayload:
169
+ generated_at: float
170
+ speedup_text: str
171
+ ee_time: str | None
172
+ ee_tps: str | None
173
+ ee_avg: str | None
174
+ full_time: str | None
175
+ full_tps: str | None
176
+ full_avg: str | None
177
+ show_ee: bool
178
+ show_full: bool
179
+
180
+
181
+ def build_stats_outputs(
182
+ result_ee,
183
+ result_full,
184
+ use_early_exit: bool,
185
+ compare_mode: bool,
186
+ generated_at: float | None = None,
187
+ ):
188
+ speedup_text = ""
189
+ if result_ee and result_full and result_full.tokens_per_second > 0:
190
+ speedup = result_ee.tokens_per_second / result_full.tokens_per_second
191
+ speedup_text = f"**Speedup:** {speedup:.2f}x"
192
+ elif result_ee:
193
+ speedup_text = "**Speedup:** N/A (full model not run)"
194
+ elif result_full:
195
+ speedup_text = "**Speedup:** N/A (early exit disabled)"
196
+
197
+ if not speedup_text:
198
+ speedup_text = "**Speedup:** N/A"
199
+
200
+ ee_time = f"{result_ee.total_time:.2f}" if result_ee else None
201
+ ee_tps = f"{result_ee.tokens_per_second:.2f}" if result_ee else None
202
+ ee_avg = f"{result_ee.avg_exit_layer:.1f}" if result_ee else None
203
+
204
+ full_time = f"{result_full.total_time:.2f}" if result_full else None
205
+ full_tps = f"{result_full.tokens_per_second:.2f}" if result_full else None
206
+ full_avg = f"{result_full.avg_exit_layer:.1f}" if result_full else None
207
+
208
+ show_ee = compare_mode or use_early_exit
209
+ show_full = compare_mode or not use_early_exit
210
+
211
+ return StatsPayload(
212
+ generated_at=generated_at if generated_at is not None else time.time(),
213
+ speedup_text=speedup_text,
214
+ ee_time=ee_time,
215
+ ee_tps=ee_tps,
216
+ ee_avg=ee_avg,
217
+ full_time=full_time,
218
+ full_tps=full_tps,
219
+ full_avg=full_avg,
220
+ show_ee=show_ee,
221
+ show_full=show_full,
222
+ )
223
+
224
+
225
+ def stats_payload_to_outputs(payload: StatsPayload):
226
+ return (
227
+ payload.speedup_text,
228
+ payload.ee_time,
229
+ payload.ee_tps,
230
+ payload.ee_avg,
231
+ payload.full_time,
232
+ payload.full_tps,
233
+ payload.full_avg,
234
+ gr.update(visible=payload.show_ee),
235
+ gr.update(visible=payload.show_full),
236
+ )
237
+
238
 
239
 
240
  def generate(
 
246
  compare_mode: bool,
247
  ):
248
  """Main generation function for Gradio interface with streaming."""
249
+ initial_stats_timestamp = time.time()
250
  try:
251
  decoder = get_decoder(model_key)
252
  except Exception as e:
253
  error_msg = f"<p style='color: red;'>Error loading model: {e}</p>"
254
+ status_msg = f"**Error loading model:** {e}"
255
+ stats_payload = build_stats_outputs(
256
+ None,
257
+ None,
258
+ use_early_exit,
259
+ compare_mode,
260
+ generated_at=initial_stats_timestamp,
261
+ )
262
+ yield (
263
+ error_msg,
264
+ "",
265
+ status_msg,
266
+ *stats_payload_to_outputs(stats_payload),
267
+ "",
268
+ )
269
  return
270
 
271
+
272
  head_layers = decoder.model_config.head_layer_indices
273
  legend = create_legend(head_layers)
274
 
 
283
  # Compare mode with streaming for early exit
284
  # First, stream the early exit generation
285
  final_ee_tokens = []
286
+ ee_streaming_result = None
287
+
288
  for event in decoder.generate_streaming(
289
  prompt=prompt,
290
  max_tokens=int(max_tokens),
291
  accuracy_level=closest_level,
292
  use_chat_template=True,
293
  ):
294
+ # Handle "complete" event - extract result and break
295
+ if event.event_type == "complete":
296
+ ee_streaming_result = event.result
297
+ final_ee_tokens = event.tokens
298
+ break
299
+
300
+ final_ee_tokens = event.tokens
301
  validated_html = ""
302
  if event.tokens:
303
  validated_html = tokens_to_html(event.tokens, head_layers)
 
312
 
313
  combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""
314
 
315
+ status = (
316
+ "**Early Exit:** {message} \n"
317
+ "**Full Model:** Waiting..."
318
+ ).format(
319
+ message=event.message,
320
+ )
321
 
322
+ stats_payload = build_stats_outputs(
323
+ None,
324
+ None,
325
+ use_early_exit,
326
+ compare_mode,
327
+ generated_at=initial_stats_timestamp,
328
+ )
329
  yield (
330
  combined_html,
331
+ "<p style='color: var(--body-text-color-subdued);'>Waiting for early exit to complete...</p>",
332
  status,
333
+ *stats_payload_to_outputs(stats_payload),
334
  legend,
335
  )
 
336
 
337
  # Now stream full model
338
  final_full_tokens = []
339
+ full_streaming_result = None
340
+
341
  for event in decoder.generate_full_model_streaming(
342
  prompt=prompt,
343
  max_tokens=int(max_tokens),
344
  use_chat_template=True,
345
  ):
346
+ # Handle "complete" event - extract result and break
347
+ if event.event_type == "complete":
348
+ full_streaming_result = event.result
349
+ final_full_tokens = event.tokens
350
+ break
351
+
352
+ final_full_tokens = event.tokens
353
  html_full = tokens_to_html(event.tokens, head_layers)
354
+ status = (
355
+ "**Full Model:** {message}"
356
+ ).format(
357
+ message=event.message,
358
+ )
359
+ stats_payload = build_stats_outputs(
360
+ None,
361
+ None,
362
+ use_early_exit,
363
+ compare_mode,
364
+ generated_at=initial_stats_timestamp,
365
+ )
366
  yield (
367
  tokens_to_html(final_ee_tokens, head_layers),
368
  html_full,
369
  status,
370
+ *stats_payload_to_outputs(stats_payload),
371
  legend,
372
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
+ # Final output with metrics from streaming results (no re-run needed)
375
+ html_ee = tokens_to_html(final_ee_tokens, head_layers)
376
+ html_full = tokens_to_html(final_full_tokens, head_layers)
377
+
378
+ stats_payload = build_stats_outputs(ee_streaming_result, full_streaming_result, use_early_exit, compare_mode)
379
+ yield (
380
+ html_ee,
381
+ html_full,
382
+ "",
383
+ *stats_payload_to_outputs(stats_payload),
384
+ legend,
385
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  elif use_early_exit:
388
  # STREAMING mode for early exit - show draft/verify process
389
+ streaming_result = None
390
+ final_tokens = []
391
+
392
  for event in decoder.generate_streaming(
393
  prompt=prompt,
394
  max_tokens=int(max_tokens),
395
  accuracy_level=closest_level,
396
  use_chat_template=True,
397
  ):
398
+ # Handle "complete" event - extract result and break
399
+ if event.event_type == "complete":
400
+ streaming_result = event.result
401
+ final_tokens = event.tokens
402
+ break
403
+
404
+ final_tokens = event.tokens
405
+
406
  # Build HTML showing validated + drafted tokens
407
  validated_html = ""
408
  if event.tokens:
 
421
  combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""
422
 
423
  # Status message
424
+ status = (
425
+ "**Status:** {message}"
426
+ ).format(
427
+ message=event.message,
428
+ )
429
 
430
+ stats_payload = build_stats_outputs(
431
+ None,
432
+ None,
433
+ use_early_exit,
434
+ compare_mode,
435
+ generated_at=initial_stats_timestamp,
436
+ )
437
+ yield (
438
+ combined_html,
439
+ "",
440
+ status,
441
+ *stats_payload_to_outputs(stats_payload),
442
+ legend,
443
+ )
444
 
445
+ # Final output with metrics from streaming result (no re-run needed)
446
+ html = tokens_to_html(final_tokens, head_layers)
447
+ stats_payload = build_stats_outputs(streaming_result, None, use_early_exit, compare_mode)
448
+ yield (
449
+ html,
450
+ "",
451
+ "",
452
+ *stats_payload_to_outputs(stats_payload),
453
+ legend,
454
  )
 
 
 
 
 
 
 
 
 
455
 
456
  else:
457
  # Full model mode (streaming)
458
+ streaming_result = None
459
+ final_tokens = []
460
+
461
  for event in decoder.generate_full_model_streaming(
462
  prompt=prompt,
463
  max_tokens=int(max_tokens),
464
  use_chat_template=True,
465
  ):
466
+ # Handle "complete" event - extract result and break
467
+ if event.event_type == "complete":
468
+ streaming_result = event.result
469
+ final_tokens = event.tokens
470
+ break
471
+
472
+ final_tokens = event.tokens
473
  html = tokens_to_html(event.tokens, head_layers)
474
+ status = (
475
+ "**Full Model:** {message}"
476
+ ).format(
477
+ message=event.message,
478
+ )
479
+ stats_payload = build_stats_outputs(
480
+ None,
481
+ None,
482
+ use_early_exit,
483
+ compare_mode,
484
+ generated_at=initial_stats_timestamp,
485
+ )
486
+ yield (
487
+ html,
488
+ "",
489
+ status,
490
+ *stats_payload_to_outputs(stats_payload),
491
+ legend,
492
+ )
493
+
494
+ # Final output with metrics from streaming result (no re-run needed)
495
+ html = tokens_to_html(final_tokens, head_layers)
496
+ stats_payload = build_stats_outputs(None, streaming_result, use_early_exit, compare_mode)
497
+ yield (
498
+ html,
499
+ "",
500
+ "",
501
+ *stats_payload_to_outputs(stats_payload),
502
+ legend,
503
  )
 
 
 
 
 
 
 
 
504
 
505
 
506
  def build_demo():
 
566
  gr.Markdown("### Full Model (Comparison)")
567
  output_full = gr.HTML()
568
 
569
+ status_html = gr.Markdown()
570
+
571
+ with gr.Group():
572
+ gr.Markdown("### Speedup Recap")
573
+ speedup_md = gr.Markdown()
574
+ with gr.Row():
575
+ with gr.Column(visible=True) as ee_stats_col:
576
+ gr.Markdown("#### Early Exit")
577
+ ee_time = gr.Label(label="Time (s)")
578
+ ee_tps = gr.Label(label="Tokens/sec")
579
+ ee_avg = gr.Label(label="Avg Exit Layer")
580
+ with gr.Column(visible=False) as full_stats_col:
581
+ gr.Markdown("#### Full Model")
582
+ full_time = gr.Label(label="Time (s)")
583
+ full_tps = gr.Label(label="Tokens/sec")
584
+ full_avg = gr.Label(label="Avg Exit Layer")
585
 
586
  def update_visibility(compare):
587
  return gr.update(visible=compare)
 
602
  max_tokens,
603
  compare_mode,
604
  ],
605
+ outputs=[
606
+ output_ee,
607
+ output_full,
608
+ status_html,
609
+ speedup_md,
610
+ ee_time,
611
+ ee_tps,
612
+ ee_avg,
613
+ full_time,
614
+ full_tps,
615
+ full_avg,
616
+ ee_stats_col,
617
+ full_stats_col,
618
+ legend_html,
619
+ ],
620
  )
621
 
622
  return demo
 
624
 
625
  if __name__ == "__main__":
626
  demo = build_demo()
627
+ demo.launch(share=False, debug=True)
src/inference.py CHANGED
@@ -53,14 +53,47 @@ class TokenInfo:
53
  uncertainty: float
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @dataclass
57
  class StreamEvent:
58
  """Event for streaming generation updates."""
59
 
60
- event_type: str # "draft", "verify_start", "accept", "reject", "full_model"
61
  tokens: List[TokenInfo] # All tokens so far (validated)
62
  drafted_tokens: List[TokenInfo] # Currently drafted (pending verification)
63
  message: str # Human-readable status
 
64
 
65
 
66
  @dataclass
@@ -100,19 +133,8 @@ class DSSDecoder:
100
  self.device = device
101
  self.uncertainty_fn = compute_entropy
102
 
103
- def generate(
104
- self,
105
- prompt: str,
106
- max_tokens: int = 100,
107
- use_early_exit: bool = True,
108
- accuracy_level: float = 0.75,
109
- use_chat_template: bool = True,
110
- ) -> GenerationResult:
111
- """
112
- Generate text with optional early exit.
113
- Returns detailed token-level information for visualization.
114
- """
115
- # Format prompt - check if tokenizer has a chat template set
116
  if (
117
  use_chat_template
118
  and hasattr(self.tokenizer, "chat_template")
@@ -123,18 +145,26 @@ class DSSDecoder:
123
  formatted = self.tokenizer.apply_chat_template(
124
  messages, add_generation_prompt=True, tokenize=False
125
  )
126
- input_ids = self.tokenizer.encode(formatted, return_tensors="pt").to(
127
  self.device
128
  )
129
  except Exception:
130
- # Fallback to raw prompt if chat template fails
131
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
132
- self.device
133
- )
134
- else:
135
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
136
- self.device
137
- )
 
 
 
 
 
 
 
 
138
 
139
  # Get thresholds
140
  thresholds = {}
@@ -186,29 +216,9 @@ class DSSDecoder:
186
  """
187
  Generate with streaming - yields events showing draft/verify process.
188
  Each event shows current validated tokens and pending drafted tokens.
 
189
  """
190
- # Format prompt
191
- if (
192
- use_chat_template
193
- and hasattr(self.tokenizer, "chat_template")
194
- and self.tokenizer.chat_template is not None
195
- ):
196
- try:
197
- messages = [{"role": "user", "content": prompt}]
198
- formatted = self.tokenizer.apply_chat_template(
199
- messages, add_generation_prompt=True, tokenize=False
200
- )
201
- input_ids = self.tokenizer.encode(formatted, return_tensors="pt").to(
202
- self.device
203
- )
204
- except Exception:
205
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
206
- self.device
207
- )
208
- else:
209
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
210
- self.device
211
- )
212
 
213
  # Get thresholds
214
  thresholds = {}
@@ -218,6 +228,7 @@ class DSSDecoder:
218
  validated_tokens = []
219
  current_ids = input_ids.clone()
220
  num_layers = self.adapter.get_num_layers()
 
221
 
222
  while len(validated_tokens) < max_tokens:
223
  # ============================================================
@@ -226,6 +237,7 @@ class DSSDecoder:
226
  drafted_tokens = []
227
  draft_ids = current_ids.clone()
228
  got_lm_head_token = False
 
229
 
230
  for _ in range(max_draft_length):
231
  if len(validated_tokens) + len(drafted_tokens) >= max_tokens:
@@ -240,7 +252,8 @@ class DSSDecoder:
240
  # EOS handling
241
  if exit_head is not None and drafted_tokens:
242
  break # Verify pending drafts first
243
- return # Stop generation
 
244
 
245
  token_text = self.tokenizer.decode([token_id])
246
  drafted_token = TokenInfo(
@@ -274,6 +287,10 @@ class DSSDecoder:
274
  message=f"Drafting token {len(drafted_tokens)} using Head {exit_head}",
275
  )
276
 
 
 
 
 
277
  # ============================================================
278
  # VERIFY PHASE
279
  # ============================================================
@@ -382,6 +399,17 @@ class DSSDecoder:
382
  ):
383
  break
384
 
 
 
 
 
 
 
 
 
 
 
 
385
  def _generate_with_early_exit(
386
  self,
387
  input_ids: torch.Tensor,
@@ -773,33 +801,14 @@ class DSSDecoder:
773
  ):
774
  """
775
  Generate with full model in streaming mode - yields each token as generated.
 
776
  """
777
- # Format prompt
778
- if (
779
- use_chat_template
780
- and hasattr(self.tokenizer, "chat_template")
781
- and self.tokenizer.chat_template is not None
782
- ):
783
- try:
784
- messages = [{"role": "user", "content": prompt}]
785
- formatted = self.tokenizer.apply_chat_template(
786
- messages, add_generation_prompt=True, tokenize=False
787
- )
788
- input_ids = self.tokenizer.encode(formatted, return_tensors="pt").to(
789
- self.device
790
- )
791
- except Exception:
792
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
793
- self.device
794
- )
795
- else:
796
- input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
797
- self.device
798
- )
799
 
800
  tokens = []
801
  current_ids = input_ids.clone()
802
  num_layers = self.adapter.get_num_layers()
 
803
 
804
  for i in range(max_tokens):
805
  with torch.no_grad():
@@ -832,6 +841,17 @@ class DSSDecoder:
832
  message=f"Token {i + 1}: '{token_text}'",
833
  )
834
 
 
 
 
 
 
 
 
 
 
 
 
835
 
836
  def load_dssd_model(
837
  model_name: str,
 
53
  uncertainty: float
54
 
55
 
56
+ @dataclass
57
+ class StreamingResult:
58
+ """Result from streaming generation with accumulated metrics."""
59
+
60
+ tokens: List[TokenInfo]
61
+ total_time: float
62
+ tokens_per_second: float
63
+ avg_exit_layer: float
64
+ exit_distribution: Dict[str, int]
65
+
66
+ @classmethod
67
+ def from_tokens(cls, tokens: List[TokenInfo], total_time: float, num_layers: int) -> "StreamingResult":
68
+ """Build a StreamingResult from a list of tokens and timing info."""
69
+ exit_dist: Dict[str, int] = {}
70
+ layer_sum = 0
71
+
72
+ for t in tokens:
73
+ key = str(t.exit_head) if t.exit_head is not None else "full"
74
+ exit_dist[key] = exit_dist.get(key, 0) + 1
75
+ layer_sum += t.exit_layer
76
+
77
+ avg_layer = layer_sum / len(tokens) if tokens else num_layers
78
+
79
+ return cls(
80
+ tokens=tokens,
81
+ total_time=total_time,
82
+ tokens_per_second=len(tokens) / total_time if total_time > 0 else 0,
83
+ avg_exit_layer=avg_layer,
84
+ exit_distribution=exit_dist,
85
+ )
86
+
87
+
88
  @dataclass
89
  class StreamEvent:
90
  """Event for streaming generation updates."""
91
 
92
+ event_type: str # "draft", "verify_start", "accept", "reject", "full_model", "complete"
93
  tokens: List[TokenInfo] # All tokens so far (validated)
94
  drafted_tokens: List[TokenInfo] # Currently drafted (pending verification)
95
  message: str # Human-readable status
96
+ result: Optional[StreamingResult] = None # Set on final "complete" event
97
 
98
 
99
  @dataclass
 
133
  self.device = device
134
  self.uncertainty_fn = compute_entropy
135
 
136
+ def _format_and_encode_prompt(self, prompt: str, use_chat_template: bool) -> torch.Tensor:
137
+ """Format prompt with optional chat template and return input_ids tensor."""
 
 
 
 
 
 
 
 
 
 
 
138
  if (
139
  use_chat_template
140
  and hasattr(self.tokenizer, "chat_template")
 
145
  formatted = self.tokenizer.apply_chat_template(
146
  messages, add_generation_prompt=True, tokenize=False
147
  )
148
+ return self.tokenizer.encode(formatted, return_tensors="pt").to(
149
  self.device
150
  )
151
  except Exception:
152
+ pass # Fall through to raw prompt encoding
153
+ return self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
154
+
155
+ def generate(
156
+ self,
157
+ prompt: str,
158
+ max_tokens: int = 100,
159
+ use_early_exit: bool = True,
160
+ accuracy_level: float = 0.75,
161
+ use_chat_template: bool = True,
162
+ ) -> GenerationResult:
163
+ """
164
+ Generate text with optional early exit.
165
+ Returns detailed token-level information for visualization.
166
+ """
167
+ input_ids = self._format_and_encode_prompt(prompt, use_chat_template)
168
 
169
  # Get thresholds
170
  thresholds = {}
 
216
  """
217
  Generate with streaming - yields events showing draft/verify process.
218
  Each event shows current validated tokens and pending drafted tokens.
219
+ Yields a final "complete" event with StreamingResult containing metrics.
220
  """
221
+ input_ids = self._format_and_encode_prompt(prompt, use_chat_template)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  # Get thresholds
224
  thresholds = {}
 
228
  validated_tokens = []
229
  current_ids = input_ids.clone()
230
  num_layers = self.adapter.get_num_layers()
231
+ start_time = time.time()
232
 
233
  while len(validated_tokens) < max_tokens:
234
  # ============================================================
 
237
  drafted_tokens = []
238
  draft_ids = current_ids.clone()
239
  got_lm_head_token = False
240
+ should_stop = False
241
 
242
  for _ in range(max_draft_length):
243
  if len(validated_tokens) + len(drafted_tokens) >= max_tokens:
 
252
  # EOS handling
253
  if exit_head is not None and drafted_tokens:
254
  break # Verify pending drafts first
255
+ should_stop = True
256
+ break # Stop generation
257
 
258
  token_text = self.tokenizer.decode([token_id])
259
  drafted_token = TokenInfo(
 
287
  message=f"Drafting token {len(drafted_tokens)} using Head {exit_head}",
288
  )
289
 
290
+ # Check if we should stop (EOS encountered with no pending drafts)
291
+ if should_stop:
292
+ break
293
+
294
  # ============================================================
295
  # VERIFY PHASE
296
  # ============================================================
 
399
  ):
400
  break
401
 
402
+ # Yield final "complete" event with metrics
403
+ total_time = time.time() - start_time
404
+ result = StreamingResult.from_tokens(validated_tokens, total_time, num_layers)
405
+ yield StreamEvent(
406
+ event_type="complete",
407
+ tokens=list(validated_tokens),
408
+ drafted_tokens=[],
409
+ message="Generation complete",
410
+ result=result,
411
+ )
412
+
413
  def _generate_with_early_exit(
414
  self,
415
  input_ids: torch.Tensor,
 
801
  ):
802
  """
803
  Generate with full model in streaming mode - yields each token as generated.
804
+ Yields a final "complete" event with StreamingResult containing metrics.
805
  """
806
+ input_ids = self._format_and_encode_prompt(prompt, use_chat_template)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807
 
808
  tokens = []
809
  current_ids = input_ids.clone()
810
  num_layers = self.adapter.get_num_layers()
811
+ start_time = time.time()
812
 
813
  for i in range(max_tokens):
814
  with torch.no_grad():
 
841
  message=f"Token {i + 1}: '{token_text}'",
842
  )
843
 
844
+ # Yield final "complete" event with metrics
845
+ total_time = time.time() - start_time
846
+ result = StreamingResult.from_tokens(tokens, total_time, num_layers)
847
+ yield StreamEvent(
848
+ event_type="complete",
849
+ tokens=list(tokens),
850
+ drafted_tokens=[],
851
+ message="Generation complete",
852
+ result=result,
853
+ )
854
+
855
 
856
  def load_dssd_model(
857
  model_name: str,
src/jagged_cache.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ JaggedKVCache - Sparse KV Cache for Early Exit Inference.
3
+
4
+ This cache tracks per-layer sequence lengths, enabling efficient
5
+ generation with early exit heads that stop at different layers.
6
+ """
7
+
8
+ import torch
9
+ from typing import List, Tuple, Optional
10
+
11
+
12
+ class JaggedKVCache:
13
+ """
14
+ Sparse KV Cache that tracks per-layer sequence lengths.
15
+
16
+ Unlike standard KV caches where all layers have the same length,
17
+ this cache allows different layers to have different cached lengths.
18
+ This is essential for early exit inference where tokens may exit
19
+ at different layers.
20
+
21
+ Key features:
22
+ - Per-layer KV storage with independent lengths
23
+ - Lazy fill: missing positions are detected and can be computed on-demand
24
+ - Truncation: efficient rollback on rejection
25
+ - Cloning: snapshot for speculative drafting
26
+
27
+ Attributes:
28
+ num_layers: Total number of transformer layers
29
+ batch_size: Batch size (typically 1 for inference)
30
+ num_kv_heads: Number of key-value heads
31
+ head_dim: Dimension of each head
32
+ device: Device to store tensors on
33
+ dtype: Data type for tensors
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ num_layers: int,
39
+ batch_size: int = 1,
40
+ num_kv_heads: int = 8,
41
+ head_dim: int = 128,
42
+ device: str = "cpu",
43
+ dtype: torch.dtype = torch.float32,
44
+ ):
45
+ self.num_layers = num_layers
46
+ self.batch_size = batch_size
47
+ self.num_kv_heads = num_kv_heads
48
+ self.head_dim = head_dim
49
+ self.device = device
50
+ self.dtype = dtype
51
+
52
+ # Per-layer storage: List of (key_cache, value_cache) or None
53
+ self.layer_caches: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
54
+ None for _ in range(num_layers)
55
+ ]
56
+
57
+ # Track sequence length per layer (capacity = max_position + 1)
58
+ self.layer_seq_lengths: List[int] = [0] * num_layers
59
+
60
+ # Track which positions are actually filled (for lazy fill detection)
61
+ # This is a list of sets, one per layer
62
+ self.filled_positions: List[set] = [set() for _ in range(num_layers)]
63
+
64
+ def update(
65
+ self,
66
+ layer_idx: int,
67
+ key_states: torch.Tensor,
68
+ value_states: torch.Tensor,
69
+ cache_position: torch.Tensor,
70
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
71
+ """
72
+ Update cache for a layer at specific positions.
73
+
74
+ Args:
75
+ layer_idx: Layer index to update
76
+ key_states: [B, num_kv_heads, seq_len, head_dim] new key states
77
+ value_states: [B, num_kv_heads, seq_len, head_dim] new value states
78
+ cache_position: [seq_len] tensor of positions to update
79
+
80
+ Returns:
81
+ (full_keys, full_values) tuple with all cached data
82
+ """
83
+ new_len = cache_position[-1].item() + 1
84
+ input_seq_len = key_states.shape[2]
85
+ positions = cache_position.tolist()
86
+
87
+ if self.layer_caches[layer_idx] is None:
88
+ # First time - check if positions are contiguous starting from 0
89
+ if cache_position[0].item() == 0 and input_seq_len == new_len:
90
+ # Simple case: positions [0, 1, ..., n-1] - just clone
91
+ self.layer_caches[layer_idx] = (
92
+ key_states.clone(),
93
+ value_states.clone(),
94
+ )
95
+ else:
96
+ # Non-contiguous or not starting from 0 - allocate full size
97
+ k_cache = torch.zeros(
98
+ (self.batch_size, self.num_kv_heads, new_len, self.head_dim),
99
+ device=self.device,
100
+ dtype=self.dtype,
101
+ )
102
+ v_cache = torch.zeros(
103
+ (self.batch_size, self.num_kv_heads, new_len, self.head_dim),
104
+ device=self.device,
105
+ dtype=self.dtype,
106
+ )
107
+ k_cache[:, :, cache_position.long(), :] = key_states
108
+ v_cache[:, :, cache_position.long(), :] = value_states
109
+ self.layer_caches[layer_idx] = (k_cache, v_cache)
110
+
111
+ self.layer_seq_lengths[layer_idx] = new_len
112
+ else:
113
+ k_cache, v_cache = self.layer_caches[layer_idx]
114
+ current_len = k_cache.shape[2]
115
+
116
+ if new_len > current_len:
117
+ # Need to extend cache
118
+ extension_size = new_len - current_len
119
+ k_extension = torch.zeros(
120
+ (self.batch_size, self.num_kv_heads, extension_size, self.head_dim),
121
+ device=self.device,
122
+ dtype=self.dtype,
123
+ )
124
+ v_extension = torch.zeros(
125
+ (self.batch_size, self.num_kv_heads, extension_size, self.head_dim),
126
+ device=self.device,
127
+ dtype=self.dtype,
128
+ )
129
+ k_cache = torch.cat([k_cache, k_extension], dim=2)
130
+ v_cache = torch.cat([v_cache, v_extension], dim=2)
131
+
132
+ # Update at cache_position
133
+ k_cache[:, :, cache_position.long(), :] = key_states
134
+ v_cache[:, :, cache_position.long(), :] = value_states
135
+
136
+ self.layer_caches[layer_idx] = (k_cache, v_cache)
137
+ self.layer_seq_lengths[layer_idx] = max(
138
+ self.layer_seq_lengths[layer_idx], new_len
139
+ )
140
+
141
+ # Track filled positions
142
+ self.filled_positions[layer_idx].update(positions)
143
+
144
+ return self.layer_caches[layer_idx]
145
+
146
+ def get_kv(self, layer_idx: int) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
147
+ """Get cached KV for a layer, or None if not cached."""
148
+ return self.layer_caches[layer_idx]
149
+
150
+ def get_seq_length(self, layer_idx: int) -> int:
151
+ """Get the sequence length (capacity) for a layer."""
152
+ return self.layer_seq_lengths[layer_idx]
153
+
154
+ def has_position(self, layer_idx: int, position: int) -> bool:
155
+ """Check if a specific position is filled for a layer."""
156
+ return position in self.filled_positions[layer_idx]
157
+
158
+ def get_unfilled_positions(self, layer_idx: int, up_to: int) -> List[int]:
159
+ """Get list of positions that are not filled for a layer, up to `up_to` (exclusive)."""
160
+ all_positions = set(range(up_to))
161
+ filled = self.filled_positions[layer_idx]
162
+ return sorted(all_positions - filled)
163
+
164
+ def needs_fill(self, layer_idx: int, positions: List[int]) -> bool:
165
+ """Check if any of the given positions need to be filled for a layer."""
166
+ return not all(p in self.filled_positions[layer_idx] for p in positions)
167
+
168
+ def get_missing_layers(self, position: int, target_layer: int) -> List[int]:
169
+ """
170
+ Get list of layers that need computation for a position.
171
+
172
+ Args:
173
+ position: The position we need KV for
174
+ target_layer: The deepest layer we need to reach
175
+
176
+ Returns:
177
+ List of layer indices that need computation for this position
178
+ """
179
+ missing = []
180
+ for layer_idx in range(target_layer + 1):
181
+ if position not in self.filled_positions[layer_idx]:
182
+ missing.append(layer_idx)
183
+ return missing
184
+
185
+ def truncate_from(self, position: int):
186
+ """
187
+ Truncate all layer caches from position onwards (exclusive).
188
+ Used for rollback on rejection.
189
+
190
+ Args:
191
+ position: First position to remove (keeps 0..position-1)
192
+ """
193
+ for layer_idx in range(self.num_layers):
194
+ if self.layer_caches[layer_idx] is not None:
195
+ k, v = self.layer_caches[layer_idx]
196
+ if k.shape[2] > position:
197
+ self.layer_caches[layer_idx] = (
198
+ k[:, :, :position, :].contiguous(),
199
+ v[:, :, :position, :].contiguous(),
200
+ )
201
+ self.layer_seq_lengths[layer_idx] = min(
202
+ self.layer_seq_lengths[layer_idx], position
203
+ )
204
+
205
+ # Remove filled positions >= position
206
+ self.filled_positions[layer_idx] = {
207
+ p for p in self.filled_positions[layer_idx] if p < position
208
+ }
209
+
210
+ def clone(self) -> "JaggedKVCache":
211
+ """
212
+ Create a deep copy of the cache for speculative drafting.
213
+
214
+ Returns:
215
+ Independent copy that can be modified without affecting original
216
+ """
217
+ new_cache = JaggedKVCache(
218
+ num_layers=self.num_layers,
219
+ batch_size=self.batch_size,
220
+ num_kv_heads=self.num_kv_heads,
221
+ head_dim=self.head_dim,
222
+ device=self.device,
223
+ dtype=self.dtype,
224
+ )
225
+ for i, kv in enumerate(self.layer_caches):
226
+ if kv is not None:
227
+ new_cache.layer_caches[i] = (kv[0].clone(), kv[1].clone())
228
+ new_cache.layer_seq_lengths = self.layer_seq_lengths.copy()
229
+ new_cache.filled_positions = [s.copy() for s in self.filled_positions]
230
+ return new_cache
231
+
232
+ def reset(self):
233
+ """Reset the cache to empty state."""
234
+ self.layer_caches = [None for _ in range(self.num_layers)]
235
+ self.layer_seq_lengths = [0] * self.num_layers
236
+ self.filled_positions = [set() for _ in range(self.num_layers)]
237
+
238
+ def __repr__(self) -> str:
239
+ lines = [f"JaggedKVCache(num_layers={self.num_layers}, device={self.device})"]
240
+ for i in range(min(self.num_layers, 10)): # Show first 10 layers
241
+ seq_len = self.layer_seq_lengths[i]
242
+ filled = len(self.filled_positions[i])
243
+ if seq_len > 0:
244
+ lines.append(f" Layer {i:2d}: {filled}/{seq_len} filled")
245
+ if self.num_layers > 10:
246
+ lines.append(f" ... ({self.num_layers - 10} more layers)")
247
+ return "\n".join(lines)
src/model_adapters.py CHANGED
@@ -127,11 +127,7 @@ class LlamaStyleAdapter(ModelAdapter):
127
  ) -> Optional[Tuple[Tensor, Tensor]]:
128
  if self._rotary is not None:
129
  cos, sin = self._rotary(hidden_states, position_ids)
130
- # Unsqueeze to (batch, 1, seq_len, head_dim) to support broadcasting
131
- # This matches LlamaModel behavior which prepares embeddings for layers
132
- if cos.dim() == 3:
133
- cos = cos.unsqueeze(1)
134
- sin = sin.unsqueeze(1)
135
  return (cos, sin)
136
  return None
137
 
 
127
  ) -> Optional[Tuple[Tensor, Tensor]]:
128
  if self._rotary is not None:
129
  cos, sin = self._rotary(hidden_states, position_ids)
130
+ # Return as-is - the model's apply_rotary_pos_emb handles unsqueezing
 
 
 
 
131
  return (cos, sin)
132
  return None
133
 
src/model_config.py CHANGED
@@ -2,7 +2,7 @@
2
  # Re-exported from the main package for demo use
3
 
4
  import json
5
- from dataclasses import dataclass, field, asdict
6
  from typing import Dict, List, Optional
7
 
8
 
@@ -34,10 +34,6 @@ class ModelConfig:
34
  training_config=data.get("training_config"),
35
  )
36
 
37
- def to_json(self, path: str) -> None:
38
- with open(path, "w") as f:
39
- json.dump(asdict(self), f, indent=2)
40
-
41
 
42
  @dataclass
43
  class CalibrationResult:
@@ -57,15 +53,6 @@ class CalibrationResult:
57
  data = json.load(f)
58
  return cls(**data)
59
 
60
- def to_json(self, path: str) -> None:
61
- with open(path, "w") as f:
62
- json.dump(asdict(self), f, indent=2)
63
-
64
- def get_threshold(self, accuracy_level: float, head_idx: int) -> float:
65
- level_key = f"{accuracy_level:.2f}"
66
- head_key = str(head_idx)
67
- return self.thresholds[level_key][head_key]
68
-
69
  def get_thresholds_for_level(self, accuracy_level: float) -> Dict[int, float]:
70
  """Get all thresholds for a given accuracy level."""
71
  level_key = f"{accuracy_level:.2f}"
 
2
  # Re-exported from the main package for demo use
3
 
4
  import json
5
+ from dataclasses import dataclass, field
6
  from typing import Dict, List, Optional
7
 
8
 
 
34
  training_config=data.get("training_config"),
35
  )
36
 
 
 
 
 
37
 
38
  @dataclass
39
  class CalibrationResult:
 
53
  data = json.load(f)
54
  return cls(**data)
55
 
 
 
 
 
 
 
 
 
 
56
  def get_thresholds_for_level(self, accuracy_level: float) -> Dict[int, float]:
57
  """Get all thresholds for a given accuracy level."""
58
  level_key = f"{accuracy_level:.2f}"
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Tests package for DSSD demo
tests/benchmark_kv_cache.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark tests for KV Cache optimization in DSSD.
3
+
4
+ This module provides deterministic benchmarks to measure:
5
+ 1. Layer forward counts (direct measure of computation)
6
+ 2. Wall clock time for draft + verify phases
7
+ 3. Optional FLOPs estimation
8
+
9
+ Run with: python -m tests.benchmark_kv_cache
10
+ """
11
+
12
+ import time
13
+ import torch
14
+ import torch.nn as nn
15
+ from dataclasses import dataclass, field
16
+ from typing import Dict, List, Optional, Tuple
17
+ from contextlib import contextmanager
18
+
19
+
20
+ # =============================================================================
21
+ # Instrumentation
22
+ # =============================================================================
23
+
24
+
25
+ @dataclass
26
+ class BenchmarkMetrics:
27
+ """Tracks metrics during benchmark run."""
28
+
29
+ # Layer forward counts
30
+ layer_forward_counts: Dict[int, int] = field(default_factory=dict)
31
+ total_layer_forwards: int = 0
32
+
33
+ # Timing
34
+ draft_time_ms: float = 0.0
35
+ verify_time_ms: float = 0.0
36
+ total_time_ms: float = 0.0
37
+
38
+ # Token counts
39
+ tokens_drafted: int = 0
40
+ tokens_accepted: int = 0
41
+ tokens_rejected: int = 0
42
+
43
+ # Early exit distribution
44
+ exit_layers: List[int] = field(default_factory=list)
45
+
46
+ def reset(self):
47
+ """Reset all metrics."""
48
+ self.layer_forward_counts.clear()
49
+ self.total_layer_forwards = 0
50
+ self.draft_time_ms = 0.0
51
+ self.verify_time_ms = 0.0
52
+ self.total_time_ms = 0.0
53
+ self.tokens_drafted = 0
54
+ self.tokens_accepted = 0
55
+ self.tokens_rejected = 0
56
+ self.exit_layers.clear()
57
+
58
+ def record_layer_forward(self, layer_idx: int):
59
+ """Record a layer forward pass."""
60
+ self.layer_forward_counts[layer_idx] = (
61
+ self.layer_forward_counts.get(layer_idx, 0) + 1
62
+ )
63
+ self.total_layer_forwards += 1
64
+
65
+ def summary(self) -> str:
66
+ """Return human-readable summary."""
67
+ lines = [
68
+ "=" * 50,
69
+ "BENCHMARK METRICS",
70
+ "=" * 50,
71
+ f"Total Layer Forwards: {self.total_layer_forwards}",
72
+ f"Tokens Drafted: {self.tokens_drafted}",
73
+ f"Tokens Accepted: {self.tokens_accepted}",
74
+ f"Tokens Rejected: {self.tokens_rejected}",
75
+ f"Draft Time: {self.draft_time_ms:.2f} ms",
76
+ f"Verify Time: {self.verify_time_ms:.2f} ms",
77
+ f"Total Time: {self.total_time_ms:.2f} ms",
78
+ "",
79
+ "Layer Forward Distribution:",
80
+ ]
81
+ for layer_idx in sorted(self.layer_forward_counts.keys()):
82
+ count = self.layer_forward_counts[layer_idx]
83
+ lines.append(f" Layer {layer_idx:2d}: {count} forwards")
84
+
85
+ if self.exit_layers:
86
+ avg_exit = sum(self.exit_layers) / len(self.exit_layers)
87
+ lines.append(f"\nAverage Exit Layer: {avg_exit:.1f}")
88
+
89
+ lines.append("=" * 50)
90
+ return "\n".join(lines)
91
+
92
+
93
+ # Global metrics instance for instrumentation
94
+ _metrics: Optional[BenchmarkMetrics] = None
95
+
96
+
97
+ def get_metrics() -> Optional[BenchmarkMetrics]:
98
+ """Get the current metrics instance."""
99
+ return _metrics
100
+
101
+
102
+ @contextmanager
103
+ def benchmark_context():
104
+ """Context manager that enables metric collection."""
105
+ global _metrics
106
+ _metrics = BenchmarkMetrics()
107
+ try:
108
+ yield _metrics
109
+ finally:
110
+ _metrics = None
111
+
112
+
113
+ def instrument_layer_forward(layer_idx: int):
114
+ """Call this from forward_layer to record layer execution."""
115
+ if _metrics is not None:
116
+ _metrics.record_layer_forward(layer_idx)
117
+
118
+
119
+ # =============================================================================
120
+ # Timer Utilities
121
+ # =============================================================================
122
+
123
+
124
+ class Timer:
125
+ """Simple timer for benchmarking."""
126
+
127
+ def __init__(self):
128
+ self.start_time = None
129
+ self.elapsed_ms = 0.0
130
+
131
+ def start(self):
132
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
133
+ self.start_time = time.perf_counter()
134
+
135
+ def stop(self) -> float:
136
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
137
+ if self.start_time is not None:
138
+ self.elapsed_ms = (time.perf_counter() - self.start_time) * 1000
139
+ return self.elapsed_ms
140
+
141
+
142
+ # =============================================================================
143
+ # Benchmark Test Scenarios
144
+ # =============================================================================
145
+
146
+
147
+ @dataclass
148
+ class BenchmarkConfig:
149
+ """Configuration for benchmark runs."""
150
+
151
+ # Model setting
152
+ model_name: str = "Qwen/Qwen3-0.6B"
153
+
154
+ # Generation settings
155
+ prompt: str = "Explain what machine learning is in simple terms."
156
+ max_draft_length: int = 5
157
+ num_iterations: int = 3 # Multiple iterations for averaging
158
+
159
+ # Thresholds for early exit (simulated or real)
160
+ accuracy_level: float = 0.75
161
+
162
+ # Reproducibility
163
+ seed: int = 42
164
+
165
+
166
+ def run_single_draft_verify_benchmark(
167
+ decoder, # DSSDecoder
168
+ config: BenchmarkConfig,
169
+ use_cache: bool = False,
170
+ ) -> BenchmarkMetrics:
171
+ """
172
+ Run a single draft + verify cycle and measure metrics.
173
+
174
+ Args:
175
+ decoder: The DSSDecoder instance
176
+ config: Benchmark configuration
177
+ use_cache: Whether to use JaggedKVCache (for comparison)
178
+
179
+ Returns:
180
+ BenchmarkMetrics with recorded data
181
+ """
182
+ # Set seed for reproducibility
183
+ torch.manual_seed(config.seed)
184
+ if torch.cuda.is_available():
185
+ torch.cuda.manual_seed(config.seed)
186
+
187
+ with benchmark_context() as metrics:
188
+ timer = Timer()
189
+
190
+ # Tokenize prompt
191
+ input_ids = decoder.tokenizer.encode(config.prompt, return_tensors="pt").to(
192
+ decoder.device
193
+ )
194
+
195
+ # Get thresholds
196
+ thresholds = {}
197
+ if decoder.calibration:
198
+ thresholds = decoder.calibration.get_thresholds_for_level(
199
+ config.accuracy_level
200
+ )
201
+
202
+ # ========== DRAFT PHASE ==========
203
+ timer.start()
204
+ drafted_tokens = []
205
+ draft_ids = input_ids.clone()
206
+
207
+ for _ in range(config.max_draft_length):
208
+ # Call the drafting function
209
+ # Note: This will need to be modified to use our instrumented version
210
+ draft_result = decoder._draft_single_token(draft_ids, thresholds)
211
+
212
+ if draft_result is None:
213
+ break
214
+
215
+ token_id, exit_head, exit_layer, uncertainty = draft_result
216
+ drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
217
+ metrics.exit_layers.append(exit_layer)
218
+
219
+ if token_id == decoder.tokenizer.eos_token_id:
220
+ break
221
+
222
+ draft_ids = torch.cat(
223
+ [draft_ids, torch.tensor([[token_id]], device=decoder.device)], dim=1
224
+ )
225
+
226
+ metrics.draft_time_ms = timer.stop()
227
+ metrics.tokens_drafted = len(drafted_tokens)
228
+
229
+ # ========== VERIFY PHASE ==========
230
+ timer.start()
231
+
232
+ if drafted_tokens:
233
+ with torch.no_grad():
234
+ outputs = decoder.model(draft_ids, use_cache=False)
235
+ verify_logits = outputs.logits
236
+
237
+ # Verify each token
238
+ start_pos = input_ids.shape[1] - 1
239
+ accepted = 0
240
+
241
+ for i, (token_id, exit_head, exit_layer, uncertainty) in enumerate(
242
+ drafted_tokens
243
+ ):
244
+ verify_pos = start_pos + i
245
+ verified_token = torch.argmax(verify_logits[0, verify_pos, :]).item()
246
+
247
+ if token_id == verified_token:
248
+ accepted += 1
249
+ else:
250
+ break
251
+
252
+ metrics.tokens_accepted = accepted
253
+ metrics.tokens_rejected = len(drafted_tokens) - accepted
254
+
255
+ metrics.verify_time_ms = timer.stop()
256
+ metrics.total_time_ms = metrics.draft_time_ms + metrics.verify_time_ms
257
+
258
+ return metrics
259
+
260
+
261
+ def run_baseline_benchmark(decoder, config: BenchmarkConfig) -> BenchmarkMetrics:
262
+ """
263
+ Run baseline benchmark (current implementation without cache optimization).
264
+ """
265
+ print(f"\n{'=' * 60}")
266
+ print("BASELINE BENCHMARK (No Cache)")
267
+ print(f"{'=' * 60}")
268
+ print(f"Model: {config.model_name}")
269
+ print(f"Prompt: '{config.prompt[:50]}...'")
270
+ print(f"Max Draft Length: {config.max_draft_length}")
271
+ print(f"Iterations: {config.num_iterations}")
272
+
273
+ all_metrics = []
274
+
275
+ for i in range(config.num_iterations):
276
+ print(f"\nIteration {i + 1}/{config.num_iterations}...")
277
+ metrics = run_single_draft_verify_benchmark(decoder, config, use_cache=False)
278
+ all_metrics.append(metrics)
279
+ print(f" Layer Forwards: {metrics.total_layer_forwards}")
280
+ print(f" Draft Time: {metrics.draft_time_ms:.2f} ms")
281
+ print(f" Verify Time: {metrics.verify_time_ms:.2f} ms")
282
+
283
+ # Average metrics
284
+ avg_metrics = BenchmarkMetrics()
285
+ avg_metrics.total_layer_forwards = sum(
286
+ m.total_layer_forwards for m in all_metrics
287
+ ) // len(all_metrics)
288
+ avg_metrics.draft_time_ms = sum(m.draft_time_ms for m in all_metrics) / len(
289
+ all_metrics
290
+ )
291
+ avg_metrics.verify_time_ms = sum(m.verify_time_ms for m in all_metrics) / len(
292
+ all_metrics
293
+ )
294
+ avg_metrics.total_time_ms = sum(m.total_time_ms for m in all_metrics) / len(
295
+ all_metrics
296
+ )
297
+ avg_metrics.tokens_drafted = all_metrics[0].tokens_drafted
298
+ avg_metrics.tokens_accepted = all_metrics[0].tokens_accepted
299
+ avg_metrics.tokens_rejected = all_metrics[0].tokens_rejected
300
+
301
+ # Combine layer counts
302
+ for m in all_metrics:
303
+ for layer_idx, count in m.layer_forward_counts.items():
304
+ avg_metrics.layer_forward_counts[layer_idx] = (
305
+ avg_metrics.layer_forward_counts.get(layer_idx, 0)
306
+ + count // len(all_metrics)
307
+ )
308
+
309
+ print("\n" + avg_metrics.summary())
310
+ return avg_metrics
311
+
312
+
313
+ # =============================================================================
314
+ # Main Entry Point
315
+ # =============================================================================
316
+
317
+
318
+ def main():
319
+ """Run benchmark suite."""
320
+ import sys
321
+
322
+ sys.path.insert(0, "/home/fvalade/workspace/DSSD_demo")
323
+
324
+ from src.inference import load_dssd_model
325
+
326
+ config = BenchmarkConfig()
327
+
328
+ print("Loading model...")
329
+ try:
330
+ # You'll need to update these paths to match your setup
331
+ decoder, tokenizer = load_dssd_model(
332
+ model_name=config.model_name,
333
+ heads_path="../checkpoints/qwen3-0.6b/aux_heads.pt",
334
+ config_path="../checkpoints/qwen3-0.6b/config.json",
335
+ calibration_path="../checkpoints/qwen3-0.6b/calibration.json",
336
+ device="auto",
337
+ )
338
+ print("Model loaded successfully!")
339
+ except Exception as e:
340
+ print(f"Error loading model: {e}")
341
+ print("\nTo run this benchmark, ensure you have:")
342
+ print(" 1. A trained auxiliary heads checkpoint")
343
+ print(" 2. The corresponding config.json")
344
+ print(" 3. (Optional) calibration.json for thresholds")
345
+ return
346
+
347
+ # Run baseline benchmark
348
+ baseline_metrics = run_baseline_benchmark(decoder, config)
349
+
350
+ # Save results for later comparison
351
+ print("\n" + "=" * 60)
352
+ print("BASELINE RESULTS SAVED")
353
+ print("Run this again after implementing JaggedKVCache to compare.")
354
+ print("=" * 60)
355
+
356
+
357
+ if __name__ == "__main__":
358
+ main()
tests/run_benchmark.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Benchmark comparison: Standard generation vs Cache-optimized generation.
4
+
5
+ This script measures and compares:
6
+ - Layer forward counts
7
+ - Wall clock time
8
+ - Tokens per second
9
+
10
+ Usage:
11
+ python tests/run_benchmark.py --model Qwen/Qwen3-0.6B --heads-path /path/to/heads.pt
12
+ """
13
+
14
+ import argparse
15
+ import time
16
+ import sys
17
+ import os
18
+
19
+ # Add project to path
20
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
+
22
+ import torch
23
+
24
+
25
+ def make_dummy_decoder():
26
+ """Create a minimal decoder for benchmarking without GPU."""
27
+ from src.jagged_cache import JaggedKVCache
28
+
29
+ print("\n" + "=" * 60)
30
+ print("BENCHMARK: JaggedKVCache Operations (No GPU Required)")
31
+ print("=" * 60)
32
+
33
+ # Test cache performance
34
+ num_layers = 28
35
+ batch_size = 1
36
+ num_heads = 8
37
+ head_dim = 128
38
+ seq_len = 100
39
+
40
+ cache = JaggedKVCache(
41
+ num_layers=num_layers,
42
+ batch_size=batch_size,
43
+ num_kv_heads=num_heads,
44
+ head_dim=head_dim,
45
+ device="cpu",
46
+ dtype=torch.float32,
47
+ )
48
+
49
+ # Simulate prefill
50
+ print(f"\nSimulating prefill ({seq_len} tokens, {num_layers} layers)...")
51
+ start = time.perf_counter()
52
+ for pos in range(seq_len):
53
+ for layer_idx in range(num_layers):
54
+ k = torch.randn(batch_size, num_heads, 1, head_dim)
55
+ v = torch.randn(batch_size, num_heads, 1, head_dim)
56
+ cache.update(layer_idx, k, v, torch.tensor([pos]))
57
+ prefill_time = (time.perf_counter() - start) * 1000
58
+ print(f" Prefill time: {prefill_time:.2f} ms")
59
+
60
+ # Simulate draft phase (early exit at different layers)
61
+ print("\nSimulating draft phase (5 tokens, variable exit layers)...")
62
+ exit_layers = [4, 8, 6, 12, 10] # Simulate different exit layers
63
+ draft_cache = cache.clone()
64
+
65
+ start = time.perf_counter()
66
+ for i, exit_layer in enumerate(exit_layers):
67
+ pos = seq_len + i
68
+ for layer_idx in range(exit_layer + 1):
69
+ k = torch.randn(batch_size, num_heads, 1, head_dim)
70
+ v = torch.randn(batch_size, num_heads, 1, head_dim)
71
+ draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
72
+ draft_time = (time.perf_counter() - start) * 1000
73
+ print(f" Draft time: {draft_time:.2f} ms")
74
+
75
+ # Print cache state
76
+ print("\nCache state after drafting:")
77
+ for layer_idx in [0, 4, 8, 12, 16, 20, 24, 27]:
78
+ filled = len(draft_cache.filled_positions[layer_idx])
79
+ print(f" Layer {layer_idx:2d}: {filled} positions filled")
80
+
81
+ # Simulate verification (fill all layers for all positions)
82
+ print("\nSimulating verification (lazy fill + full model)...")
83
+ start = time.perf_counter()
84
+ for pos in range(seq_len, seq_len + 5):
85
+ # Find missing layers
86
+ missing = draft_cache.get_missing_layers(pos, num_layers - 1)
87
+ for layer_idx in missing:
88
+ k = torch.randn(batch_size, num_heads, 1, head_dim)
89
+ v = torch.randn(batch_size, num_heads, 1, head_dim)
90
+ draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
91
+ verify_time = (time.perf_counter() - start) * 1000
92
+ print(f" Verify time: {verify_time:.2f} ms")
93
+
94
+ # Calculate and explain savings
95
+ print("\n" + "=" * 60)
96
+ print("ANALYSIS: Layer Operations")
97
+ print("=" * 60)
98
+
99
+ # Prefill ops (same for all approaches - one-time cost)
100
+ prefill_ops = seq_len * num_layers
101
+ print(f"\nPREFILL (one-time): {prefill_ops} layer ops")
102
+
103
+ # Draft phase with early exit
104
+ draft_ops = sum(exit_layer + 1 for exit_layer in exit_layers)
105
+ draft_ops_full = 5 * num_layers # Without early exit
106
+ print(f"\nDRAFT PHASE (5 tokens):")
107
+ print(f" With early exit: {draft_ops} ops (avg {draft_ops / 5:.1f} layers/token)")
108
+ print(f" Without early exit: {draft_ops_full} ops ({num_layers} layers/token)")
109
+ print(
110
+ f" Draft savings: {draft_ops_full - draft_ops} ops ({100 * (1 - draft_ops / draft_ops_full):.0f}% reduction)"
111
+ )
112
+
113
+ # The KEY benefit: with cache, each draft token is O(1 token * exit_layer)
114
+ # Without cache, it would be O(seq_len * exit_layer) per token
115
+ print(f"\nCACHE BENEFIT:")
116
+ print(f" Without cache, each draft would recompute {seq_len}-token context")
117
+ print(f" With cache, each draft processes only 1 new token")
118
+ per_token_savings = seq_len - 1 # Positions we don't recompute
119
+ total_context_savings = per_token_savings * draft_ops
120
+ print(f" Context reuse savings: ~{total_context_savings} avoided operations")
121
+
122
+ # Verify phase
123
+ verify_ops = 5 * num_layers
124
+ print(f"\nVERIFY PHASE: {verify_ops} ops (fills all layers for drafted tokens)")
125
+
126
+ print(f"\nTotal time: {prefill_time + draft_time + verify_time:.2f} ms")
127
+
128
+ return True
129
+
130
+
131
+ def run_full_benchmark(model_name, heads_path, config_path, calibration_path=None):
132
+ """Run full benchmark with actual model."""
133
+ from src.inference import load_dssd_model
134
+
135
+ print("\n" + "=" * 60)
136
+ print(f"BENCHMARK: Full Model Comparison")
137
+ print(f"Model: {model_name}")
138
+ print("=" * 60)
139
+
140
+ try:
141
+ decoder, tokenizer = load_dssd_model(
142
+ model_name=model_name,
143
+ heads_path=heads_path,
144
+ config_path=config_path,
145
+ calibration_path=calibration_path,
146
+ device="auto",
147
+ )
148
+ except Exception as e:
149
+ print(f"Error loading model: {e}")
150
+ return False
151
+
152
+ prompt = "Explain what machine learning is in three sentences."
153
+ max_tokens = 50
154
+
155
+ # Warmup
156
+ print("\nWarming up...")
157
+ _ = decoder.generate(
158
+ prompt, max_tokens=10, use_early_exit=False, use_chat_template=True
159
+ )
160
+
161
+ # Benchmark standard generation
162
+ print("\nRunning standard generation (no cache)...")
163
+ start = time.perf_counter()
164
+ result_standard = decoder.generate(
165
+ prompt,
166
+ max_tokens=max_tokens,
167
+ use_early_exit=True,
168
+ accuracy_level=0.75,
169
+ use_chat_template=True,
170
+ )
171
+ time_standard = time.perf_counter() - start
172
+
173
+ # Benchmark cache-optimized generation (fast version)
174
+ print("Running cache-optimized generation (fast)...")
175
+ start = time.perf_counter()
176
+ result_cached = decoder.generate_fast(
177
+ prompt,
178
+ max_tokens=max_tokens,
179
+ accuracy_level=0.75,
180
+ use_chat_template=True,
181
+ )
182
+ time_cached = time.perf_counter() - start
183
+
184
+ # Print results
185
+ print("\n" + "=" * 60)
186
+ print("RESULTS")
187
+ print("=" * 60)
188
+
189
+ print("\nStandard Generation:")
190
+ print(f" Tokens generated: {len(result_standard.tokens)}")
191
+ print(f" Time: {time_standard:.2f}s")
192
+ print(f" Tokens/sec: {len(result_standard.tokens) / time_standard:.2f}")
193
+ print(f" Avg exit layer: {result_standard.avg_exit_layer:.1f}")
194
+
195
+ print("\nCache-Optimized Generation:")
196
+ print(f" Tokens generated: {len(result_cached.tokens)}")
197
+ print(f" Time: {time_cached:.2f}s")
198
+ print(f" Tokens/sec: {len(result_cached.tokens) / time_cached:.2f}")
199
+ print(f" Avg exit layer: {result_cached.avg_exit_layer:.1f}")
200
+ if "total_drafted" in result_cached.exit_distribution:
201
+ print(f" Drafted: {result_cached.exit_distribution['total_drafted']}")
202
+ print(f" Accepted: {result_cached.exit_distribution['total_accepted']}")
203
+ print(
204
+ f" Acceptance rate: {result_cached.exit_distribution['acceptance_rate']:.1%}"
205
+ )
206
+
207
+ print("\nSpeedup:")
208
+ speedup = time_standard / time_cached if time_cached > 0 else 0
209
+ print(f" {speedup:.2f}x faster with cache")
210
+
211
+ return True
212
+
213
+
214
+ def main():
215
+ parser = argparse.ArgumentParser(description="Benchmark DSSD generation")
216
+ parser.add_argument("--model", default="Qwen/Qwen3-0.6B", help="Model name")
217
+ parser.add_argument("--heads-path", help="Path to aux heads checkpoint")
218
+ parser.add_argument("--config-path", help="Path to model config")
219
+ parser.add_argument("--calibration-path", help="Path to calibration file")
220
+ parser.add_argument(
221
+ "--cpu-only", action="store_true", help="Run CPU-only cache benchmark"
222
+ )
223
+ args = parser.parse_args()
224
+
225
+ if args.cpu_only or not args.heads_path:
226
+ # Run CPU-only cache operations benchmark
227
+ make_dummy_decoder()
228
+ else:
229
+ # Run full benchmark with model
230
+ run_full_benchmark(
231
+ args.model,
232
+ args.heads_path,
233
+ args.config_path,
234
+ args.calibration_path,
235
+ )
236
+
237
+
238
+ if __name__ == "__main__":
239
+ main()
tests/test_cache_integration.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration tests for JaggedKVCache with inference pipeline.
3
+
4
+ Run with: pytest tests/test_cache_integration.py -v
5
+ """
6
+
7
+ import pytest
8
+ import torch
9
+ from typing import List, Optional
10
+
11
+ # Import from production module
12
+ import sys
13
+
14
+ sys.path.insert(0, "/home/fvalade/workspace/DSSD_demo")
15
+
16
+ from src.jagged_cache import JaggedKVCache
17
+
18
+
19
+ class TestJaggedKVCacheProduction:
20
+ """Test the production JaggedKVCache implementation."""
21
+
22
+ @pytest.fixture
23
+ def cache(self):
24
+ """Create a test cache."""
25
+ return JaggedKVCache(
26
+ num_layers=8,
27
+ batch_size=1,
28
+ num_kv_heads=4,
29
+ head_dim=64,
30
+ device="cpu",
31
+ dtype=torch.float32,
32
+ )
33
+
34
+ @pytest.fixture
35
+ def sample_kv(self):
36
+ """Create sample KV tensors."""
37
+
38
+ def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64):
39
+ k = torch.randn(batch_size, num_heads, seq_len, head_dim)
40
+ v = torch.randn(batch_size, num_heads, seq_len, head_dim)
41
+ return k, v
42
+
43
+ return _make_kv
44
+
45
+ def test_filled_positions_tracking(self, cache, sample_kv):
46
+ """Test that filled_positions correctly tracks which positions are filled."""
47
+ # Update layer 0 with position 0
48
+ k, v = sample_kv()
49
+ cache.update(0, k, v, torch.tensor([0]))
50
+
51
+ assert cache.has_position(0, 0) == True
52
+ assert cache.has_position(0, 1) == False
53
+ assert cache.has_position(1, 0) == False # Layer 1 not touched
54
+
55
+ def test_needs_fill(self, cache, sample_kv):
56
+ """Test needs_fill correctly identifies missing positions."""
57
+ # Fill layer 0 with position 0
58
+ k, v = sample_kv()
59
+ cache.update(0, k, v, torch.tensor([0]))
60
+
61
+ # Layer 0 has position 0, doesn't need fill
62
+ assert cache.needs_fill(0, [0]) == False
63
+
64
+ # Layer 0 doesn't have position 1
65
+ assert cache.needs_fill(0, [1]) == True
66
+
67
+ # Layer 1 has nothing
68
+ assert cache.needs_fill(1, [0]) == True
69
+
70
+ def test_get_unfilled_positions(self, cache, sample_kv):
71
+ """Test getting unfilled positions."""
72
+ # Fill positions 0 and 2 for layer 0
73
+ k, v = sample_kv()
74
+ cache.update(0, k, v, torch.tensor([0]))
75
+ k, v = sample_kv()
76
+ cache.update(0, k, v, torch.tensor([2]))
77
+
78
+ # Unfilled up to position 4 should be [1, 3]
79
+ unfilled = cache.get_unfilled_positions(0, 4)
80
+ assert unfilled == [1, 3]
81
+
82
+ def test_truncate_clears_filled_positions(self, cache, sample_kv):
83
+ """Test that truncation also clears filled_positions."""
84
+ # Fill positions 0-4
85
+ for pos in range(5):
86
+ k, v = sample_kv()
87
+ cache.update(0, k, v, torch.tensor([pos]))
88
+
89
+ assert cache.has_position(0, 4) == True
90
+
91
+ # Truncate at position 3
92
+ cache.truncate_from(3)
93
+
94
+ # Positions 3 and 4 should be gone
95
+ assert cache.has_position(0, 2) == True
96
+ assert cache.has_position(0, 3) == False
97
+ assert cache.has_position(0, 4) == False
98
+
99
+ def test_clone_copies_filled_positions(self, cache, sample_kv):
100
+ """Test that clone also copies filled_positions."""
101
+ k, v = sample_kv()
102
+ cache.update(0, k, v, torch.tensor([0]))
103
+
104
+ cloned = cache.clone()
105
+
106
+ assert cloned.has_position(0, 0) == True
107
+
108
+ # Modify original
109
+ k, v = sample_kv()
110
+ cache.update(0, k, v, torch.tensor([1]))
111
+
112
+ # Clone should be unaffected
113
+ assert cache.has_position(0, 1) == True
114
+ assert cloned.has_position(0, 1) == False
115
+
116
+ def test_reset(self, cache, sample_kv):
117
+ """Test that reset clears everything."""
118
+ k, v = sample_kv()
119
+ cache.update(0, k, v, torch.tensor([0]))
120
+
121
+ cache.reset()
122
+
123
+ assert cache.get_kv(0) is None
124
+ assert cache.get_seq_length(0) == 0
125
+ assert cache.has_position(0, 0) == False
126
+
127
+
128
+ class TestLazyFillScenario:
129
+ """Test realistic lazy fill scenarios."""
130
+
131
+ @pytest.fixture
132
+ def cache(self):
133
+ return JaggedKVCache(
134
+ num_layers=8,
135
+ batch_size=1,
136
+ num_kv_heads=4,
137
+ head_dim=64,
138
+ device="cpu",
139
+ dtype=torch.float32,
140
+ )
141
+
142
+ @pytest.fixture
143
+ def sample_kv(self):
144
+ def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64):
145
+ k = torch.randn(batch_size, num_heads, seq_len, head_dim)
146
+ v = torch.randn(batch_size, num_heads, seq_len, head_dim)
147
+ return k, v
148
+
149
+ return _make_kv
150
+
151
+ def test_lazy_fill_scenario(self, cache, sample_kv):
152
+ """
153
+ Simulate:
154
+ - Prefill prompt (positions 0-4) through all layers
155
+ - Draft token 5 exiting at layer 2
156
+ - Draft token 6 exiting at layer 6 (needs lazy fill)
157
+ """
158
+ # Prefill: positions 0-4 through all 8 layers
159
+ for pos in range(5):
160
+ for layer_idx in range(8):
161
+ k, v = sample_kv()
162
+ cache.update(layer_idx, k, v, torch.tensor([pos]))
163
+
164
+ # Verify prefill complete
165
+ for layer_idx in range(8):
166
+ assert cache.get_seq_length(layer_idx) == 5
167
+ for pos in range(5):
168
+ assert cache.has_position(layer_idx, pos)
169
+
170
+ # Draft token 5, exit at layer 2
171
+ for layer_idx in range(3): # Layers 0, 1, 2
172
+ k, v = sample_kv()
173
+ cache.update(layer_idx, k, v, torch.tensor([5]))
174
+
175
+ # Position 5 is filled only for layers 0-2
176
+ assert cache.has_position(0, 5)
177
+ assert cache.has_position(2, 5)
178
+ assert not cache.has_position(3, 5)
179
+
180
+ # Draft token 6, need to exit at layer 6
181
+ # Check what positions are missing for layer 6
182
+ missing_at_layer_6 = cache.get_missing_layers(5, 6)
183
+
184
+ # Layers 3-6 are missing position 5
185
+ assert 3 in missing_at_layer_6
186
+ assert 6 in missing_at_layer_6
187
+ assert 0 not in missing_at_layer_6 # Layer 0 has position 5
188
+
189
+ # Check unfilled positions for layer 6 up to position 6
190
+ unfilled = cache.get_unfilled_positions(6, 6)
191
+ assert 5 in unfilled # Position 5 is unfilled at layer 6
192
+
193
+
194
+ if __name__ == "__main__":
195
+ pytest.main([__file__, "-v"])
tests/test_cache_operations.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step-by-step verification tests for KV Cache operations.
3
+
4
+ These tests verify the correctness of the JaggedKVCache implementation
5
+ without requiring a full model. Run with: pytest tests/test_cache_operations.py -v
6
+ """
7
+
8
+ import pytest
9
+ import torch
10
+ from typing import List, Tuple, Optional
11
+
12
+
13
+ # =============================================================================
14
+ # Mock Cache Implementation (to be replaced with real JaggedKVCache)
15
+ # =============================================================================
16
+
17
+
18
+ class JaggedKVCache:
19
+ """
20
+ Jagged KV Cache that tracks per-layer sequence lengths.
21
+
22
+ This is a reference implementation for testing. The production version
23
+ will be in src/jagged_cache.py.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ num_layers: int,
29
+ batch_size: int = 1,
30
+ num_kv_heads: int = 8,
31
+ head_dim: int = 128,
32
+ device: str = "cpu",
33
+ dtype: torch.dtype = torch.float32,
34
+ ):
35
+ self.num_layers = num_layers
36
+ self.batch_size = batch_size
37
+ self.num_kv_heads = num_kv_heads
38
+ self.head_dim = head_dim
39
+ self.device = device
40
+ self.dtype = dtype
41
+
42
+ # Per-layer storage: List of (key_cache, value_cache) or None
43
+ self.layer_caches: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
44
+ None for _ in range(num_layers)
45
+ ]
46
+
47
+ # Track sequence length per layer
48
+ self.layer_seq_lengths: List[int] = [0] * num_layers
49
+
50
+ def update(
51
+ self,
52
+ layer_idx: int,
53
+ key_states: torch.Tensor,
54
+ value_states: torch.Tensor,
55
+ cache_position: torch.Tensor,
56
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ """
58
+ Update cache for a layer at specific positions.
59
+
60
+ Args:
61
+ layer_idx: Layer index to update
62
+ key_states: [B, num_kv_heads, seq_len, head_dim]
63
+ value_states: [B, num_kv_heads, seq_len, head_dim]
64
+ cache_position: [seq_len] positions to update
65
+
66
+ Returns:
67
+ (full_keys, full_values) including cached + new
68
+ """
69
+ new_len = cache_position[-1].item() + 1
70
+ input_seq_len = key_states.shape[2]
71
+
72
+ if self.layer_caches[layer_idx] is None:
73
+ # First time - check if positions are contiguous starting from 0
74
+ if cache_position[0].item() == 0 and input_seq_len == new_len:
75
+ # Simple case: positions [0, 1, ..., n-1] - just clone
76
+ self.layer_caches[layer_idx] = (
77
+ key_states.clone(),
78
+ value_states.clone(),
79
+ )
80
+ else:
81
+ # Non-contiguous or not starting from 0 - need to allocate full size
82
+ k_cache = torch.zeros(
83
+ (self.batch_size, self.num_kv_heads, new_len, self.head_dim),
84
+ device=self.device,
85
+ dtype=self.dtype,
86
+ )
87
+ v_cache = torch.zeros(
88
+ (self.batch_size, self.num_kv_heads, new_len, self.head_dim),
89
+ device=self.device,
90
+ dtype=self.dtype,
91
+ )
92
+ k_cache[:, :, cache_position.long(), :] = key_states
93
+ v_cache[:, :, cache_position.long(), :] = value_states
94
+ self.layer_caches[layer_idx] = (k_cache, v_cache)
95
+ self.layer_seq_lengths[layer_idx] = new_len
96
+ else:
97
+ k_cache, v_cache = self.layer_caches[layer_idx]
98
+ current_len = k_cache.shape[2]
99
+
100
+ if new_len > current_len:
101
+ # Need to extend cache
102
+ extension_size = new_len - current_len
103
+ k_extension = torch.zeros(
104
+ (self.batch_size, self.num_kv_heads, extension_size, self.head_dim),
105
+ device=self.device,
106
+ dtype=self.dtype,
107
+ )
108
+ v_extension = torch.zeros(
109
+ (self.batch_size, self.num_kv_heads, extension_size, self.head_dim),
110
+ device=self.device,
111
+ dtype=self.dtype,
112
+ )
113
+ k_cache = torch.cat([k_cache, k_extension], dim=2)
114
+ v_cache = torch.cat([v_cache, v_extension], dim=2)
115
+
116
+ # Update at cache_position (handles both extension and gap-filling)
117
+ k_cache[:, :, cache_position.long(), :] = key_states
118
+ v_cache[:, :, cache_position.long(), :] = value_states
119
+
120
+ self.layer_caches[layer_idx] = (k_cache, v_cache)
121
+ self.layer_seq_lengths[layer_idx] = max(
122
+ self.layer_seq_lengths[layer_idx], new_len
123
+ )
124
+
125
+ return self.layer_caches[layer_idx]
126
+
127
+ def get_kv(self, layer_idx: int) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
128
+ """Get cached KV for a layer, or None if not cached."""
129
+ return self.layer_caches[layer_idx]
130
+
131
+ def get_seq_length(self, layer_idx: int) -> int:
132
+ """Get the sequence length cached for a layer."""
133
+ return self.layer_seq_lengths[layer_idx]
134
+
135
+ def truncate_from(self, position: int):
136
+ """
137
+ Truncate all layer caches from position onwards.
138
+ Used for rollback on rejection.
139
+ """
140
+ for layer_idx in range(self.num_layers):
141
+ if self.layer_caches[layer_idx] is not None:
142
+ k, v = self.layer_caches[layer_idx]
143
+ if k.shape[2] > position:
144
+ self.layer_caches[layer_idx] = (
145
+ k[:, :, :position, :],
146
+ v[:, :, :position, :],
147
+ )
148
+ self.layer_seq_lengths[layer_idx] = min(
149
+ self.layer_seq_lengths[layer_idx], position
150
+ )
151
+
152
+ def clone(self) -> "JaggedKVCache":
153
+ """Create a deep copy of the cache for speculation."""
154
+ new_cache = JaggedKVCache(
155
+ num_layers=self.num_layers,
156
+ batch_size=self.batch_size,
157
+ num_kv_heads=self.num_kv_heads,
158
+ head_dim=self.head_dim,
159
+ device=self.device,
160
+ dtype=self.dtype,
161
+ )
162
+ for i, kv in enumerate(self.layer_caches):
163
+ if kv is not None:
164
+ new_cache.layer_caches[i] = (kv[0].clone(), kv[1].clone())
165
+ new_cache.layer_seq_lengths = self.layer_seq_lengths.copy()
166
+ return new_cache
167
+
168
+ def get_missing_layers(self, position: int, target_layer: int) -> List[int]:
169
+ """
170
+ Get list of layers that need computation for this position.
171
+
172
+ Args:
173
+ position: The position we need KV for
174
+ target_layer: The deepest layer we need to reach
175
+
176
+ Returns:
177
+ List of layer indices that need to be computed
178
+ """
179
+ missing = []
180
+ for layer_idx in range(target_layer + 1):
181
+ if self.layer_seq_lengths[layer_idx] <= position:
182
+ missing.append(layer_idx)
183
+ return missing
184
+
185
+ def __repr__(self):
186
+ lines = [f"JaggedKVCache(num_layers={self.num_layers})"]
187
+ for i in range(self.num_layers):
188
+ seq_len = self.layer_seq_lengths[i]
189
+ lines.append(f" Layer {i:2d}: {seq_len} positions cached")
190
+ return "\n".join(lines)
191
+
192
+
193
+ # =============================================================================
194
+ # Test Fixtures
195
+ # =============================================================================
196
+
197
+
198
+ @pytest.fixture
199
+ def small_cache():
200
+ """Create a small cache for testing."""
201
+ return JaggedKVCache(
202
+ num_layers=8,
203
+ batch_size=1,
204
+ num_kv_heads=4,
205
+ head_dim=64,
206
+ device="cpu",
207
+ dtype=torch.float32,
208
+ )
209
+
210
+
211
+ @pytest.fixture
212
+ def sample_kv():
213
+ """Create sample KV tensors."""
214
+
215
+ def _make_kv(batch_size=1, num_heads=4, seq_len=1, head_dim=64):
216
+ k = torch.randn(batch_size, num_heads, seq_len, head_dim)
217
+ v = torch.randn(batch_size, num_heads, seq_len, head_dim)
218
+ return k, v
219
+
220
+ return _make_kv
221
+
222
+
223
+ # =============================================================================
224
+ # Test 1: Basic Cache Operations
225
+ # =============================================================================
226
+
227
+
228
+ class TestCacheBasicOperations:
229
+ """Test basic cache update and retrieval."""
230
+
231
+ def test_cache_starts_empty(self, small_cache):
232
+ """Cache should start with no entries."""
233
+ for i in range(small_cache.num_layers):
234
+ assert small_cache.get_kv(i) is None
235
+ assert small_cache.get_seq_length(i) == 0
236
+
237
+ def test_single_position_update(self, small_cache, sample_kv):
238
+ """Test updating cache with a single position."""
239
+ k, v = sample_kv()
240
+ cache_position = torch.tensor([0])
241
+
242
+ small_cache.update(
243
+ layer_idx=0, key_states=k, value_states=v, cache_position=cache_position
244
+ )
245
+
246
+ assert small_cache.get_kv(0) is not None
247
+ assert small_cache.get_seq_length(0) == 1
248
+ assert small_cache.get_kv(1) is None # Other layers unchanged
249
+
250
+ def test_multiple_positions_update(self, small_cache, sample_kv):
251
+ """Test updating cache with multiple positions at once."""
252
+ k, v = sample_kv(seq_len=3)
253
+ cache_position = torch.tensor([0, 1, 2])
254
+
255
+ small_cache.update(
256
+ layer_idx=0, key_states=k, value_states=v, cache_position=cache_position
257
+ )
258
+
259
+ assert small_cache.get_seq_length(0) == 3
260
+ cached_k, cached_v = small_cache.get_kv(0)
261
+ assert cached_k.shape[2] == 3
262
+
263
+ def test_extending_cache(self, small_cache, sample_kv):
264
+ """Test extending cache with new positions."""
265
+ # First update
266
+ k1, v1 = sample_kv(seq_len=2)
267
+ small_cache.update(0, k1, v1, torch.tensor([0, 1]))
268
+
269
+ # Extend with more positions
270
+ k2, v2 = sample_kv(seq_len=2)
271
+ small_cache.update(0, k2, v2, torch.tensor([2, 3]))
272
+
273
+ assert small_cache.get_seq_length(0) == 4
274
+ cached_k, _ = small_cache.get_kv(0)
275
+ assert cached_k.shape[2] == 4
276
+
277
+
278
+ # =============================================================================
279
+ # Test 2: Jagged Cache Behavior
280
+ # =============================================================================
281
+
282
+
283
+ class TestJaggedCacheBehavior:
284
+ """Test that cache correctly handles different layers with different lengths."""
285
+
286
+ def test_different_layers_different_lengths(self, small_cache, sample_kv):
287
+ """Simulate early exit where different layers have different cached lengths.
288
+
289
+ Note: seq_length tracks capacity (max_pos + 1), not filled count.
290
+ When layer 3 is first updated at position [1], it allocates space for
291
+ positions [0, 1], but position 0 contains zeros (unfilled).
292
+ The lazy fill mechanism will fill these gaps when needed.
293
+ """
294
+ # Token 0: Exit at layer 2 -> layers 0-2 get cached
295
+ for layer_idx in range(3):
296
+ k, v = sample_kv()
297
+ small_cache.update(layer_idx, k, v, torch.tensor([0]))
298
+
299
+ # Token 1: Exit at layer 4 -> layers 0-4 get cached
300
+ for layer_idx in range(5):
301
+ k, v = sample_kv()
302
+ small_cache.update(layer_idx, k, v, torch.tensor([1]))
303
+
304
+ # Check jagged structure
305
+ # seq_length = capacity = max_position + 1
306
+ assert small_cache.get_seq_length(0) == 2 # Both tokens
307
+ assert small_cache.get_seq_length(1) == 2
308
+ assert small_cache.get_seq_length(2) == 2
309
+ # Layers 3-4 have capacity 2 (allocated for positions 0,1)
310
+ # Position 0 is zeros (unfilled) - will be lazy-filled when needed
311
+ assert small_cache.get_seq_length(3) == 2
312
+ assert small_cache.get_seq_length(4) == 2
313
+ assert small_cache.get_seq_length(5) == 0 # Never reached
314
+
315
+ def test_get_missing_layers(self, small_cache, sample_kv):
316
+ """Test detecting which layers need computation."""
317
+ # Cache position 0 for layers 0-2 only
318
+ for layer_idx in range(3):
319
+ k, v = sample_kv()
320
+ small_cache.update(layer_idx, k, v, torch.tensor([0]))
321
+
322
+ # Check what's missing for position 0 up to layer 5
323
+ missing = small_cache.get_missing_layers(position=0, target_layer=5)
324
+ assert missing == [3, 4, 5] # Layers 3-5 are missing
325
+
326
+ # Check for position 1 (not cached anywhere)
327
+ missing = small_cache.get_missing_layers(position=1, target_layer=5)
328
+ assert missing == [0, 1, 2, 3, 4, 5] # All layers missing
329
+
330
+
331
+ # =============================================================================
332
+ # Test 3: Truncation for Rollback
333
+ # =============================================================================
334
+
335
+
336
+ class TestCacheTruncation:
337
+ """Test cache truncation for rejection rollback."""
338
+
339
+ def test_truncate_removes_positions(self, small_cache, sample_kv):
340
+ """Test that truncation removes positions correctly."""
341
+ # Fill cache with 5 positions
342
+ for pos in range(5):
343
+ k, v = sample_kv()
344
+ small_cache.update(0, k, v, torch.tensor([pos]))
345
+
346
+ assert small_cache.get_seq_length(0) == 5
347
+
348
+ # Truncate at position 3 (keep 0, 1, 2)
349
+ small_cache.truncate_from(3)
350
+
351
+ assert small_cache.get_seq_length(0) == 3
352
+ cached_k, _ = small_cache.get_kv(0)
353
+ assert cached_k.shape[2] == 3
354
+
355
+ def test_truncate_all_layers(self, small_cache, sample_kv):
356
+ """Test that truncation affects all layers."""
357
+ # Fill multiple layers with different lengths
358
+ for layer_idx in range(3):
359
+ for pos in range(5):
360
+ k, v = sample_kv()
361
+ small_cache.update(layer_idx, k, v, torch.tensor([pos]))
362
+
363
+ # Add more to layer 0
364
+ for pos in range(5, 8):
365
+ k, v = sample_kv()
366
+ small_cache.update(0, k, v, torch.tensor([pos]))
367
+
368
+ assert small_cache.get_seq_length(0) == 8
369
+ assert small_cache.get_seq_length(1) == 5
370
+ assert small_cache.get_seq_length(2) == 5
371
+
372
+ # Truncate at position 4
373
+ small_cache.truncate_from(4)
374
+
375
+ assert small_cache.get_seq_length(0) == 4
376
+ assert small_cache.get_seq_length(1) == 4
377
+ assert small_cache.get_seq_length(2) == 4
378
+
379
+
380
+ # =============================================================================
381
+ # Test 4: Clone for Speculation
382
+ # =============================================================================
383
+
384
+
385
+ class TestCacheCloning:
386
+ """Test cache cloning for speculative drafting."""
387
+
388
+ def test_clone_creates_independent_copy(self, small_cache, sample_kv):
389
+ """Test that clone creates truly independent copy."""
390
+ # Fill original cache
391
+ k, v = sample_kv(seq_len=3)
392
+ small_cache.update(0, k, v, torch.tensor([0, 1, 2]))
393
+
394
+ # Clone
395
+ cloned = small_cache.clone()
396
+
397
+ # Modify original
398
+ k2, v2 = sample_kv()
399
+ small_cache.update(0, k2, v2, torch.tensor([3]))
400
+
401
+ # Check clone is unchanged
402
+ assert small_cache.get_seq_length(0) == 4
403
+ assert cloned.get_seq_length(0) == 3
404
+
405
+ def test_clone_preserves_data(self, small_cache, sample_kv):
406
+ """Test that clone preserves actual tensor values."""
407
+ k, v = sample_kv()
408
+ small_cache.update(0, k, v, torch.tensor([0]))
409
+
410
+ cloned = small_cache.clone()
411
+
412
+ orig_k, orig_v = small_cache.get_kv(0)
413
+ clone_k, clone_v = cloned.get_kv(0)
414
+
415
+ assert torch.allclose(orig_k, clone_k)
416
+ assert torch.allclose(orig_v, clone_v)
417
+
418
+
419
+ # =============================================================================
420
+ # Test 5: Simulated Draft/Verify Scenario
421
+ # =============================================================================
422
+
423
+
424
+ class TestDraftVerifyScenario:
425
+ """Simulate a realistic draft/verify scenario."""
426
+
427
+ def test_draft_verify_with_full_accept(self, small_cache, sample_kv):
428
+ """Simulate drafting 3 tokens, all accepted."""
429
+ # Prompt prefill (position 0-4)
430
+ for pos in range(5):
431
+ for layer_idx in range(small_cache.num_layers):
432
+ k, v = sample_kv()
433
+ small_cache.update(layer_idx, k, v, torch.tensor([pos]))
434
+
435
+ # Clone for drafting
436
+ draft_cache = small_cache.clone()
437
+
438
+ # Draft 3 tokens (positions 5, 6, 7), exiting at different layers
439
+ exit_layers = [2, 4, 3] # Token 5 exits at layer 2, etc.
440
+
441
+ for i, (pos, exit_layer) in enumerate(zip([5, 6, 7], exit_layers)):
442
+ for layer_idx in range(exit_layer + 1):
443
+ k, v = sample_kv()
444
+ draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
445
+
446
+ # Check jagged structure after drafting
447
+ assert draft_cache.get_seq_length(0) == 8 # All 8 positions
448
+ assert draft_cache.get_seq_length(2) == 8 # All tokens reached layer 2
449
+ assert draft_cache.get_seq_length(4) == 7 # Only tokens 5,6 reached layer 4
450
+
451
+ # "Verification" - all accepted, fill remaining layers
452
+ for pos in [5, 6, 7]:
453
+ for layer_idx in range(small_cache.num_layers):
454
+ if draft_cache.get_seq_length(layer_idx) <= pos:
455
+ k, v = sample_kv()
456
+ draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
457
+
458
+ # After verification, all layers should have all positions
459
+ for layer_idx in range(small_cache.num_layers):
460
+ assert draft_cache.get_seq_length(layer_idx) == 8
461
+
462
+ def test_draft_verify_with_rejection(self, small_cache, sample_kv):
463
+ """Simulate drafting 3 tokens, rejected at position 6."""
464
+ # Prompt prefill
465
+ for pos in range(5):
466
+ for layer_idx in range(small_cache.num_layers):
467
+ k, v = sample_kv()
468
+ small_cache.update(layer_idx, k, v, torch.tensor([pos]))
469
+
470
+ # Clone for drafting
471
+ draft_cache = small_cache.clone()
472
+
473
+ # Draft 3 tokens
474
+ for pos in [5, 6, 7]:
475
+ for layer_idx in range(3): # All exit at layer 2
476
+ k, v = sample_kv()
477
+ draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
478
+
479
+ # Simulate rejection at position 6
480
+ # Accept position 5, reject 6 (and 7)
481
+ draft_cache.truncate_from(6)
482
+
483
+ # Should only have positions 0-5
484
+ assert draft_cache.get_seq_length(0) == 6
485
+ assert draft_cache.get_seq_length(1) == 6
486
+ assert draft_cache.get_seq_length(2) == 6
487
+
488
+
489
+ # =============================================================================
490
+ # Run tests directly
491
+ # =============================================================================
492
+
493
+
494
+ if __name__ == "__main__":
495
+ pytest.main([__file__, "-v"])