Florian valade commited on
Commit ·
a781577
1
Parent(s): 7848d77
Fix early exit inference loop to eliminate redundant computation
Browse filesKey fixes:
- _draft_single_token now always returns a token (never None)
- When no early exit head is confident, continues to lm_head instead of
returning None and triggering a redundant full model pass
- Extracts 'bonus token' from verification pass when all drafts accepted
- Same fixes applied to generate_streaming method
This eliminates the double computation bug where layers were processed
twice when no head was confident, and adds the bonus token optimization
that extracts an extra token from each verification pass.
Adds comprehensive tests in tests/test_inference_loop.py
- src/inference.py +296 -141
- tests/test_inference_loop.py +559 -0
src/inference.py
CHANGED
|
@@ -218,28 +218,29 @@ class DSSDecoder:
|
|
| 218 |
validated_tokens = []
|
| 219 |
current_ids = input_ids.clone()
|
| 220 |
num_layers = self.adapter.get_num_layers()
|
| 221 |
-
head_layers = self.model_config.head_layer_indices
|
| 222 |
|
| 223 |
while len(validated_tokens) < max_tokens:
|
| 224 |
# ============================================================
|
| 225 |
-
# DRAFT PHASE: Generate tokens using early exit
|
| 226 |
# ============================================================
|
| 227 |
drafted_tokens = []
|
| 228 |
draft_ids = current_ids.clone()
|
|
|
|
| 229 |
|
| 230 |
for _ in range(max_draft_length):
|
| 231 |
if len(validated_tokens) + len(drafted_tokens) >= max_tokens:
|
| 232 |
break
|
| 233 |
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
token_id, exit_head, exit_layer, uncertainty = draft_result
|
| 240 |
|
| 241 |
if token_id == self.tokenizer.eos_token_id:
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
token_text = self.tokenizer.decode([token_id])
|
| 245 |
drafted_token = TokenInfo(
|
|
@@ -254,109 +255,126 @@ class DSSDecoder:
|
|
| 254 |
[draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
|
| 255 |
)
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
# ============================================================
|
| 266 |
# VERIFY PHASE
|
| 267 |
# ============================================================
|
| 268 |
-
if drafted_tokens:
|
| 269 |
-
|
| 270 |
-
event_type="verify_start",
|
| 271 |
-
tokens=list(validated_tokens),
|
| 272 |
-
drafted_tokens=list(drafted_tokens),
|
| 273 |
-
message=f"Verifying {len(drafted_tokens)} drafted tokens...",
|
| 274 |
-
)
|
| 275 |
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
| 286 |
).item()
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
[
|
| 293 |
-
current_ids,
|
| 294 |
-
torch.tensor(
|
| 295 |
-
[[drafted_token.token_id]], device=self.device
|
| 296 |
-
),
|
| 297 |
-
],
|
| 298 |
-
dim=1,
|
| 299 |
-
)
|
| 300 |
-
yield StreamEvent(
|
| 301 |
-
event_type="accept",
|
| 302 |
-
tokens=list(validated_tokens),
|
| 303 |
-
drafted_tokens=[],
|
| 304 |
-
message=f"✓ Accepted '{drafted_token.token_text}'",
|
| 305 |
-
)
|
| 306 |
-
else:
|
| 307 |
-
# Reject - use full model's token
|
| 308 |
-
token_text = self.tokenizer.decode([verified_token_id])
|
| 309 |
-
corrected_token = TokenInfo(
|
| 310 |
-
token_id=verified_token_id,
|
| 311 |
-
token_text=token_text,
|
| 312 |
exit_head=None,
|
| 313 |
exit_layer=num_layers,
|
| 314 |
uncertainty=0.0,
|
| 315 |
)
|
| 316 |
-
validated_tokens.append(
|
| 317 |
current_ids = torch.cat(
|
| 318 |
[
|
| 319 |
current_ids,
|
| 320 |
-
torch.tensor([[
|
| 321 |
],
|
| 322 |
dim=1,
|
| 323 |
)
|
| 324 |
yield StreamEvent(
|
| 325 |
-
event_type="
|
| 326 |
tokens=list(validated_tokens),
|
| 327 |
drafted_tokens=[],
|
| 328 |
-
message=f"
|
| 329 |
)
|
| 330 |
-
break
|
| 331 |
-
else:
|
| 332 |
-
# No drafts - generate with full model
|
| 333 |
-
with torch.no_grad():
|
| 334 |
-
outputs = self.model(current_ids, use_cache=False)
|
| 335 |
-
logits = outputs.logits
|
| 336 |
-
|
| 337 |
-
token_id = torch.argmax(logits[0, -1, :]).item()
|
| 338 |
-
|
| 339 |
-
if token_id == self.tokenizer.eos_token_id:
|
| 340 |
-
break
|
| 341 |
-
|
| 342 |
-
token_text = self.tokenizer.decode([token_id])
|
| 343 |
-
full_token = TokenInfo(
|
| 344 |
-
token_id=token_id,
|
| 345 |
-
token_text=token_text,
|
| 346 |
-
exit_head=None,
|
| 347 |
-
exit_layer=num_layers,
|
| 348 |
-
uncertainty=0.0,
|
| 349 |
-
)
|
| 350 |
-
validated_tokens.append(full_token)
|
| 351 |
-
current_ids = torch.cat(
|
| 352 |
-
[current_ids, torch.tensor([[token_id]], device=self.device)], dim=1
|
| 353 |
-
)
|
| 354 |
-
yield StreamEvent(
|
| 355 |
-
event_type="full_model",
|
| 356 |
-
tokens=list(validated_tokens),
|
| 357 |
-
drafted_tokens=[],
|
| 358 |
-
message=f"Full model: '{token_text}'",
|
| 359 |
-
)
|
| 360 |
|
| 361 |
if (
|
| 362 |
validated_tokens
|
|
@@ -374,55 +392,81 @@ class DSSDecoder:
|
|
| 374 |
"""
|
| 375 |
Speculative decoding with early exit heads.
|
| 376 |
|
| 377 |
-
|
| 378 |
-
1.
|
| 379 |
-
2.
|
| 380 |
-
3.
|
|
|
|
|
|
|
| 381 |
"""
|
| 382 |
tokens = []
|
| 383 |
current_ids = input_ids.clone()
|
| 384 |
num_layers = self.adapter.get_num_layers()
|
| 385 |
-
head_layers = self.model_config.head_layer_indices
|
| 386 |
|
| 387 |
while len(tokens) < max_tokens:
|
| 388 |
# ============================================================
|
| 389 |
-
# DRAFT PHASE: Generate tokens
|
| 390 |
# ============================================================
|
| 391 |
drafted_tokens = [] # List of (token_id, exit_head, exit_layer, uncertainty)
|
| 392 |
draft_ids = current_ids.clone()
|
|
|
|
| 393 |
|
| 394 |
for _ in range(max_draft_length):
|
| 395 |
if len(tokens) + len(drafted_tokens) >= max_tokens:
|
| 396 |
break
|
| 397 |
|
| 398 |
-
#
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
# No head was confident enough - need to verify
|
| 403 |
-
break
|
| 404 |
-
|
| 405 |
-
token_id, exit_head, exit_layer, uncertainty = draft_result
|
| 406 |
|
| 407 |
if token_id == self.tokenizer.eos_token_id:
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
# ============================================================
|
| 416 |
-
# VERIFY PHASE:
|
| 417 |
# ============================================================
|
| 418 |
-
if drafted_tokens:
|
| 419 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
with torch.no_grad():
|
| 421 |
outputs = self.model(draft_ids, use_cache=False)
|
| 422 |
verify_logits = outputs.logits
|
| 423 |
|
| 424 |
-
|
| 425 |
-
start_pos = current_ids.shape[1] - 1 # Position before drafting
|
| 426 |
|
| 427 |
for i, (drafted_token, exit_head, exit_layer, uncertainty) in enumerate(
|
| 428 |
drafted_tokens
|
|
@@ -433,7 +477,7 @@ class DSSDecoder:
|
|
| 433 |
).item()
|
| 434 |
|
| 435 |
if drafted_token == verified_token:
|
| 436 |
-
# Token matches - accept it
|
| 437 |
token_text = self.tokenizer.decode([drafted_token])
|
| 438 |
tokens.append(
|
| 439 |
TokenInfo(
|
|
@@ -472,30 +516,126 @@ class DSSDecoder:
|
|
| 472 |
)
|
| 473 |
# Stop - discard remaining drafted tokens
|
| 474 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
else:
|
| 476 |
-
#
|
| 477 |
with torch.no_grad():
|
| 478 |
-
outputs = self.model(
|
| 479 |
-
|
| 480 |
|
| 481 |
-
|
| 482 |
|
| 483 |
-
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
# Check for EOS in accepted tokens
|
| 501 |
if tokens and tokens[-1].token_id == self.tokenizer.eos_token_id:
|
|
@@ -507,15 +647,20 @@ class DSSDecoder:
|
|
| 507 |
self,
|
| 508 |
input_ids: torch.Tensor,
|
| 509 |
thresholds: Dict[int, float],
|
| 510 |
-
) ->
|
| 511 |
"""
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
Returns
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
"""
|
| 516 |
device = input_ids.device
|
| 517 |
seq_len = input_ids.shape[1]
|
| 518 |
head_layers = self.model_config.head_layer_indices
|
|
|
|
| 519 |
|
| 520 |
# Position IDs
|
| 521 |
position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(
|
|
@@ -570,8 +715,18 @@ class DSSDecoder:
|
|
| 570 |
token_id = torch.argmax(head_logits[0, -1, :]).item()
|
| 571 |
return (token_id, head_idx, layer_idx, uncertainty)
|
| 572 |
|
| 573 |
-
|
| 574 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
|
| 576 |
def _generate_full_model(
|
| 577 |
self,
|
|
|
|
| 218 |
validated_tokens = []
|
| 219 |
current_ids = input_ids.clone()
|
| 220 |
num_layers = self.adapter.get_num_layers()
|
|
|
|
| 221 |
|
| 222 |
while len(validated_tokens) < max_tokens:
|
| 223 |
# ============================================================
|
| 224 |
+
# DRAFT PHASE: Generate tokens using early exit or lm_head
|
| 225 |
# ============================================================
|
| 226 |
drafted_tokens = []
|
| 227 |
draft_ids = current_ids.clone()
|
| 228 |
+
got_lm_head_token = False
|
| 229 |
|
| 230 |
for _ in range(max_draft_length):
|
| 231 |
if len(validated_tokens) + len(drafted_tokens) >= max_tokens:
|
| 232 |
break
|
| 233 |
|
| 234 |
+
# Generate a token (always returns a result)
|
| 235 |
+
token_id, exit_head, exit_layer, uncertainty = self._draft_single_token(
|
| 236 |
+
draft_ids, thresholds
|
| 237 |
+
)
|
|
|
|
|
|
|
| 238 |
|
| 239 |
if token_id == self.tokenizer.eos_token_id:
|
| 240 |
+
# EOS handling
|
| 241 |
+
if exit_head is not None and drafted_tokens:
|
| 242 |
+
break # Verify pending drafts first
|
| 243 |
+
return # Stop generation
|
| 244 |
|
| 245 |
token_text = self.tokenizer.decode([token_id])
|
| 246 |
drafted_token = TokenInfo(
|
|
|
|
| 255 |
[draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
|
| 256 |
)
|
| 257 |
|
| 258 |
+
if exit_head is None:
|
| 259 |
+
# Token from lm_head - triggers verification
|
| 260 |
+
got_lm_head_token = True
|
| 261 |
+
yield StreamEvent(
|
| 262 |
+
event_type="draft",
|
| 263 |
+
tokens=list(validated_tokens),
|
| 264 |
+
drafted_tokens=list(drafted_tokens),
|
| 265 |
+
message=f"Drafting token {len(drafted_tokens)} using Full Model",
|
| 266 |
+
)
|
| 267 |
+
break
|
| 268 |
+
else:
|
| 269 |
+
# Token from early exit head
|
| 270 |
+
yield StreamEvent(
|
| 271 |
+
event_type="draft",
|
| 272 |
+
tokens=list(validated_tokens),
|
| 273 |
+
drafted_tokens=list(drafted_tokens),
|
| 274 |
+
message=f"Drafting token {len(drafted_tokens)} using Head {exit_head}",
|
| 275 |
+
)
|
| 276 |
|
| 277 |
# ============================================================
|
| 278 |
# VERIFY PHASE
|
| 279 |
# ============================================================
|
| 280 |
+
if not drafted_tokens:
|
| 281 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
+
yield StreamEvent(
|
| 284 |
+
event_type="verify_start",
|
| 285 |
+
tokens=list(validated_tokens),
|
| 286 |
+
drafted_tokens=list(drafted_tokens),
|
| 287 |
+
message=f"Verifying {len(drafted_tokens)} drafted tokens...",
|
| 288 |
+
)
|
| 289 |
|
| 290 |
+
with torch.no_grad():
|
| 291 |
+
outputs = self.model(draft_ids, use_cache=False)
|
| 292 |
+
verify_logits = outputs.logits
|
| 293 |
+
|
| 294 |
+
start_pos = current_ids.shape[1] - 1
|
| 295 |
+
all_accepted = True
|
| 296 |
+
|
| 297 |
+
for i, drafted_token in enumerate(drafted_tokens):
|
| 298 |
+
verify_pos = start_pos + i
|
| 299 |
+
verified_token_id = torch.argmax(
|
| 300 |
+
verify_logits[0, verify_pos, :]
|
| 301 |
+
).item()
|
| 302 |
+
|
| 303 |
+
if drafted_token.token_id == verified_token_id:
|
| 304 |
+
# Accept
|
| 305 |
+
validated_tokens.append(drafted_token)
|
| 306 |
+
current_ids = torch.cat(
|
| 307 |
+
[
|
| 308 |
+
current_ids,
|
| 309 |
+
torch.tensor(
|
| 310 |
+
[[drafted_token.token_id]], device=self.device
|
| 311 |
+
),
|
| 312 |
+
],
|
| 313 |
+
dim=1,
|
| 314 |
+
)
|
| 315 |
+
yield StreamEvent(
|
| 316 |
+
event_type="accept",
|
| 317 |
+
tokens=list(validated_tokens),
|
| 318 |
+
drafted_tokens=[],
|
| 319 |
+
message=f"✓ Accepted '{drafted_token.token_text}'",
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
# Reject - use full model's token
|
| 323 |
+
all_accepted = False
|
| 324 |
+
token_text = self.tokenizer.decode([verified_token_id])
|
| 325 |
+
corrected_token = TokenInfo(
|
| 326 |
+
token_id=verified_token_id,
|
| 327 |
+
token_text=token_text,
|
| 328 |
+
exit_head=None,
|
| 329 |
+
exit_layer=num_layers,
|
| 330 |
+
uncertainty=0.0,
|
| 331 |
+
)
|
| 332 |
+
validated_tokens.append(corrected_token)
|
| 333 |
+
current_ids = torch.cat(
|
| 334 |
+
[
|
| 335 |
+
current_ids,
|
| 336 |
+
torch.tensor([[verified_token_id]], device=self.device),
|
| 337 |
+
],
|
| 338 |
+
dim=1,
|
| 339 |
+
)
|
| 340 |
+
yield StreamEvent(
|
| 341 |
+
event_type="reject",
|
| 342 |
+
tokens=list(validated_tokens),
|
| 343 |
+
drafted_tokens=[],
|
| 344 |
+
message=f"✗ Rejected '{drafted_token.token_text}' → '{token_text}'",
|
| 345 |
+
)
|
| 346 |
+
break
|
| 347 |
|
| 348 |
+
# BONUS TOKEN: If all tokens were accepted, get bonus from last position
|
| 349 |
+
if all_accepted and len(validated_tokens) < max_tokens:
|
| 350 |
+
bonus_pos = start_pos + len(drafted_tokens)
|
| 351 |
+
if bonus_pos < verify_logits.shape[1]:
|
| 352 |
+
bonus_token_id = torch.argmax(
|
| 353 |
+
verify_logits[0, bonus_pos, :]
|
| 354 |
).item()
|
| 355 |
+
if bonus_token_id != self.tokenizer.eos_token_id:
|
| 356 |
+
bonus_text = self.tokenizer.decode([bonus_token_id])
|
| 357 |
+
bonus_token = TokenInfo(
|
| 358 |
+
token_id=bonus_token_id,
|
| 359 |
+
token_text=bonus_text,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
exit_head=None,
|
| 361 |
exit_layer=num_layers,
|
| 362 |
uncertainty=0.0,
|
| 363 |
)
|
| 364 |
+
validated_tokens.append(bonus_token)
|
| 365 |
current_ids = torch.cat(
|
| 366 |
[
|
| 367 |
current_ids,
|
| 368 |
+
torch.tensor([[bonus_token_id]], device=self.device),
|
| 369 |
],
|
| 370 |
dim=1,
|
| 371 |
)
|
| 372 |
yield StreamEvent(
|
| 373 |
+
event_type="accept",
|
| 374 |
tokens=list(validated_tokens),
|
| 375 |
drafted_tokens=[],
|
| 376 |
+
message=f"✓ Bonus token '{bonus_text}'",
|
| 377 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
if (
|
| 380 |
validated_tokens
|
|
|
|
| 392 |
"""
|
| 393 |
Speculative decoding with early exit heads.
|
| 394 |
|
| 395 |
+
The flow:
|
| 396 |
+
1. Generate tokens using _draft_single_token (which may early exit or use lm_head)
|
| 397 |
+
2. Tokens from early exit heads are "drafts" that need verification
|
| 398 |
+
3. When we get a token from lm_head (exit_head=None), it triggers verification
|
| 399 |
+
of all pending drafts, and the lm_head token is accepted as verified
|
| 400 |
+
4. All accepted tokens are guaranteed to match full model output
|
| 401 |
"""
|
| 402 |
tokens = []
|
| 403 |
current_ids = input_ids.clone()
|
| 404 |
num_layers = self.adapter.get_num_layers()
|
|
|
|
| 405 |
|
| 406 |
while len(tokens) < max_tokens:
|
| 407 |
# ============================================================
|
| 408 |
+
# DRAFT PHASE: Generate tokens, collecting early exit drafts
|
| 409 |
# ============================================================
|
| 410 |
drafted_tokens = [] # List of (token_id, exit_head, exit_layer, uncertainty)
|
| 411 |
draft_ids = current_ids.clone()
|
| 412 |
+
got_lm_head_token = False
|
| 413 |
|
| 414 |
for _ in range(max_draft_length):
|
| 415 |
if len(tokens) + len(drafted_tokens) >= max_tokens:
|
| 416 |
break
|
| 417 |
|
| 418 |
+
# Generate a token (always returns a result, never None)
|
| 419 |
+
token_id, exit_head, exit_layer, uncertainty = self._draft_single_token(
|
| 420 |
+
draft_ids, thresholds
|
| 421 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
if token_id == self.tokenizer.eos_token_id:
|
| 424 |
+
# If EOS from early exit, we still need to verify pending drafts
|
| 425 |
+
if exit_head is not None and drafted_tokens:
|
| 426 |
+
# Don't add EOS to drafts, just break to verify
|
| 427 |
+
break
|
| 428 |
+
# If EOS from lm_head or no pending drafts, we're done
|
| 429 |
+
return tokens
|
| 430 |
+
|
| 431 |
+
if exit_head is None:
|
| 432 |
+
# Token from lm_head - this is verified, triggers verification of drafts
|
| 433 |
+
got_lm_head_token = True
|
| 434 |
+
# Add to drafts for unified handling, but mark as already verified
|
| 435 |
+
drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
|
| 436 |
+
draft_ids = torch.cat(
|
| 437 |
+
[draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
|
| 438 |
+
)
|
| 439 |
+
break # Stop drafting, go to verification
|
| 440 |
+
else:
|
| 441 |
+
# Token from early exit head - add to drafts for later verification
|
| 442 |
+
drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
|
| 443 |
+
draft_ids = torch.cat(
|
| 444 |
+
[draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
|
| 445 |
+
)
|
| 446 |
|
| 447 |
# ============================================================
|
| 448 |
+
# VERIFY PHASE: Verify drafted tokens with full model
|
| 449 |
# ============================================================
|
| 450 |
+
if not drafted_tokens:
|
| 451 |
+
# No tokens generated (shouldn't happen with the new logic)
|
| 452 |
+
break
|
| 453 |
+
|
| 454 |
+
# If the last token is from lm_head, we already have full model output
|
| 455 |
+
# for all positions. Use it for verification.
|
| 456 |
+
last_token = drafted_tokens[-1]
|
| 457 |
+
_, last_exit_head, _, _ = last_token
|
| 458 |
+
|
| 459 |
+
if last_exit_head is None:
|
| 460 |
+
# Last token is from lm_head - all earlier tokens need verification
|
| 461 |
+
# The lm_head pass already computed logits for all positions
|
| 462 |
+
# We can use the model output to verify
|
| 463 |
+
|
| 464 |
+
# Need to run full model to get logits for verification
|
| 465 |
with torch.no_grad():
|
| 466 |
outputs = self.model(draft_ids, use_cache=False)
|
| 467 |
verify_logits = outputs.logits
|
| 468 |
|
| 469 |
+
start_pos = current_ids.shape[1] - 1
|
|
|
|
| 470 |
|
| 471 |
for i, (drafted_token, exit_head, exit_layer, uncertainty) in enumerate(
|
| 472 |
drafted_tokens
|
|
|
|
| 477 |
).item()
|
| 478 |
|
| 479 |
if drafted_token == verified_token:
|
| 480 |
+
# Token matches - accept it
|
| 481 |
token_text = self.tokenizer.decode([drafted_token])
|
| 482 |
tokens.append(
|
| 483 |
TokenInfo(
|
|
|
|
| 516 |
)
|
| 517 |
# Stop - discard remaining drafted tokens
|
| 518 |
break
|
| 519 |
+
|
| 520 |
+
# BONUS TOKEN: If all drafted tokens were accepted, use the last position
|
| 521 |
+
# to get an additional token (this is the "free" token from lm_head)
|
| 522 |
+
if len(tokens) >= len(drafted_tokens):
|
| 523 |
+
# All drafts were accepted, check for bonus token
|
| 524 |
+
bonus_pos = start_pos + len(drafted_tokens)
|
| 525 |
+
if bonus_pos < verify_logits.shape[1]:
|
| 526 |
+
bonus_token_id = torch.argmax(
|
| 527 |
+
verify_logits[0, bonus_pos, :]
|
| 528 |
+
).item()
|
| 529 |
+
if (
|
| 530 |
+
bonus_token_id != self.tokenizer.eos_token_id
|
| 531 |
+
and len(tokens) < max_tokens
|
| 532 |
+
):
|
| 533 |
+
bonus_text = self.tokenizer.decode([bonus_token_id])
|
| 534 |
+
tokens.append(
|
| 535 |
+
TokenInfo(
|
| 536 |
+
token_id=bonus_token_id,
|
| 537 |
+
token_text=bonus_text,
|
| 538 |
+
exit_head=None, # Full model
|
| 539 |
+
exit_layer=num_layers,
|
| 540 |
+
uncertainty=0.0,
|
| 541 |
+
)
|
| 542 |
+
)
|
| 543 |
+
current_ids = torch.cat(
|
| 544 |
+
[
|
| 545 |
+
current_ids,
|
| 546 |
+
torch.tensor(
|
| 547 |
+
[[bonus_token_id]], device=self.device
|
| 548 |
+
),
|
| 549 |
+
],
|
| 550 |
+
dim=1,
|
| 551 |
+
)
|
| 552 |
else:
|
| 553 |
+
# All tokens are from early exit heads - need to run full model for verification
|
| 554 |
with torch.no_grad():
|
| 555 |
+
outputs = self.model(draft_ids, use_cache=False)
|
| 556 |
+
verify_logits = outputs.logits
|
| 557 |
|
| 558 |
+
start_pos = current_ids.shape[1] - 1
|
| 559 |
|
| 560 |
+
for i, (drafted_token, exit_head, exit_layer, uncertainty) in enumerate(
|
| 561 |
+
drafted_tokens
|
| 562 |
+
):
|
| 563 |
+
verify_pos = start_pos + i
|
| 564 |
+
verified_token = torch.argmax(
|
| 565 |
+
verify_logits[0, verify_pos, :]
|
| 566 |
+
).item()
|
| 567 |
|
| 568 |
+
if drafted_token == verified_token:
|
| 569 |
+
# Token matches - accept it with early exit info
|
| 570 |
+
token_text = self.tokenizer.decode([drafted_token])
|
| 571 |
+
tokens.append(
|
| 572 |
+
TokenInfo(
|
| 573 |
+
token_id=drafted_token,
|
| 574 |
+
token_text=token_text,
|
| 575 |
+
exit_head=exit_head,
|
| 576 |
+
exit_layer=exit_layer,
|
| 577 |
+
uncertainty=uncertainty,
|
| 578 |
+
)
|
| 579 |
+
)
|
| 580 |
+
current_ids = torch.cat(
|
| 581 |
+
[
|
| 582 |
+
current_ids,
|
| 583 |
+
torch.tensor([[drafted_token]], device=self.device),
|
| 584 |
+
],
|
| 585 |
+
dim=1,
|
| 586 |
+
)
|
| 587 |
+
else:
|
| 588 |
+
# Mismatch - use full model's token
|
| 589 |
+
token_text = self.tokenizer.decode([verified_token])
|
| 590 |
+
tokens.append(
|
| 591 |
+
TokenInfo(
|
| 592 |
+
token_id=verified_token,
|
| 593 |
+
token_text=token_text,
|
| 594 |
+
exit_head=None, # Full model
|
| 595 |
+
exit_layer=num_layers,
|
| 596 |
+
uncertainty=0.0,
|
| 597 |
+
)
|
| 598 |
+
)
|
| 599 |
+
current_ids = torch.cat(
|
| 600 |
+
[
|
| 601 |
+
current_ids,
|
| 602 |
+
torch.tensor([[verified_token]], device=self.device),
|
| 603 |
+
],
|
| 604 |
+
dim=1,
|
| 605 |
+
)
|
| 606 |
+
# Stop - discard remaining drafted tokens
|
| 607 |
+
break
|
| 608 |
+
|
| 609 |
+
# BONUS TOKEN from verification pass
|
| 610 |
+
if len(tokens) >= len(drafted_tokens):
|
| 611 |
+
bonus_pos = start_pos + len(drafted_tokens)
|
| 612 |
+
if bonus_pos < verify_logits.shape[1]:
|
| 613 |
+
bonus_token_id = torch.argmax(
|
| 614 |
+
verify_logits[0, bonus_pos, :]
|
| 615 |
+
).item()
|
| 616 |
+
if (
|
| 617 |
+
bonus_token_id != self.tokenizer.eos_token_id
|
| 618 |
+
and len(tokens) < max_tokens
|
| 619 |
+
):
|
| 620 |
+
bonus_text = self.tokenizer.decode([bonus_token_id])
|
| 621 |
+
tokens.append(
|
| 622 |
+
TokenInfo(
|
| 623 |
+
token_id=bonus_token_id,
|
| 624 |
+
token_text=bonus_text,
|
| 625 |
+
exit_head=None, # Full model
|
| 626 |
+
exit_layer=num_layers,
|
| 627 |
+
uncertainty=0.0,
|
| 628 |
+
)
|
| 629 |
+
)
|
| 630 |
+
current_ids = torch.cat(
|
| 631 |
+
[
|
| 632 |
+
current_ids,
|
| 633 |
+
torch.tensor(
|
| 634 |
+
[[bonus_token_id]], device=self.device
|
| 635 |
+
),
|
| 636 |
+
],
|
| 637 |
+
dim=1,
|
| 638 |
+
)
|
| 639 |
|
| 640 |
# Check for EOS in accepted tokens
|
| 641 |
if tokens and tokens[-1].token_id == self.tokenizer.eos_token_id:
|
|
|
|
| 647 |
self,
|
| 648 |
input_ids: torch.Tensor,
|
| 649 |
thresholds: Dict[int, float],
|
| 650 |
+
) -> Tuple[int, Optional[int], int, float]:
|
| 651 |
"""
|
| 652 |
+
Generate a single token using early exit or full model.
|
| 653 |
+
|
| 654 |
+
Returns (token_id, exit_head, exit_layer, uncertainty):
|
| 655 |
+
- If an early exit head is confident: returns token with that head's info
|
| 656 |
+
- If no head is confident: continues to lm_head and returns token from there
|
| 657 |
+
|
| 658 |
+
This function ALWAYS returns a token (never returns None).
|
| 659 |
"""
|
| 660 |
device = input_ids.device
|
| 661 |
seq_len = input_ids.shape[1]
|
| 662 |
head_layers = self.model_config.head_layer_indices
|
| 663 |
+
num_layers = self.adapter.get_num_layers()
|
| 664 |
|
| 665 |
# Position IDs
|
| 666 |
position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(
|
|
|
|
| 715 |
token_id = torch.argmax(head_logits[0, -1, :]).item()
|
| 716 |
return (token_id, head_idx, layer_idx, uncertainty)
|
| 717 |
|
| 718 |
+
# No head was confident - use lm_head to get the token
|
| 719 |
+
# Apply final norm and lm_head
|
| 720 |
+
final_hidden = self.adapter.apply_final_norm(hidden_states)
|
| 721 |
+
logits = self.adapter.get_lm_head_output(final_hidden)
|
| 722 |
+
|
| 723 |
+
# Get token from last position
|
| 724 |
+
token_id = torch.argmax(logits[0, -1, :]).item()
|
| 725 |
+
|
| 726 |
+
# Compute uncertainty for the lm_head output
|
| 727 |
+
uncertainty = self.uncertainty_fn(logits[0, -1, :].unsqueeze(0), dim=-1).item()
|
| 728 |
+
|
| 729 |
+
return (token_id, None, num_layers, uncertainty)
|
| 730 |
|
| 731 |
def _generate_full_model(
|
| 732 |
self,
|
tests/test_inference_loop.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for the correct early exit inference loop behavior.
|
| 3 |
+
|
| 4 |
+
The inference loop should work as follows:
|
| 5 |
+
|
| 6 |
+
1. SINGLE FORWARD PASS per token attempt:
|
| 7 |
+
- Process layers sequentially
|
| 8 |
+
- At each head checkpoint, check if confident enough
|
| 9 |
+
- If confident: EARLY EXIT - return token immediately (save compute)
|
| 10 |
+
- If no head confident: continue to lm_head, return token from there
|
| 11 |
+
- NEVER return None - always produce exactly one token per forward pass
|
| 12 |
+
|
| 13 |
+
2. SPECULATIVE DECODING:
|
| 14 |
+
- Drafted tokens (from early exit heads) are unverified
|
| 15 |
+
- When we eventually run to lm_head (full model), we verify all pending drafts
|
| 16 |
+
- The lm_head pass also produces a BONUS token (the next prediction)
|
| 17 |
+
- On mismatch: use full model's token, discard remaining drafts
|
| 18 |
+
|
| 19 |
+
Key invariants:
|
| 20 |
+
- _draft_single_token NEVER returns None
|
| 21 |
+
- When all drafts are accepted, we get N+1 tokens (N verified + 1 bonus)
|
| 22 |
+
- No redundant computation (never run layers twice for same token)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import pytest
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from unittest.mock import Mock, MagicMock, patch
|
| 29 |
+
from typing import List, Tuple, Optional
|
| 30 |
+
import sys
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 34 |
+
|
| 35 |
+
from src.inference import DSSDecoder, TokenInfo, AuxiliaryHead, compute_entropy
|
| 36 |
+
from src.model_adapters import ModelAdapter
|
| 37 |
+
from src.model_config import ModelConfig, CalibrationResult
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MockAdapter(ModelAdapter):
|
| 41 |
+
"""Mock adapter for testing without a real model."""
|
| 42 |
+
|
| 43 |
+
def __init__(self, num_layers: int = 8, hidden_size: int = 64, vocab_size: int = 100):
|
| 44 |
+
self.num_layers = num_layers
|
| 45 |
+
self.hidden_size = hidden_size
|
| 46 |
+
self.vocab_size = vocab_size
|
| 47 |
+
self._layers = nn.ModuleList([nn.Identity() for _ in range(num_layers)])
|
| 48 |
+
self._embed = nn.Embedding(vocab_size, hidden_size)
|
| 49 |
+
self._norm = nn.LayerNorm(hidden_size)
|
| 50 |
+
self._lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
| 51 |
+
|
| 52 |
+
# Create a mapping from layer to index
|
| 53 |
+
self._layer_to_idx = {layer: idx for idx, layer in enumerate(self._layers)}
|
| 54 |
+
|
| 55 |
+
# Track calls for verification
|
| 56 |
+
self.layer_calls = []
|
| 57 |
+
self.final_norm_calls = 0
|
| 58 |
+
self.lm_head_calls = 0
|
| 59 |
+
|
| 60 |
+
def get_embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
return self._embed(input_ids)
|
| 62 |
+
|
| 63 |
+
def get_layers(self) -> nn.ModuleList:
|
| 64 |
+
return self._layers
|
| 65 |
+
|
| 66 |
+
def get_num_layers(self) -> int:
|
| 67 |
+
return self.num_layers
|
| 68 |
+
|
| 69 |
+
def forward_layer(
|
| 70 |
+
self,
|
| 71 |
+
layer: nn.Module,
|
| 72 |
+
hidden_states: torch.Tensor,
|
| 73 |
+
position_ids: torch.Tensor,
|
| 74 |
+
attention_mask: Optional[torch.Tensor],
|
| 75 |
+
past_key_value: Optional[Tuple],
|
| 76 |
+
position_embeddings: Optional[Tuple],
|
| 77 |
+
use_cache: bool = True,
|
| 78 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 79 |
+
) -> Tuple[torch.Tensor, Optional[Tuple]]:
|
| 80 |
+
layer_idx = self._layer_to_idx.get(layer, -1)
|
| 81 |
+
self.layer_calls.append(layer_idx)
|
| 82 |
+
return hidden_states, None
|
| 83 |
+
|
| 84 |
+
def apply_final_norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
self.final_norm_calls += 1
|
| 86 |
+
return self._norm(hidden_states)
|
| 87 |
+
|
| 88 |
+
def get_lm_head_output(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
self.lm_head_calls += 1
|
| 90 |
+
return self._lm_head(hidden_states)
|
| 91 |
+
|
| 92 |
+
def get_position_embeddings(
|
| 93 |
+
self, hidden_states: torch.Tensor, position_ids: torch.Tensor
|
| 94 |
+
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
| 95 |
+
# Return dummy cos/sin embeddings
|
| 96 |
+
seq_len = hidden_states.shape[1]
|
| 97 |
+
cos = torch.ones(1, seq_len, self.hidden_size)
|
| 98 |
+
sin = torch.zeros(1, seq_len, self.hidden_size)
|
| 99 |
+
return (cos, sin)
|
| 100 |
+
|
| 101 |
+
def reset_tracking(self):
|
| 102 |
+
self.layer_calls = []
|
| 103 |
+
self.final_norm_calls = 0
|
| 104 |
+
self.lm_head_calls = 0
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class MockTokenizer:
|
| 108 |
+
"""Mock tokenizer for testing."""
|
| 109 |
+
|
| 110 |
+
def __init__(self, vocab_size: int = 100):
|
| 111 |
+
self.vocab_size = vocab_size
|
| 112 |
+
self.eos_token_id = 0
|
| 113 |
+
self.pad_token = "<pad>"
|
| 114 |
+
self.chat_template = None # Disable chat template
|
| 115 |
+
|
| 116 |
+
def encode(self, text: str, return_tensors: str = None) -> torch.Tensor:
|
| 117 |
+
# Simple mock encoding
|
| 118 |
+
tokens = [ord(c) % self.vocab_size for c in text[:10]]
|
| 119 |
+
if return_tensors == "pt":
|
| 120 |
+
return torch.tensor([tokens])
|
| 121 |
+
return tokens
|
| 122 |
+
|
| 123 |
+
def decode(self, token_ids: List[int]) -> str:
|
| 124 |
+
if isinstance(token_ids, int):
|
| 125 |
+
token_ids = [token_ids]
|
| 126 |
+
return "".join(chr(t + 65) for t in token_ids)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@pytest.fixture
|
| 130 |
+
def mock_model_config():
|
| 131 |
+
"""Create a mock model config with 2 heads."""
|
| 132 |
+
return ModelConfig(
|
| 133 |
+
model_name="mock-model",
|
| 134 |
+
num_heads=2,
|
| 135 |
+
head_layer_indices=[2, 5], # Heads at layers 2 and 5
|
| 136 |
+
quantization="none",
|
| 137 |
+
hidden_size=64,
|
| 138 |
+
vocab_size=100,
|
| 139 |
+
num_hidden_layers=8,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@pytest.fixture
|
| 144 |
+
def mock_calibration():
|
| 145 |
+
"""Create mock calibration with thresholds."""
|
| 146 |
+
return CalibrationResult(
|
| 147 |
+
model_config_path="mock",
|
| 148 |
+
calibration_dataset="mock",
|
| 149 |
+
calibration_samples=100,
|
| 150 |
+
uncertainty_metric="entropy",
|
| 151 |
+
accuracy_levels=[0.75],
|
| 152 |
+
thresholds={
|
| 153 |
+
"0.75": {
|
| 154 |
+
"0": 0.5, # Head 0 threshold
|
| 155 |
+
"1": 0.7, # Head 1 threshold
|
| 156 |
+
}
|
| 157 |
+
},
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@pytest.fixture
|
| 162 |
+
def mock_aux_heads():
|
| 163 |
+
"""Create mock auxiliary heads."""
|
| 164 |
+
heads = nn.ModuleList([
|
| 165 |
+
AuxiliaryHead(hidden_size=64, vocab_size=100),
|
| 166 |
+
AuxiliaryHead(hidden_size=64, vocab_size=100),
|
| 167 |
+
])
|
| 168 |
+
return heads
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class MockModel:
|
| 172 |
+
"""Mock model that can be configured to return specific outputs."""
|
| 173 |
+
|
| 174 |
+
def __init__(self):
|
| 175 |
+
self._forward_fn = None
|
| 176 |
+
|
| 177 |
+
def parameters(self):
|
| 178 |
+
return iter([torch.zeros(1)])
|
| 179 |
+
|
| 180 |
+
def set_forward(self, fn):
|
| 181 |
+
"""Set the forward function to use."""
|
| 182 |
+
self._forward_fn = fn
|
| 183 |
+
|
| 184 |
+
def __call__(self, input_ids, **kwargs):
|
| 185 |
+
if self._forward_fn is not None:
|
| 186 |
+
return self._forward_fn(input_ids, **kwargs)
|
| 187 |
+
# Default: return zeros
|
| 188 |
+
seq_len = input_ids.shape[1]
|
| 189 |
+
class Output:
|
| 190 |
+
def __init__(self):
|
| 191 |
+
self.logits = torch.zeros(1, seq_len, 100)
|
| 192 |
+
return Output()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class MockOutput:
|
| 196 |
+
"""Simple output wrapper."""
|
| 197 |
+
def __init__(self, logits):
|
| 198 |
+
self.logits = logits
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@pytest.fixture
|
| 202 |
+
def mock_decoder(mock_model_config, mock_calibration, mock_aux_heads):
|
| 203 |
+
"""Create a decoder with mocked components."""
|
| 204 |
+
adapter = MockAdapter(num_layers=8, hidden_size=64, vocab_size=100)
|
| 205 |
+
tokenizer = MockTokenizer(vocab_size=100)
|
| 206 |
+
|
| 207 |
+
# Create a configurable mock model
|
| 208 |
+
mock_model = MockModel()
|
| 209 |
+
|
| 210 |
+
decoder = DSSDecoder(
|
| 211 |
+
model=mock_model,
|
| 212 |
+
adapter=adapter,
|
| 213 |
+
aux_heads=mock_aux_heads,
|
| 214 |
+
tokenizer=tokenizer,
|
| 215 |
+
model_config=mock_model_config,
|
| 216 |
+
calibration=mock_calibration,
|
| 217 |
+
device="cpu",
|
| 218 |
+
)
|
| 219 |
+
return decoder
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class TestDraftSingleTokenNeverReturnsNone:
|
| 223 |
+
"""
|
| 224 |
+
_draft_single_token should NEVER return None.
|
| 225 |
+
|
| 226 |
+
It should always return a token:
|
| 227 |
+
- From an early exit head if confident, OR
|
| 228 |
+
- From the lm_head if no head is confident
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def test_returns_token_when_head_confident(self, mock_decoder):
|
| 232 |
+
"""When a head is confident, return token with that head's info."""
|
| 233 |
+
# Make head 0 very confident (low entropy)
|
| 234 |
+
with patch.object(mock_decoder.aux_heads[0], 'forward') as mock_head:
|
| 235 |
+
# Create logits with very peaked distribution (low entropy)
|
| 236 |
+
logits = torch.zeros(1, 1, 100)
|
| 237 |
+
logits[0, 0, 42] = 100.0 # Very confident about token 42
|
| 238 |
+
mock_head.return_value = logits
|
| 239 |
+
|
| 240 |
+
input_ids = torch.tensor([[1, 2, 3]])
|
| 241 |
+
thresholds = {0: 0.5, 1: 0.7}
|
| 242 |
+
|
| 243 |
+
result = mock_decoder._draft_single_token(input_ids, thresholds)
|
| 244 |
+
|
| 245 |
+
assert result is not None, "_draft_single_token returned None!"
|
| 246 |
+
token_id, exit_head, exit_layer, uncertainty = result
|
| 247 |
+
assert token_id == 42
|
| 248 |
+
assert exit_head == 0
|
| 249 |
+
assert exit_layer == 2 # Head 0 is at layer 2
|
| 250 |
+
|
| 251 |
+
def test_returns_token_from_lm_head_when_no_head_confident(self, mock_decoder):
|
| 252 |
+
"""
|
| 253 |
+
When NO head is confident, should continue to lm_head and return token.
|
| 254 |
+
This is the critical fix - currently the code returns None here.
|
| 255 |
+
"""
|
| 256 |
+
# Make all heads NOT confident (high entropy)
|
| 257 |
+
def make_uncertain_logits(*args, **kwargs):
|
| 258 |
+
logits = torch.randn(1, 1, 100) # Random = high entropy
|
| 259 |
+
return logits
|
| 260 |
+
|
| 261 |
+
for head in mock_decoder.aux_heads:
|
| 262 |
+
head.forward = make_uncertain_logits
|
| 263 |
+
|
| 264 |
+
input_ids = torch.tensor([[1, 2, 3]])
|
| 265 |
+
thresholds = {0: 0.001, 1: 0.001} # Very strict thresholds
|
| 266 |
+
|
| 267 |
+
result = mock_decoder._draft_single_token(input_ids, thresholds)
|
| 268 |
+
|
| 269 |
+
# THIS IS THE KEY ASSERTION - currently fails!
|
| 270 |
+
assert result is not None, (
|
| 271 |
+
"_draft_single_token returned None when no head was confident. "
|
| 272 |
+
"It should have continued to lm_head and returned a token."
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
token_id, exit_head, exit_layer, uncertainty = result
|
| 276 |
+
assert exit_head is None, "Token should be from lm_head, not a head"
|
| 277 |
+
assert exit_layer == mock_decoder.adapter.get_num_layers()
|
| 278 |
+
|
| 279 |
+
def test_no_redundant_computation_when_lm_head_used(self, mock_decoder):
|
| 280 |
+
"""
|
| 281 |
+
When falling back to lm_head, layers should only be computed ONCE.
|
| 282 |
+
The current bug: layers are computed in _draft_single_token,
|
| 283 |
+
then computed AGAIN in the fallback full model call.
|
| 284 |
+
"""
|
| 285 |
+
adapter = mock_decoder.adapter
|
| 286 |
+
adapter.reset_tracking()
|
| 287 |
+
|
| 288 |
+
# Make all heads NOT confident
|
| 289 |
+
def make_uncertain_logits(*args, **kwargs):
|
| 290 |
+
return torch.randn(1, 1, 100)
|
| 291 |
+
|
| 292 |
+
for head in mock_decoder.aux_heads:
|
| 293 |
+
head.forward = make_uncertain_logits
|
| 294 |
+
|
| 295 |
+
input_ids = torch.tensor([[1, 2, 3]])
|
| 296 |
+
thresholds = {0: 0.001, 1: 0.001}
|
| 297 |
+
|
| 298 |
+
result = mock_decoder._draft_single_token(input_ids, thresholds)
|
| 299 |
+
|
| 300 |
+
# Count how many times each layer was called
|
| 301 |
+
layer_call_counts = {}
|
| 302 |
+
for layer_idx in adapter.layer_calls:
|
| 303 |
+
layer_call_counts[layer_idx] = layer_call_counts.get(layer_idx, 0) + 1
|
| 304 |
+
|
| 305 |
+
# Each layer should be called exactly ONCE
|
| 306 |
+
for layer_idx in range(adapter.num_layers):
|
| 307 |
+
count = layer_call_counts.get(layer_idx, 0)
|
| 308 |
+
assert count == 1, (
|
| 309 |
+
f"Layer {layer_idx} was called {count} times. "
|
| 310 |
+
"Should be exactly 1 (no redundant computation)."
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class TestBonusTokenOnFullVerification:
|
| 315 |
+
"""
|
| 316 |
+
When we run to lm_head (for verification or no confident head),
|
| 317 |
+
we should get N+1 tokens: N verified drafts + 1 bonus.
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
def test_bonus_token_when_all_drafts_accepted(self, mock_decoder):
|
| 321 |
+
"""
|
| 322 |
+
If all drafted tokens are verified correct, we should get:
|
| 323 |
+
- All drafted tokens (verified)
|
| 324 |
+
- PLUS one bonus token from the last lm_head position
|
| 325 |
+
"""
|
| 326 |
+
num_layers = mock_decoder.adapter.get_num_layers()
|
| 327 |
+
|
| 328 |
+
# Scenario: 3 tokens drafted with early exit, then one from lm_head (triggers verify)
|
| 329 |
+
# The lm_head token triggers verification of all previous drafts
|
| 330 |
+
drafted_sequence = [
|
| 331 |
+
(10, 0, 2, 0.1), # token 10, head 0, layer 2 (early exit)
|
| 332 |
+
(20, 1, 5, 0.2), # token 20, head 1, layer 5 (early exit)
|
| 333 |
+
(30, 1, 5, 0.3), # token 30, head 1, layer 5 (early exit)
|
| 334 |
+
(40, None, num_layers, 0.0), # token 40, lm_head (triggers verify)
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
draft_call_count = [0]
|
| 338 |
+
|
| 339 |
+
def mock_draft(*args, **kwargs):
|
| 340 |
+
if draft_call_count[0] < len(drafted_sequence):
|
| 341 |
+
result = drafted_sequence[draft_call_count[0]]
|
| 342 |
+
draft_call_count[0] += 1
|
| 343 |
+
return result
|
| 344 |
+
# Return EOS to stop
|
| 345 |
+
return (mock_decoder.tokenizer.eos_token_id, None, num_layers, 0.0)
|
| 346 |
+
|
| 347 |
+
# Mock the full model verification
|
| 348 |
+
def mock_model_forward(input_ids, **kwargs):
|
| 349 |
+
seq_len = input_ids.shape[1]
|
| 350 |
+
logits = torch.zeros(1, seq_len, 100)
|
| 351 |
+
|
| 352 |
+
# Make all drafted tokens verify correctly
|
| 353 |
+
# base_pos = prompt length - 1 = 3 - 1 = 2
|
| 354 |
+
base_pos = 2
|
| 355 |
+
for i, (token_id, _, _, _) in enumerate(drafted_sequence):
|
| 356 |
+
if i < len(drafted_sequence):
|
| 357 |
+
logits[0, base_pos + i, token_id] = 100.0
|
| 358 |
+
|
| 359 |
+
# Bonus token prediction at last position
|
| 360 |
+
logits[0, -1, 99] = 100.0 # Predict token 99 as bonus
|
| 361 |
+
|
| 362 |
+
return MockOutput(logits)
|
| 363 |
+
|
| 364 |
+
mock_decoder.model.set_forward(mock_model_forward)
|
| 365 |
+
|
| 366 |
+
with patch.object(mock_decoder, '_draft_single_token', side_effect=mock_draft):
|
| 367 |
+
input_ids = torch.tensor([[1, 2, 3]])
|
| 368 |
+
thresholds = {0: 0.5, 1: 0.7}
|
| 369 |
+
|
| 370 |
+
tokens = mock_decoder._generate_with_early_exit(
|
| 371 |
+
input_ids, max_tokens=10, thresholds=thresholds
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# Should get 5 tokens: 4 drafted/lm_head + 1 bonus
|
| 375 |
+
assert len(tokens) >= 5, (
|
| 376 |
+
f"Expected at least 5 tokens (4 drafted + 1 bonus), got {len(tokens)}. "
|
| 377 |
+
f"Tokens: {[(t.token_id, t.exit_head) for t in tokens]}"
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# First 3 should be early exit tokens
|
| 381 |
+
assert tokens[0].token_id == 10
|
| 382 |
+
assert tokens[0].exit_head == 0
|
| 383 |
+
assert tokens[1].token_id == 20
|
| 384 |
+
assert tokens[1].exit_head == 1
|
| 385 |
+
assert tokens[2].token_id == 30
|
| 386 |
+
assert tokens[2].exit_head == 1
|
| 387 |
+
|
| 388 |
+
# 4th is the lm_head token that triggered verification
|
| 389 |
+
assert tokens[3].token_id == 40
|
| 390 |
+
assert tokens[3].exit_head is None
|
| 391 |
+
|
| 392 |
+
# 5th is the bonus token
|
| 393 |
+
assert tokens[4].token_id == 99, (
|
| 394 |
+
f"5th token should be bonus token 99, got {tokens[4].token_id}"
|
| 395 |
+
)
|
| 396 |
+
assert tokens[4].exit_head is None
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class TestVerificationOnMismatch:
|
| 400 |
+
"""Test that verification correctly handles mismatches."""
|
| 401 |
+
|
| 402 |
+
def test_rejected_draft_uses_full_model_token(self, mock_decoder):
|
| 403 |
+
"""
|
| 404 |
+
When a draft is rejected (mismatch), we should:
|
| 405 |
+
1. Use the full model's token instead
|
| 406 |
+
2. Discard remaining drafted tokens
|
| 407 |
+
"""
|
| 408 |
+
num_layers = mock_decoder.adapter.get_num_layers()
|
| 409 |
+
|
| 410 |
+
# Scenario: 3 early exit tokens drafted, then lm_head token triggers verify
|
| 411 |
+
# The second drafted token will NOT match
|
| 412 |
+
drafted_sequence = [
|
| 413 |
+
(10, 0, 2, 0.1), # Matches
|
| 414 |
+
(20, 1, 5, 0.2), # Will NOT match - full model says 25
|
| 415 |
+
(30, 1, 5, 0.3), # Should be discarded
|
| 416 |
+
(40, None, num_layers, 0.0), # lm_head triggers verification
|
| 417 |
+
]
|
| 418 |
+
|
| 419 |
+
draft_call_count = [0]
|
| 420 |
+
def mock_draft(*args, **kwargs):
|
| 421 |
+
if draft_call_count[0] < len(drafted_sequence):
|
| 422 |
+
result = drafted_sequence[draft_call_count[0]]
|
| 423 |
+
draft_call_count[0] += 1
|
| 424 |
+
return result
|
| 425 |
+
# Return EOS to stop
|
| 426 |
+
return (mock_decoder.tokenizer.eos_token_id, None, num_layers, 0.0)
|
| 427 |
+
|
| 428 |
+
def mock_model_forward(input_ids, **kwargs):
|
| 429 |
+
seq_len = input_ids.shape[1]
|
| 430 |
+
logits = torch.zeros(1, seq_len, 100)
|
| 431 |
+
|
| 432 |
+
# base_pos = prompt_len - 1 = 3 - 1 = 2
|
| 433 |
+
base_pos = 2
|
| 434 |
+
|
| 435 |
+
# First draft matches
|
| 436 |
+
logits[0, base_pos, 10] = 100.0
|
| 437 |
+
|
| 438 |
+
# Second draft does NOT match - full model says 25
|
| 439 |
+
logits[0, base_pos + 1, 25] = 100.0 # Different from drafted 20!
|
| 440 |
+
|
| 441 |
+
return MockOutput(logits)
|
| 442 |
+
|
| 443 |
+
mock_decoder.model.set_forward(mock_model_forward)
|
| 444 |
+
|
| 445 |
+
with patch.object(mock_decoder, '_draft_single_token', side_effect=mock_draft):
|
| 446 |
+
input_ids = torch.tensor([[1, 2, 3]])
|
| 447 |
+
thresholds = {0: 0.5, 1: 0.7}
|
| 448 |
+
|
| 449 |
+
tokens = mock_decoder._generate_with_early_exit(
|
| 450 |
+
input_ids, max_tokens=10, thresholds=thresholds
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Should get exactly 2 tokens: first accepted, second corrected
|
| 454 |
+
# Third drafted token should be discarded
|
| 455 |
+
assert len(tokens) >= 2, f"Expected at least 2 tokens, got {len(tokens)}"
|
| 456 |
+
|
| 457 |
+
# First token: accepted draft
|
| 458 |
+
assert tokens[0].token_id == 10
|
| 459 |
+
assert tokens[0].exit_head == 0
|
| 460 |
+
|
| 461 |
+
# Second token: full model's correction
|
| 462 |
+
assert tokens[1].token_id == 25, (
|
| 463 |
+
f"Second token should be full model's 25, not drafted 20. Got {tokens[1].token_id}"
|
| 464 |
+
)
|
| 465 |
+
assert tokens[1].exit_head is None, "Corrected token should have exit_head=None"
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class TestEarlyExitSavesCompute:
|
| 469 |
+
"""Test that early exit actually skips layer computation."""
|
| 470 |
+
|
| 471 |
+
def test_early_exit_stops_at_confident_layer(self, mock_decoder):
|
| 472 |
+
"""When head 0 (layer 2) is confident, layers 3-7 should NOT be computed."""
|
| 473 |
+
adapter = mock_decoder.adapter
|
| 474 |
+
adapter.reset_tracking()
|
| 475 |
+
|
| 476 |
+
# Make head 0 (at layer 2) very confident
|
| 477 |
+
with patch.object(mock_decoder.aux_heads[0], 'forward') as mock_head:
|
| 478 |
+
logits = torch.zeros(1, 1, 100)
|
| 479 |
+
logits[0, 0, 42] = 100.0
|
| 480 |
+
mock_head.return_value = logits
|
| 481 |
+
|
| 482 |
+
input_ids = torch.tensor([[1, 2, 3]])
|
| 483 |
+
thresholds = {0: 10.0, 1: 10.0} # High thresholds, easy to beat
|
| 484 |
+
|
| 485 |
+
result = mock_decoder._draft_single_token(input_ids, thresholds)
|
| 486 |
+
|
| 487 |
+
# Should have exited at layer 2
|
| 488 |
+
assert result is not None
|
| 489 |
+
_, exit_head, exit_layer, _ = result
|
| 490 |
+
assert exit_layer == 2
|
| 491 |
+
|
| 492 |
+
# Only layers 0, 1, 2 should have been called
|
| 493 |
+
max_layer_called = max(adapter.layer_calls) if adapter.layer_calls else -1
|
| 494 |
+
assert max_layer_called == 2, (
|
| 495 |
+
f"Expected to stop at layer 2, but layers up to {max_layer_called} were called. "
|
| 496 |
+
f"Layer calls: {adapter.layer_calls}"
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class TestGenerationTermination:
|
| 501 |
+
"""Test that generation terminates correctly."""
|
| 502 |
+
|
| 503 |
+
def test_stops_on_eos_token_from_draft(self, mock_decoder):
|
| 504 |
+
"""Generation should stop when EOS token is produced during drafting."""
|
| 505 |
+
# Return EOS token on first draft
|
| 506 |
+
def mock_draft(input_ids, thresholds):
|
| 507 |
+
return (mock_decoder.tokenizer.eos_token_id, 0, 2, 0.1)
|
| 508 |
+
|
| 509 |
+
with patch.object(mock_decoder, '_draft_single_token', side_effect=mock_draft):
|
| 510 |
+
input_ids = torch.tensor([[1, 2, 3]])
|
| 511 |
+
thresholds = {0: 10.0, 1: 10.0}
|
| 512 |
+
|
| 513 |
+
tokens = mock_decoder._generate_with_early_exit(
|
| 514 |
+
input_ids, max_tokens=100, thresholds=thresholds
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Should stop immediately (0 tokens since EOS is not appended)
|
| 518 |
+
assert len(tokens) == 0, f"Should stop on EOS, got {len(tokens)} tokens"
|
| 519 |
+
|
| 520 |
+
def test_stops_at_max_tokens(self, mock_decoder):
|
| 521 |
+
"""Generation should stop at max_tokens limit."""
|
| 522 |
+
num_layers = mock_decoder.adapter.get_num_layers()
|
| 523 |
+
|
| 524 |
+
# Make draft return alternating early exit / lm_head tokens
|
| 525 |
+
draft_count = [0]
|
| 526 |
+
|
| 527 |
+
def mock_draft(input_ids, thresholds):
|
| 528 |
+
draft_count[0] += 1
|
| 529 |
+
# Alternate between early exit and lm_head to trigger verification
|
| 530 |
+
if draft_count[0] % 2 == 1:
|
| 531 |
+
return (10 + draft_count[0], 0, 2, 0.1) # early exit
|
| 532 |
+
else:
|
| 533 |
+
return (20 + draft_count[0], None, num_layers, 0.0) # lm_head
|
| 534 |
+
|
| 535 |
+
def mock_model_forward(input_ids, **kwargs):
|
| 536 |
+
seq_len = input_ids.shape[1]
|
| 537 |
+
# Return logits that match the drafted tokens
|
| 538 |
+
logits = torch.zeros(1, seq_len, 100)
|
| 539 |
+
# Match all positions to their drafted values
|
| 540 |
+
for pos in range(seq_len):
|
| 541 |
+
expected_token = 10 + (pos + 1) if (pos + 1) % 2 == 1 else 20 + (pos + 1)
|
| 542 |
+
logits[0, pos, expected_token % 100] = 100.0
|
| 543 |
+
return MockOutput(logits)
|
| 544 |
+
|
| 545 |
+
mock_decoder.model.set_forward(mock_model_forward)
|
| 546 |
+
|
| 547 |
+
with patch.object(mock_decoder, '_draft_single_token', side_effect=mock_draft):
|
| 548 |
+
input_ids = torch.tensor([[1, 2, 3]])
|
| 549 |
+
thresholds = {0: 10.0, 1: 10.0}
|
| 550 |
+
|
| 551 |
+
tokens = mock_decoder._generate_with_early_exit(
|
| 552 |
+
input_ids, max_tokens=5, thresholds=thresholds
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
assert len(tokens) <= 5, f"Should stop at max_tokens=5, got {len(tokens)} tokens"
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
if __name__ == "__main__":
|
| 559 |
+
pytest.main([__file__, "-v", "--tb=short"])
|