Florian valade commited on
Commit
a781577
·
1 Parent(s): 7848d77

Fix early exit inference loop to eliminate redundant computation

Browse files

Key fixes:
- _draft_single_token now always returns a token (never None)
- When no early exit head is confident, continues to lm_head instead of
returning None and triggering a redundant full model pass
- Extracts 'bonus token' from verification pass when all drafts accepted
- Same fixes applied to generate_streaming method

This eliminates the double computation bug where layers were processed
twice when no head was confident, and adds the bonus token optimization
that extracts an extra token from each verification pass.

Adds comprehensive tests in tests/test_inference_loop.py

Files changed (2) hide show
  1. src/inference.py +296 -141
  2. tests/test_inference_loop.py +559 -0
src/inference.py CHANGED
@@ -218,28 +218,29 @@ class DSSDecoder:
218
  validated_tokens = []
219
  current_ids = input_ids.clone()
220
  num_layers = self.adapter.get_num_layers()
221
- head_layers = self.model_config.head_layer_indices
222
 
223
  while len(validated_tokens) < max_tokens:
224
  # ============================================================
225
- # DRAFT PHASE: Generate tokens using early exit heads
226
  # ============================================================
227
  drafted_tokens = []
228
  draft_ids = current_ids.clone()
 
229
 
230
  for _ in range(max_draft_length):
231
  if len(validated_tokens) + len(drafted_tokens) >= max_tokens:
232
  break
233
 
234
- draft_result = self._draft_single_token(draft_ids, thresholds)
235
-
236
- if draft_result is None:
237
- break
238
-
239
- token_id, exit_head, exit_layer, uncertainty = draft_result
240
 
241
  if token_id == self.tokenizer.eos_token_id:
242
- break
 
 
 
243
 
244
  token_text = self.tokenizer.decode([token_id])
245
  drafted_token = TokenInfo(
@@ -254,109 +255,126 @@ class DSSDecoder:
254
  [draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
255
  )
256
 
257
- # Yield draft event
258
- yield StreamEvent(
259
- event_type="draft",
260
- tokens=list(validated_tokens),
261
- drafted_tokens=list(drafted_tokens),
262
- message=f"Drafting token {len(drafted_tokens)} using Head {exit_head}",
263
- )
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  # ============================================================
266
  # VERIFY PHASE
267
  # ============================================================
268
- if drafted_tokens:
269
- yield StreamEvent(
270
- event_type="verify_start",
271
- tokens=list(validated_tokens),
272
- drafted_tokens=list(drafted_tokens),
273
- message=f"Verifying {len(drafted_tokens)} drafted tokens...",
274
- )
275
 
276
- with torch.no_grad():
277
- outputs = self.model(draft_ids, use_cache=False)
278
- verify_logits = outputs.logits
 
 
 
279
 
280
- start_pos = current_ids.shape[1] - 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- for i, drafted_token in enumerate(drafted_tokens):
283
- verify_pos = start_pos + i
284
- verified_token_id = torch.argmax(
285
- verify_logits[0, verify_pos, :]
 
 
286
  ).item()
287
-
288
- if drafted_token.token_id == verified_token_id:
289
- # Accept
290
- validated_tokens.append(drafted_token)
291
- current_ids = torch.cat(
292
- [
293
- current_ids,
294
- torch.tensor(
295
- [[drafted_token.token_id]], device=self.device
296
- ),
297
- ],
298
- dim=1,
299
- )
300
- yield StreamEvent(
301
- event_type="accept",
302
- tokens=list(validated_tokens),
303
- drafted_tokens=[],
304
- message=f"✓ Accepted '{drafted_token.token_text}'",
305
- )
306
- else:
307
- # Reject - use full model's token
308
- token_text = self.tokenizer.decode([verified_token_id])
309
- corrected_token = TokenInfo(
310
- token_id=verified_token_id,
311
- token_text=token_text,
312
  exit_head=None,
313
  exit_layer=num_layers,
314
  uncertainty=0.0,
315
  )
316
- validated_tokens.append(corrected_token)
317
  current_ids = torch.cat(
318
  [
319
  current_ids,
320
- torch.tensor([[verified_token_id]], device=self.device),
321
  ],
322
  dim=1,
323
  )
324
  yield StreamEvent(
325
- event_type="reject",
326
  tokens=list(validated_tokens),
327
  drafted_tokens=[],
328
- message=f" Rejected '{drafted_token.token_text}' '{token_text}'",
329
  )
330
- break
331
- else:
332
- # No drafts - generate with full model
333
- with torch.no_grad():
334
- outputs = self.model(current_ids, use_cache=False)
335
- logits = outputs.logits
336
-
337
- token_id = torch.argmax(logits[0, -1, :]).item()
338
-
339
- if token_id == self.tokenizer.eos_token_id:
340
- break
341
-
342
- token_text = self.tokenizer.decode([token_id])
343
- full_token = TokenInfo(
344
- token_id=token_id,
345
- token_text=token_text,
346
- exit_head=None,
347
- exit_layer=num_layers,
348
- uncertainty=0.0,
349
- )
350
- validated_tokens.append(full_token)
351
- current_ids = torch.cat(
352
- [current_ids, torch.tensor([[token_id]], device=self.device)], dim=1
353
- )
354
- yield StreamEvent(
355
- event_type="full_model",
356
- tokens=list(validated_tokens),
357
- drafted_tokens=[],
358
- message=f"Full model: '{token_text}'",
359
- )
360
 
361
  if (
362
  validated_tokens
@@ -374,55 +392,81 @@ class DSSDecoder:
374
  """
375
  Speculative decoding with early exit heads.
376
 
377
- GUARANTEES same output as full model by:
378
- 1. DRAFT: Generate tokens using early exit heads (fast, partial compute)
379
- 2. VERIFY: When full model needed, verify ALL drafted tokens
380
- 3. ACCEPT: Keep matching tokens, take model's token at first mismatch
 
 
381
  """
382
  tokens = []
383
  current_ids = input_ids.clone()
384
  num_layers = self.adapter.get_num_layers()
385
- head_layers = self.model_config.head_layer_indices
386
 
387
  while len(tokens) < max_tokens:
388
  # ============================================================
389
- # DRAFT PHASE: Generate tokens using early exit heads
390
  # ============================================================
391
  drafted_tokens = [] # List of (token_id, exit_head, exit_layer, uncertainty)
392
  draft_ids = current_ids.clone()
 
393
 
394
  for _ in range(max_draft_length):
395
  if len(tokens) + len(drafted_tokens) >= max_tokens:
396
  break
397
 
398
- # Try to draft a token using early exit
399
- draft_result = self._draft_single_token(draft_ids, thresholds)
400
-
401
- if draft_result is None:
402
- # No head was confident enough - need to verify
403
- break
404
-
405
- token_id, exit_head, exit_layer, uncertainty = draft_result
406
 
407
  if token_id == self.tokenizer.eos_token_id:
408
- break
409
-
410
- drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
411
- draft_ids = torch.cat(
412
- [draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
413
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  # ============================================================
416
- # VERIFY PHASE: Run full model to verify drafted tokens
417
  # ============================================================
418
- if drafted_tokens:
419
- # Run full model on current_ids + all drafted tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  with torch.no_grad():
421
  outputs = self.model(draft_ids, use_cache=False)
422
  verify_logits = outputs.logits
423
 
424
- # Verify each drafted token
425
- start_pos = current_ids.shape[1] - 1 # Position before drafting
426
 
427
  for i, (drafted_token, exit_head, exit_layer, uncertainty) in enumerate(
428
  drafted_tokens
@@ -433,7 +477,7 @@ class DSSDecoder:
433
  ).item()
434
 
435
  if drafted_token == verified_token:
436
- # Token matches - accept it with early exit info
437
  token_text = self.tokenizer.decode([drafted_token])
438
  tokens.append(
439
  TokenInfo(
@@ -472,30 +516,126 @@ class DSSDecoder:
472
  )
473
  # Stop - discard remaining drafted tokens
474
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  else:
476
- # No tokens drafted - generate one with full model
477
  with torch.no_grad():
478
- outputs = self.model(current_ids, use_cache=False)
479
- logits = outputs.logits
480
 
481
- token_id = torch.argmax(logits[0, -1, :]).item()
482
 
483
- if token_id == self.tokenizer.eos_token_id:
484
- break
 
 
 
 
 
485
 
486
- token_text = self.tokenizer.decode([token_id])
487
- tokens.append(
488
- TokenInfo(
489
- token_id=token_id,
490
- token_text=token_text,
491
- exit_head=None,
492
- exit_layer=num_layers,
493
- uncertainty=0.0,
494
- )
495
- )
496
- current_ids = torch.cat(
497
- [current_ids, torch.tensor([[token_id]], device=self.device)], dim=1
498
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
 
500
  # Check for EOS in accepted tokens
501
  if tokens and tokens[-1].token_id == self.tokenizer.eos_token_id:
@@ -507,15 +647,20 @@ class DSSDecoder:
507
  self,
508
  input_ids: torch.Tensor,
509
  thresholds: Dict[int, float],
510
- ) -> Optional[Tuple[int, int, int, float]]:
511
  """
512
- Try to draft a single token using early exit heads.
513
- Returns (token_id, exit_head, exit_layer, uncertainty) if confident enough.
514
- Returns None if no head is confident enough (need full model verification).
 
 
 
 
515
  """
516
  device = input_ids.device
517
  seq_len = input_ids.shape[1]
518
  head_layers = self.model_config.head_layer_indices
 
519
 
520
  # Position IDs
521
  position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(
@@ -570,8 +715,18 @@ class DSSDecoder:
570
  token_id = torch.argmax(head_logits[0, -1, :]).item()
571
  return (token_id, head_idx, layer_idx, uncertainty)
572
 
573
- # No head was confident enough - need full model verification
574
- return None
 
 
 
 
 
 
 
 
 
 
575
 
576
  def _generate_full_model(
577
  self,
 
218
  validated_tokens = []
219
  current_ids = input_ids.clone()
220
  num_layers = self.adapter.get_num_layers()
 
221
 
222
  while len(validated_tokens) < max_tokens:
223
  # ============================================================
224
+ # DRAFT PHASE: Generate tokens using early exit or lm_head
225
  # ============================================================
226
  drafted_tokens = []
227
  draft_ids = current_ids.clone()
228
+ got_lm_head_token = False
229
 
230
  for _ in range(max_draft_length):
231
  if len(validated_tokens) + len(drafted_tokens) >= max_tokens:
232
  break
233
 
234
+ # Generate a token (always returns a result)
235
+ token_id, exit_head, exit_layer, uncertainty = self._draft_single_token(
236
+ draft_ids, thresholds
237
+ )
 
 
238
 
239
  if token_id == self.tokenizer.eos_token_id:
240
+ # EOS handling
241
+ if exit_head is not None and drafted_tokens:
242
+ break # Verify pending drafts first
243
+ return # Stop generation
244
 
245
  token_text = self.tokenizer.decode([token_id])
246
  drafted_token = TokenInfo(
 
255
  [draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
256
  )
257
 
258
+ if exit_head is None:
259
+ # Token from lm_head - triggers verification
260
+ got_lm_head_token = True
261
+ yield StreamEvent(
262
+ event_type="draft",
263
+ tokens=list(validated_tokens),
264
+ drafted_tokens=list(drafted_tokens),
265
+ message=f"Drafting token {len(drafted_tokens)} using Full Model",
266
+ )
267
+ break
268
+ else:
269
+ # Token from early exit head
270
+ yield StreamEvent(
271
+ event_type="draft",
272
+ tokens=list(validated_tokens),
273
+ drafted_tokens=list(drafted_tokens),
274
+ message=f"Drafting token {len(drafted_tokens)} using Head {exit_head}",
275
+ )
276
 
277
  # ============================================================
278
  # VERIFY PHASE
279
  # ============================================================
280
+ if not drafted_tokens:
281
+ break
 
 
 
 
 
282
 
283
+ yield StreamEvent(
284
+ event_type="verify_start",
285
+ tokens=list(validated_tokens),
286
+ drafted_tokens=list(drafted_tokens),
287
+ message=f"Verifying {len(drafted_tokens)} drafted tokens...",
288
+ )
289
 
290
+ with torch.no_grad():
291
+ outputs = self.model(draft_ids, use_cache=False)
292
+ verify_logits = outputs.logits
293
+
294
+ start_pos = current_ids.shape[1] - 1
295
+ all_accepted = True
296
+
297
+ for i, drafted_token in enumerate(drafted_tokens):
298
+ verify_pos = start_pos + i
299
+ verified_token_id = torch.argmax(
300
+ verify_logits[0, verify_pos, :]
301
+ ).item()
302
+
303
+ if drafted_token.token_id == verified_token_id:
304
+ # Accept
305
+ validated_tokens.append(drafted_token)
306
+ current_ids = torch.cat(
307
+ [
308
+ current_ids,
309
+ torch.tensor(
310
+ [[drafted_token.token_id]], device=self.device
311
+ ),
312
+ ],
313
+ dim=1,
314
+ )
315
+ yield StreamEvent(
316
+ event_type="accept",
317
+ tokens=list(validated_tokens),
318
+ drafted_tokens=[],
319
+ message=f"✓ Accepted '{drafted_token.token_text}'",
320
+ )
321
+ else:
322
+ # Reject - use full model's token
323
+ all_accepted = False
324
+ token_text = self.tokenizer.decode([verified_token_id])
325
+ corrected_token = TokenInfo(
326
+ token_id=verified_token_id,
327
+ token_text=token_text,
328
+ exit_head=None,
329
+ exit_layer=num_layers,
330
+ uncertainty=0.0,
331
+ )
332
+ validated_tokens.append(corrected_token)
333
+ current_ids = torch.cat(
334
+ [
335
+ current_ids,
336
+ torch.tensor([[verified_token_id]], device=self.device),
337
+ ],
338
+ dim=1,
339
+ )
340
+ yield StreamEvent(
341
+ event_type="reject",
342
+ tokens=list(validated_tokens),
343
+ drafted_tokens=[],
344
+ message=f"✗ Rejected '{drafted_token.token_text}' → '{token_text}'",
345
+ )
346
+ break
347
 
348
+ # BONUS TOKEN: If all tokens were accepted, get bonus from last position
349
+ if all_accepted and len(validated_tokens) < max_tokens:
350
+ bonus_pos = start_pos + len(drafted_tokens)
351
+ if bonus_pos < verify_logits.shape[1]:
352
+ bonus_token_id = torch.argmax(
353
+ verify_logits[0, bonus_pos, :]
354
  ).item()
355
+ if bonus_token_id != self.tokenizer.eos_token_id:
356
+ bonus_text = self.tokenizer.decode([bonus_token_id])
357
+ bonus_token = TokenInfo(
358
+ token_id=bonus_token_id,
359
+ token_text=bonus_text,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  exit_head=None,
361
  exit_layer=num_layers,
362
  uncertainty=0.0,
363
  )
364
+ validated_tokens.append(bonus_token)
365
  current_ids = torch.cat(
366
  [
367
  current_ids,
368
+ torch.tensor([[bonus_token_id]], device=self.device),
369
  ],
370
  dim=1,
371
  )
372
  yield StreamEvent(
373
+ event_type="accept",
374
  tokens=list(validated_tokens),
375
  drafted_tokens=[],
376
+ message=f" Bonus token '{bonus_text}'",
377
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
  if (
380
  validated_tokens
 
392
  """
393
  Speculative decoding with early exit heads.
394
 
395
+ The flow:
396
+ 1. Generate tokens using _draft_single_token (which may early exit or use lm_head)
397
+ 2. Tokens from early exit heads are "drafts" that need verification
398
+ 3. When we get a token from lm_head (exit_head=None), it triggers verification
399
+ of all pending drafts, and the lm_head token is accepted as verified
400
+ 4. All accepted tokens are guaranteed to match full model output
401
  """
402
  tokens = []
403
  current_ids = input_ids.clone()
404
  num_layers = self.adapter.get_num_layers()
 
405
 
406
  while len(tokens) < max_tokens:
407
  # ============================================================
408
+ # DRAFT PHASE: Generate tokens, collecting early exit drafts
409
  # ============================================================
410
  drafted_tokens = [] # List of (token_id, exit_head, exit_layer, uncertainty)
411
  draft_ids = current_ids.clone()
412
+ got_lm_head_token = False
413
 
414
  for _ in range(max_draft_length):
415
  if len(tokens) + len(drafted_tokens) >= max_tokens:
416
  break
417
 
418
+ # Generate a token (always returns a result, never None)
419
+ token_id, exit_head, exit_layer, uncertainty = self._draft_single_token(
420
+ draft_ids, thresholds
421
+ )
 
 
 
 
422
 
423
  if token_id == self.tokenizer.eos_token_id:
424
+ # If EOS from early exit, we still need to verify pending drafts
425
+ if exit_head is not None and drafted_tokens:
426
+ # Don't add EOS to drafts, just break to verify
427
+ break
428
+ # If EOS from lm_head or no pending drafts, we're done
429
+ return tokens
430
+
431
+ if exit_head is None:
432
+ # Token from lm_head - this is verified, triggers verification of drafts
433
+ got_lm_head_token = True
434
+ # Add to drafts for unified handling, but mark as already verified
435
+ drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
436
+ draft_ids = torch.cat(
437
+ [draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
438
+ )
439
+ break # Stop drafting, go to verification
440
+ else:
441
+ # Token from early exit head - add to drafts for later verification
442
+ drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
443
+ draft_ids = torch.cat(
444
+ [draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
445
+ )
446
 
447
  # ============================================================
448
+ # VERIFY PHASE: Verify drafted tokens with full model
449
  # ============================================================
450
+ if not drafted_tokens:
451
+ # No tokens generated (shouldn't happen with the new logic)
452
+ break
453
+
454
+ # If the last token is from lm_head, we already have full model output
455
+ # for all positions. Use it for verification.
456
+ last_token = drafted_tokens[-1]
457
+ _, last_exit_head, _, _ = last_token
458
+
459
+ if last_exit_head is None:
460
+ # Last token is from lm_head - all earlier tokens need verification
461
+ # The lm_head pass already computed logits for all positions
462
+ # We can use the model output to verify
463
+
464
+ # Need to run full model to get logits for verification
465
  with torch.no_grad():
466
  outputs = self.model(draft_ids, use_cache=False)
467
  verify_logits = outputs.logits
468
 
469
+ start_pos = current_ids.shape[1] - 1
 
470
 
471
  for i, (drafted_token, exit_head, exit_layer, uncertainty) in enumerate(
472
  drafted_tokens
 
477
  ).item()
478
 
479
  if drafted_token == verified_token:
480
+ # Token matches - accept it
481
  token_text = self.tokenizer.decode([drafted_token])
482
  tokens.append(
483
  TokenInfo(
 
516
  )
517
  # Stop - discard remaining drafted tokens
518
  break
519
+
520
+ # BONUS TOKEN: If all drafted tokens were accepted, use the last position
521
+ # to get an additional token (this is the "free" token from lm_head)
522
+ if len(tokens) >= len(drafted_tokens):
523
+ # All drafts were accepted, check for bonus token
524
+ bonus_pos = start_pos + len(drafted_tokens)
525
+ if bonus_pos < verify_logits.shape[1]:
526
+ bonus_token_id = torch.argmax(
527
+ verify_logits[0, bonus_pos, :]
528
+ ).item()
529
+ if (
530
+ bonus_token_id != self.tokenizer.eos_token_id
531
+ and len(tokens) < max_tokens
532
+ ):
533
+ bonus_text = self.tokenizer.decode([bonus_token_id])
534
+ tokens.append(
535
+ TokenInfo(
536
+ token_id=bonus_token_id,
537
+ token_text=bonus_text,
538
+ exit_head=None, # Full model
539
+ exit_layer=num_layers,
540
+ uncertainty=0.0,
541
+ )
542
+ )
543
+ current_ids = torch.cat(
544
+ [
545
+ current_ids,
546
+ torch.tensor(
547
+ [[bonus_token_id]], device=self.device
548
+ ),
549
+ ],
550
+ dim=1,
551
+ )
552
  else:
553
+ # All tokens are from early exit heads - need to run full model for verification
554
  with torch.no_grad():
555
+ outputs = self.model(draft_ids, use_cache=False)
556
+ verify_logits = outputs.logits
557
 
558
+ start_pos = current_ids.shape[1] - 1
559
 
560
+ for i, (drafted_token, exit_head, exit_layer, uncertainty) in enumerate(
561
+ drafted_tokens
562
+ ):
563
+ verify_pos = start_pos + i
564
+ verified_token = torch.argmax(
565
+ verify_logits[0, verify_pos, :]
566
+ ).item()
567
 
568
+ if drafted_token == verified_token:
569
+ # Token matches - accept it with early exit info
570
+ token_text = self.tokenizer.decode([drafted_token])
571
+ tokens.append(
572
+ TokenInfo(
573
+ token_id=drafted_token,
574
+ token_text=token_text,
575
+ exit_head=exit_head,
576
+ exit_layer=exit_layer,
577
+ uncertainty=uncertainty,
578
+ )
579
+ )
580
+ current_ids = torch.cat(
581
+ [
582
+ current_ids,
583
+ torch.tensor([[drafted_token]], device=self.device),
584
+ ],
585
+ dim=1,
586
+ )
587
+ else:
588
+ # Mismatch - use full model's token
589
+ token_text = self.tokenizer.decode([verified_token])
590
+ tokens.append(
591
+ TokenInfo(
592
+ token_id=verified_token,
593
+ token_text=token_text,
594
+ exit_head=None, # Full model
595
+ exit_layer=num_layers,
596
+ uncertainty=0.0,
597
+ )
598
+ )
599
+ current_ids = torch.cat(
600
+ [
601
+ current_ids,
602
+ torch.tensor([[verified_token]], device=self.device),
603
+ ],
604
+ dim=1,
605
+ )
606
+ # Stop - discard remaining drafted tokens
607
+ break
608
+
609
+ # BONUS TOKEN from verification pass
610
+ if len(tokens) >= len(drafted_tokens):
611
+ bonus_pos = start_pos + len(drafted_tokens)
612
+ if bonus_pos < verify_logits.shape[1]:
613
+ bonus_token_id = torch.argmax(
614
+ verify_logits[0, bonus_pos, :]
615
+ ).item()
616
+ if (
617
+ bonus_token_id != self.tokenizer.eos_token_id
618
+ and len(tokens) < max_tokens
619
+ ):
620
+ bonus_text = self.tokenizer.decode([bonus_token_id])
621
+ tokens.append(
622
+ TokenInfo(
623
+ token_id=bonus_token_id,
624
+ token_text=bonus_text,
625
+ exit_head=None, # Full model
626
+ exit_layer=num_layers,
627
+ uncertainty=0.0,
628
+ )
629
+ )
630
+ current_ids = torch.cat(
631
+ [
632
+ current_ids,
633
+ torch.tensor(
634
+ [[bonus_token_id]], device=self.device
635
+ ),
636
+ ],
637
+ dim=1,
638
+ )
639
 
640
  # Check for EOS in accepted tokens
641
  if tokens and tokens[-1].token_id == self.tokenizer.eos_token_id:
 
647
  self,
648
  input_ids: torch.Tensor,
649
  thresholds: Dict[int, float],
650
+ ) -> Tuple[int, Optional[int], int, float]:
651
  """
652
+ Generate a single token using early exit or full model.
653
+
654
+ Returns (token_id, exit_head, exit_layer, uncertainty):
655
+ - If an early exit head is confident: returns token with that head's info
656
+ - If no head is confident: continues to lm_head and returns token from there
657
+
658
+ This function ALWAYS returns a token (never returns None).
659
  """
660
  device = input_ids.device
661
  seq_len = input_ids.shape[1]
662
  head_layers = self.model_config.head_layer_indices
663
+ num_layers = self.adapter.get_num_layers()
664
 
665
  # Position IDs
666
  position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(
 
715
  token_id = torch.argmax(head_logits[0, -1, :]).item()
716
  return (token_id, head_idx, layer_idx, uncertainty)
717
 
718
+ # No head was confident - use lm_head to get the token
719
+ # Apply final norm and lm_head
720
+ final_hidden = self.adapter.apply_final_norm(hidden_states)
721
+ logits = self.adapter.get_lm_head_output(final_hidden)
722
+
723
+ # Get token from last position
724
+ token_id = torch.argmax(logits[0, -1, :]).item()
725
+
726
+ # Compute uncertainty for the lm_head output
727
+ uncertainty = self.uncertainty_fn(logits[0, -1, :].unsqueeze(0), dim=-1).item()
728
+
729
+ return (token_id, None, num_layers, uncertainty)
730
 
731
  def _generate_full_model(
732
  self,
tests/test_inference_loop.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for the correct early exit inference loop behavior.
3
+
4
+ The inference loop should work as follows:
5
+
6
+ 1. SINGLE FORWARD PASS per token attempt:
7
+ - Process layers sequentially
8
+ - At each head checkpoint, check if confident enough
9
+ - If confident: EARLY EXIT - return token immediately (save compute)
10
+ - If no head confident: continue to lm_head, return token from there
11
+ - NEVER return None - always produce exactly one token per forward pass
12
+
13
+ 2. SPECULATIVE DECODING:
14
+ - Drafted tokens (from early exit heads) are unverified
15
+ - When we eventually run to lm_head (full model), we verify all pending drafts
16
+ - The lm_head pass also produces a BONUS token (the next prediction)
17
+ - On mismatch: use full model's token, discard remaining drafts
18
+
19
+ Key invariants:
20
+ - _draft_single_token NEVER returns None
21
+ - When all drafts are accepted, we get N+1 tokens (N verified + 1 bonus)
22
+ - No redundant computation (never run layers twice for same token)
23
+ """
24
+
25
+ import pytest
26
+ import torch
27
+ import torch.nn as nn
28
+ from unittest.mock import Mock, MagicMock, patch
29
+ from typing import List, Tuple, Optional
30
+ import sys
31
+ import os
32
+
33
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
34
+
35
+ from src.inference import DSSDecoder, TokenInfo, AuxiliaryHead, compute_entropy
36
+ from src.model_adapters import ModelAdapter
37
+ from src.model_config import ModelConfig, CalibrationResult
38
+
39
+
40
+ class MockAdapter(ModelAdapter):
41
+ """Mock adapter for testing without a real model."""
42
+
43
+ def __init__(self, num_layers: int = 8, hidden_size: int = 64, vocab_size: int = 100):
44
+ self.num_layers = num_layers
45
+ self.hidden_size = hidden_size
46
+ self.vocab_size = vocab_size
47
+ self._layers = nn.ModuleList([nn.Identity() for _ in range(num_layers)])
48
+ self._embed = nn.Embedding(vocab_size, hidden_size)
49
+ self._norm = nn.LayerNorm(hidden_size)
50
+ self._lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
51
+
52
+ # Create a mapping from layer to index
53
+ self._layer_to_idx = {layer: idx for idx, layer in enumerate(self._layers)}
54
+
55
+ # Track calls for verification
56
+ self.layer_calls = []
57
+ self.final_norm_calls = 0
58
+ self.lm_head_calls = 0
59
+
60
+ def get_embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
61
+ return self._embed(input_ids)
62
+
63
+ def get_layers(self) -> nn.ModuleList:
64
+ return self._layers
65
+
66
+ def get_num_layers(self) -> int:
67
+ return self.num_layers
68
+
69
+ def forward_layer(
70
+ self,
71
+ layer: nn.Module,
72
+ hidden_states: torch.Tensor,
73
+ position_ids: torch.Tensor,
74
+ attention_mask: Optional[torch.Tensor],
75
+ past_key_value: Optional[Tuple],
76
+ position_embeddings: Optional[Tuple],
77
+ use_cache: bool = True,
78
+ cache_position: Optional[torch.Tensor] = None,
79
+ ) -> Tuple[torch.Tensor, Optional[Tuple]]:
80
+ layer_idx = self._layer_to_idx.get(layer, -1)
81
+ self.layer_calls.append(layer_idx)
82
+ return hidden_states, None
83
+
84
+ def apply_final_norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
85
+ self.final_norm_calls += 1
86
+ return self._norm(hidden_states)
87
+
88
+ def get_lm_head_output(self, hidden_states: torch.Tensor) -> torch.Tensor:
89
+ self.lm_head_calls += 1
90
+ return self._lm_head(hidden_states)
91
+
92
+ def get_position_embeddings(
93
+ self, hidden_states: torch.Tensor, position_ids: torch.Tensor
94
+ ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
95
+ # Return dummy cos/sin embeddings
96
+ seq_len = hidden_states.shape[1]
97
+ cos = torch.ones(1, seq_len, self.hidden_size)
98
+ sin = torch.zeros(1, seq_len, self.hidden_size)
99
+ return (cos, sin)
100
+
101
+ def reset_tracking(self):
102
+ self.layer_calls = []
103
+ self.final_norm_calls = 0
104
+ self.lm_head_calls = 0
105
+
106
+
107
+ class MockTokenizer:
108
+ """Mock tokenizer for testing."""
109
+
110
+ def __init__(self, vocab_size: int = 100):
111
+ self.vocab_size = vocab_size
112
+ self.eos_token_id = 0
113
+ self.pad_token = "<pad>"
114
+ self.chat_template = None # Disable chat template
115
+
116
+ def encode(self, text: str, return_tensors: str = None) -> torch.Tensor:
117
+ # Simple mock encoding
118
+ tokens = [ord(c) % self.vocab_size for c in text[:10]]
119
+ if return_tensors == "pt":
120
+ return torch.tensor([tokens])
121
+ return tokens
122
+
123
+ def decode(self, token_ids: List[int]) -> str:
124
+ if isinstance(token_ids, int):
125
+ token_ids = [token_ids]
126
+ return "".join(chr(t + 65) for t in token_ids)
127
+
128
+
129
+ @pytest.fixture
130
+ def mock_model_config():
131
+ """Create a mock model config with 2 heads."""
132
+ return ModelConfig(
133
+ model_name="mock-model",
134
+ num_heads=2,
135
+ head_layer_indices=[2, 5], # Heads at layers 2 and 5
136
+ quantization="none",
137
+ hidden_size=64,
138
+ vocab_size=100,
139
+ num_hidden_layers=8,
140
+ )
141
+
142
+
143
+ @pytest.fixture
144
+ def mock_calibration():
145
+ """Create mock calibration with thresholds."""
146
+ return CalibrationResult(
147
+ model_config_path="mock",
148
+ calibration_dataset="mock",
149
+ calibration_samples=100,
150
+ uncertainty_metric="entropy",
151
+ accuracy_levels=[0.75],
152
+ thresholds={
153
+ "0.75": {
154
+ "0": 0.5, # Head 0 threshold
155
+ "1": 0.7, # Head 1 threshold
156
+ }
157
+ },
158
+ )
159
+
160
+
161
+ @pytest.fixture
162
+ def mock_aux_heads():
163
+ """Create mock auxiliary heads."""
164
+ heads = nn.ModuleList([
165
+ AuxiliaryHead(hidden_size=64, vocab_size=100),
166
+ AuxiliaryHead(hidden_size=64, vocab_size=100),
167
+ ])
168
+ return heads
169
+
170
+
171
+ class MockModel:
172
+ """Mock model that can be configured to return specific outputs."""
173
+
174
+ def __init__(self):
175
+ self._forward_fn = None
176
+
177
+ def parameters(self):
178
+ return iter([torch.zeros(1)])
179
+
180
+ def set_forward(self, fn):
181
+ """Set the forward function to use."""
182
+ self._forward_fn = fn
183
+
184
+ def __call__(self, input_ids, **kwargs):
185
+ if self._forward_fn is not None:
186
+ return self._forward_fn(input_ids, **kwargs)
187
+ # Default: return zeros
188
+ seq_len = input_ids.shape[1]
189
+ class Output:
190
+ def __init__(self):
191
+ self.logits = torch.zeros(1, seq_len, 100)
192
+ return Output()
193
+
194
+
195
+ class MockOutput:
196
+ """Simple output wrapper."""
197
+ def __init__(self, logits):
198
+ self.logits = logits
199
+
200
+
201
+ @pytest.fixture
202
+ def mock_decoder(mock_model_config, mock_calibration, mock_aux_heads):
203
+ """Create a decoder with mocked components."""
204
+ adapter = MockAdapter(num_layers=8, hidden_size=64, vocab_size=100)
205
+ tokenizer = MockTokenizer(vocab_size=100)
206
+
207
+ # Create a configurable mock model
208
+ mock_model = MockModel()
209
+
210
+ decoder = DSSDecoder(
211
+ model=mock_model,
212
+ adapter=adapter,
213
+ aux_heads=mock_aux_heads,
214
+ tokenizer=tokenizer,
215
+ model_config=mock_model_config,
216
+ calibration=mock_calibration,
217
+ device="cpu",
218
+ )
219
+ return decoder
220
+
221
+
222
+ class TestDraftSingleTokenNeverReturnsNone:
223
+ """
224
+ _draft_single_token should NEVER return None.
225
+
226
+ It should always return a token:
227
+ - From an early exit head if confident, OR
228
+ - From the lm_head if no head is confident
229
+ """
230
+
231
+ def test_returns_token_when_head_confident(self, mock_decoder):
232
+ """When a head is confident, return token with that head's info."""
233
+ # Make head 0 very confident (low entropy)
234
+ with patch.object(mock_decoder.aux_heads[0], 'forward') as mock_head:
235
+ # Create logits with very peaked distribution (low entropy)
236
+ logits = torch.zeros(1, 1, 100)
237
+ logits[0, 0, 42] = 100.0 # Very confident about token 42
238
+ mock_head.return_value = logits
239
+
240
+ input_ids = torch.tensor([[1, 2, 3]])
241
+ thresholds = {0: 0.5, 1: 0.7}
242
+
243
+ result = mock_decoder._draft_single_token(input_ids, thresholds)
244
+
245
+ assert result is not None, "_draft_single_token returned None!"
246
+ token_id, exit_head, exit_layer, uncertainty = result
247
+ assert token_id == 42
248
+ assert exit_head == 0
249
+ assert exit_layer == 2 # Head 0 is at layer 2
250
+
251
+ def test_returns_token_from_lm_head_when_no_head_confident(self, mock_decoder):
252
+ """
253
+ When NO head is confident, should continue to lm_head and return token.
254
+ This is the critical fix - currently the code returns None here.
255
+ """
256
+ # Make all heads NOT confident (high entropy)
257
+ def make_uncertain_logits(*args, **kwargs):
258
+ logits = torch.randn(1, 1, 100) # Random = high entropy
259
+ return logits
260
+
261
+ for head in mock_decoder.aux_heads:
262
+ head.forward = make_uncertain_logits
263
+
264
+ input_ids = torch.tensor([[1, 2, 3]])
265
+ thresholds = {0: 0.001, 1: 0.001} # Very strict thresholds
266
+
267
+ result = mock_decoder._draft_single_token(input_ids, thresholds)
268
+
269
+ # THIS IS THE KEY ASSERTION - currently fails!
270
+ assert result is not None, (
271
+ "_draft_single_token returned None when no head was confident. "
272
+ "It should have continued to lm_head and returned a token."
273
+ )
274
+
275
+ token_id, exit_head, exit_layer, uncertainty = result
276
+ assert exit_head is None, "Token should be from lm_head, not a head"
277
+ assert exit_layer == mock_decoder.adapter.get_num_layers()
278
+
279
+ def test_no_redundant_computation_when_lm_head_used(self, mock_decoder):
280
+ """
281
+ When falling back to lm_head, layers should only be computed ONCE.
282
+ The current bug: layers are computed in _draft_single_token,
283
+ then computed AGAIN in the fallback full model call.
284
+ """
285
+ adapter = mock_decoder.adapter
286
+ adapter.reset_tracking()
287
+
288
+ # Make all heads NOT confident
289
+ def make_uncertain_logits(*args, **kwargs):
290
+ return torch.randn(1, 1, 100)
291
+
292
+ for head in mock_decoder.aux_heads:
293
+ head.forward = make_uncertain_logits
294
+
295
+ input_ids = torch.tensor([[1, 2, 3]])
296
+ thresholds = {0: 0.001, 1: 0.001}
297
+
298
+ result = mock_decoder._draft_single_token(input_ids, thresholds)
299
+
300
+ # Count how many times each layer was called
301
+ layer_call_counts = {}
302
+ for layer_idx in adapter.layer_calls:
303
+ layer_call_counts[layer_idx] = layer_call_counts.get(layer_idx, 0) + 1
304
+
305
+ # Each layer should be called exactly ONCE
306
+ for layer_idx in range(adapter.num_layers):
307
+ count = layer_call_counts.get(layer_idx, 0)
308
+ assert count == 1, (
309
+ f"Layer {layer_idx} was called {count} times. "
310
+ "Should be exactly 1 (no redundant computation)."
311
+ )
312
+
313
+
314
+ class TestBonusTokenOnFullVerification:
315
+ """
316
+ When we run to lm_head (for verification or no confident head),
317
+ we should get N+1 tokens: N verified drafts + 1 bonus.
318
+ """
319
+
320
+ def test_bonus_token_when_all_drafts_accepted(self, mock_decoder):
321
+ """
322
+ If all drafted tokens are verified correct, we should get:
323
+ - All drafted tokens (verified)
324
+ - PLUS one bonus token from the last lm_head position
325
+ """
326
+ num_layers = mock_decoder.adapter.get_num_layers()
327
+
328
+ # Scenario: 3 tokens drafted with early exit, then one from lm_head (triggers verify)
329
+ # The lm_head token triggers verification of all previous drafts
330
+ drafted_sequence = [
331
+ (10, 0, 2, 0.1), # token 10, head 0, layer 2 (early exit)
332
+ (20, 1, 5, 0.2), # token 20, head 1, layer 5 (early exit)
333
+ (30, 1, 5, 0.3), # token 30, head 1, layer 5 (early exit)
334
+ (40, None, num_layers, 0.0), # token 40, lm_head (triggers verify)
335
+ ]
336
+
337
+ draft_call_count = [0]
338
+
339
+ def mock_draft(*args, **kwargs):
340
+ if draft_call_count[0] < len(drafted_sequence):
341
+ result = drafted_sequence[draft_call_count[0]]
342
+ draft_call_count[0] += 1
343
+ return result
344
+ # Return EOS to stop
345
+ return (mock_decoder.tokenizer.eos_token_id, None, num_layers, 0.0)
346
+
347
+ # Mock the full model verification
348
+ def mock_model_forward(input_ids, **kwargs):
349
+ seq_len = input_ids.shape[1]
350
+ logits = torch.zeros(1, seq_len, 100)
351
+
352
+ # Make all drafted tokens verify correctly
353
+ # base_pos = prompt length - 1 = 3 - 1 = 2
354
+ base_pos = 2
355
+ for i, (token_id, _, _, _) in enumerate(drafted_sequence):
356
+ if i < len(drafted_sequence):
357
+ logits[0, base_pos + i, token_id] = 100.0
358
+
359
+ # Bonus token prediction at last position
360
+ logits[0, -1, 99] = 100.0 # Predict token 99 as bonus
361
+
362
+ return MockOutput(logits)
363
+
364
+ mock_decoder.model.set_forward(mock_model_forward)
365
+
366
+ with patch.object(mock_decoder, '_draft_single_token', side_effect=mock_draft):
367
+ input_ids = torch.tensor([[1, 2, 3]])
368
+ thresholds = {0: 0.5, 1: 0.7}
369
+
370
+ tokens = mock_decoder._generate_with_early_exit(
371
+ input_ids, max_tokens=10, thresholds=thresholds
372
+ )
373
+
374
+ # Should get 5 tokens: 4 drafted/lm_head + 1 bonus
375
+ assert len(tokens) >= 5, (
376
+ f"Expected at least 5 tokens (4 drafted + 1 bonus), got {len(tokens)}. "
377
+ f"Tokens: {[(t.token_id, t.exit_head) for t in tokens]}"
378
+ )
379
+
380
+ # First 3 should be early exit tokens
381
+ assert tokens[0].token_id == 10
382
+ assert tokens[0].exit_head == 0
383
+ assert tokens[1].token_id == 20
384
+ assert tokens[1].exit_head == 1
385
+ assert tokens[2].token_id == 30
386
+ assert tokens[2].exit_head == 1
387
+
388
+ # 4th is the lm_head token that triggered verification
389
+ assert tokens[3].token_id == 40
390
+ assert tokens[3].exit_head is None
391
+
392
+ # 5th is the bonus token
393
+ assert tokens[4].token_id == 99, (
394
+ f"5th token should be bonus token 99, got {tokens[4].token_id}"
395
+ )
396
+ assert tokens[4].exit_head is None
397
+
398
+
399
+ class TestVerificationOnMismatch:
400
+ """Test that verification correctly handles mismatches."""
401
+
402
+ def test_rejected_draft_uses_full_model_token(self, mock_decoder):
403
+ """
404
+ When a draft is rejected (mismatch), we should:
405
+ 1. Use the full model's token instead
406
+ 2. Discard remaining drafted tokens
407
+ """
408
+ num_layers = mock_decoder.adapter.get_num_layers()
409
+
410
+ # Scenario: 3 early exit tokens drafted, then lm_head token triggers verify
411
+ # The second drafted token will NOT match
412
+ drafted_sequence = [
413
+ (10, 0, 2, 0.1), # Matches
414
+ (20, 1, 5, 0.2), # Will NOT match - full model says 25
415
+ (30, 1, 5, 0.3), # Should be discarded
416
+ (40, None, num_layers, 0.0), # lm_head triggers verification
417
+ ]
418
+
419
+ draft_call_count = [0]
420
+ def mock_draft(*args, **kwargs):
421
+ if draft_call_count[0] < len(drafted_sequence):
422
+ result = drafted_sequence[draft_call_count[0]]
423
+ draft_call_count[0] += 1
424
+ return result
425
+ # Return EOS to stop
426
+ return (mock_decoder.tokenizer.eos_token_id, None, num_layers, 0.0)
427
+
428
+ def mock_model_forward(input_ids, **kwargs):
429
+ seq_len = input_ids.shape[1]
430
+ logits = torch.zeros(1, seq_len, 100)
431
+
432
+ # base_pos = prompt_len - 1 = 3 - 1 = 2
433
+ base_pos = 2
434
+
435
+ # First draft matches
436
+ logits[0, base_pos, 10] = 100.0
437
+
438
+ # Second draft does NOT match - full model says 25
439
+ logits[0, base_pos + 1, 25] = 100.0 # Different from drafted 20!
440
+
441
+ return MockOutput(logits)
442
+
443
+ mock_decoder.model.set_forward(mock_model_forward)
444
+
445
+ with patch.object(mock_decoder, '_draft_single_token', side_effect=mock_draft):
446
+ input_ids = torch.tensor([[1, 2, 3]])
447
+ thresholds = {0: 0.5, 1: 0.7}
448
+
449
+ tokens = mock_decoder._generate_with_early_exit(
450
+ input_ids, max_tokens=10, thresholds=thresholds
451
+ )
452
+
453
+ # Should get exactly 2 tokens: first accepted, second corrected
454
+ # Third drafted token should be discarded
455
+ assert len(tokens) >= 2, f"Expected at least 2 tokens, got {len(tokens)}"
456
+
457
+ # First token: accepted draft
458
+ assert tokens[0].token_id == 10
459
+ assert tokens[0].exit_head == 0
460
+
461
+ # Second token: full model's correction
462
+ assert tokens[1].token_id == 25, (
463
+ f"Second token should be full model's 25, not drafted 20. Got {tokens[1].token_id}"
464
+ )
465
+ assert tokens[1].exit_head is None, "Corrected token should have exit_head=None"
466
+
467
+
468
+ class TestEarlyExitSavesCompute:
469
+ """Test that early exit actually skips layer computation."""
470
+
471
+ def test_early_exit_stops_at_confident_layer(self, mock_decoder):
472
+ """When head 0 (layer 2) is confident, layers 3-7 should NOT be computed."""
473
+ adapter = mock_decoder.adapter
474
+ adapter.reset_tracking()
475
+
476
+ # Make head 0 (at layer 2) very confident
477
+ with patch.object(mock_decoder.aux_heads[0], 'forward') as mock_head:
478
+ logits = torch.zeros(1, 1, 100)
479
+ logits[0, 0, 42] = 100.0
480
+ mock_head.return_value = logits
481
+
482
+ input_ids = torch.tensor([[1, 2, 3]])
483
+ thresholds = {0: 10.0, 1: 10.0} # High thresholds, easy to beat
484
+
485
+ result = mock_decoder._draft_single_token(input_ids, thresholds)
486
+
487
+ # Should have exited at layer 2
488
+ assert result is not None
489
+ _, exit_head, exit_layer, _ = result
490
+ assert exit_layer == 2
491
+
492
+ # Only layers 0, 1, 2 should have been called
493
+ max_layer_called = max(adapter.layer_calls) if adapter.layer_calls else -1
494
+ assert max_layer_called == 2, (
495
+ f"Expected to stop at layer 2, but layers up to {max_layer_called} were called. "
496
+ f"Layer calls: {adapter.layer_calls}"
497
+ )
498
+
499
+
500
+ class TestGenerationTermination:
501
+ """Test that generation terminates correctly."""
502
+
503
+ def test_stops_on_eos_token_from_draft(self, mock_decoder):
504
+ """Generation should stop when EOS token is produced during drafting."""
505
+ # Return EOS token on first draft
506
+ def mock_draft(input_ids, thresholds):
507
+ return (mock_decoder.tokenizer.eos_token_id, 0, 2, 0.1)
508
+
509
+ with patch.object(mock_decoder, '_draft_single_token', side_effect=mock_draft):
510
+ input_ids = torch.tensor([[1, 2, 3]])
511
+ thresholds = {0: 10.0, 1: 10.0}
512
+
513
+ tokens = mock_decoder._generate_with_early_exit(
514
+ input_ids, max_tokens=100, thresholds=thresholds
515
+ )
516
+
517
+ # Should stop immediately (0 tokens since EOS is not appended)
518
+ assert len(tokens) == 0, f"Should stop on EOS, got {len(tokens)} tokens"
519
+
520
+ def test_stops_at_max_tokens(self, mock_decoder):
521
+ """Generation should stop at max_tokens limit."""
522
+ num_layers = mock_decoder.adapter.get_num_layers()
523
+
524
+ # Make draft return alternating early exit / lm_head tokens
525
+ draft_count = [0]
526
+
527
+ def mock_draft(input_ids, thresholds):
528
+ draft_count[0] += 1
529
+ # Alternate between early exit and lm_head to trigger verification
530
+ if draft_count[0] % 2 == 1:
531
+ return (10 + draft_count[0], 0, 2, 0.1) # early exit
532
+ else:
533
+ return (20 + draft_count[0], None, num_layers, 0.0) # lm_head
534
+
535
+ def mock_model_forward(input_ids, **kwargs):
536
+ seq_len = input_ids.shape[1]
537
+ # Return logits that match the drafted tokens
538
+ logits = torch.zeros(1, seq_len, 100)
539
+ # Match all positions to their drafted values
540
+ for pos in range(seq_len):
541
+ expected_token = 10 + (pos + 1) if (pos + 1) % 2 == 1 else 20 + (pos + 1)
542
+ logits[0, pos, expected_token % 100] = 100.0
543
+ return MockOutput(logits)
544
+
545
+ mock_decoder.model.set_forward(mock_model_forward)
546
+
547
+ with patch.object(mock_decoder, '_draft_single_token', side_effect=mock_draft):
548
+ input_ids = torch.tensor([[1, 2, 3]])
549
+ thresholds = {0: 10.0, 1: 10.0}
550
+
551
+ tokens = mock_decoder._generate_with_early_exit(
552
+ input_ids, max_tokens=5, thresholds=thresholds
553
+ )
554
+
555
+ assert len(tokens) <= 5, f"Should stop at max_tokens=5, got {len(tokens)} tokens"
556
+
557
+
558
+ if __name__ == "__main__":
559
+ pytest.main([__file__, "-v", "--tb=short"])