Dssd_Demo / src /inference.py
Florian valade
Fix transformers compatibility: pin versions and rename past_key_value to past_key_values
687049b
# True Early Exit Inference with Dynamic Self-Speculative Decoding
# Provides actual speedup by stopping layer computation early
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Tuple, Callable
from collections import defaultdict
import time
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
BitsAndBytesConfig,
)
from .model_adapters import get_adapter, ModelAdapter
from .model_config import ModelConfig, CalibrationResult
def compute_entropy(logits: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""Compute entropy - lower = more confident."""
probs = F.softmax(logits, dim=dim)
log_probs = F.log_softmax(logits, dim=dim)
return -torch.sum(probs * log_probs, dim=dim)
class AuxiliaryHead(nn.Module):
"""Auxiliary head for early exit prediction."""
def __init__(
self, hidden_size: int, vocab_size: int, norm_layer: Optional[nn.Module] = None
):
super().__init__()
self.norm = norm_layer if norm_layer is not None else nn.Identity()
self.linear = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.linear(self.norm(hidden_states))
@dataclass
class TokenInfo:
"""Information about a generated token for visualization."""
token_id: int
token_text: str
exit_head: Optional[int] # None = full model
exit_layer: int
uncertainty: float
@dataclass
class StreamingResult:
"""Result from streaming generation with accumulated metrics."""
tokens: List[TokenInfo]
total_time: float
tokens_per_second: float
avg_exit_layer: float
exit_distribution: Dict[str, int]
@classmethod
def from_tokens(cls, tokens: List[TokenInfo], total_time: float, num_layers: int) -> "StreamingResult":
"""Build a StreamingResult from a list of tokens and timing info."""
exit_dist: Dict[str, int] = {}
layer_sum = 0
for t in tokens:
key = str(t.exit_head) if t.exit_head is not None else "full"
exit_dist[key] = exit_dist.get(key, 0) + 1
layer_sum += t.exit_layer
avg_layer = layer_sum / len(tokens) if tokens else num_layers
return cls(
tokens=tokens,
total_time=total_time,
tokens_per_second=len(tokens) / total_time if total_time > 0 else 0,
avg_exit_layer=avg_layer,
exit_distribution=exit_dist,
)
@dataclass
class StreamEvent:
"""Event for streaming generation updates."""
event_type: str # "draft", "verify_start", "accept", "reject", "full_model", "complete"
tokens: List[TokenInfo] # All tokens so far (validated)
drafted_tokens: List[TokenInfo] # Currently drafted (pending verification)
message: str # Human-readable status
result: Optional[StreamingResult] = None # Set on final "complete" event
@dataclass
class GenerationResult:
"""Complete generation result with token-level information."""
text: str
tokens: List[TokenInfo]
total_time: float
tokens_per_second: float
avg_exit_layer: float
exit_distribution: Dict[str, int]
class DSSDecoder:
"""
Dynamic Self-Speculative Decoder with TRUE early exit.
Actually stops computation at intermediate layers for speedup.
"""
def __init__(
self,
model: AutoModelForCausalLM,
adapter: ModelAdapter,
aux_heads: nn.ModuleList,
tokenizer: AutoTokenizer,
model_config: ModelConfig,
calibration: Optional[CalibrationResult] = None,
device: str = "cuda",
):
self.model = model
self.adapter = adapter
self.aux_heads = aux_heads
self.tokenizer = tokenizer
self.model_config = model_config
self.calibration = calibration
self.device = device
self.uncertainty_fn = compute_entropy
def _format_and_encode_prompt(self, prompt: str, use_chat_template: bool) -> torch.Tensor:
"""Format prompt with optional chat template and return input_ids tensor."""
if (
use_chat_template
and hasattr(self.tokenizer, "chat_template")
and self.tokenizer.chat_template is not None
):
try:
messages = [{"role": "user", "content": prompt}]
formatted = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
return self.tokenizer.encode(formatted, return_tensors="pt").to(
self.device
)
except Exception:
pass # Fall through to raw prompt encoding
return self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
def generate(
self,
prompt: str,
max_tokens: int = 100,
use_early_exit: bool = True,
accuracy_level: float = 0.75,
use_chat_template: bool = True,
) -> GenerationResult:
"""
Generate text with optional early exit.
Returns detailed token-level information for visualization.
"""
input_ids = self._format_and_encode_prompt(prompt, use_chat_template)
# Get thresholds
thresholds = {}
if use_early_exit and self.calibration:
thresholds = self.calibration.get_thresholds_for_level(accuracy_level)
# Generate
start_time = time.time()
if use_early_exit:
tokens = self._generate_with_early_exit(input_ids, max_tokens, thresholds)
else:
tokens = self._generate_full_model(input_ids, max_tokens)
end_time = time.time()
total_time = end_time - start_time
# Build result
text = "".join(t.token_text for t in tokens)
exit_dist = defaultdict(int)
layer_sum = 0
for t in tokens:
key = str(t.exit_head) if t.exit_head is not None else "full"
exit_dist[key] += 1
layer_sum += t.exit_layer
avg_layer = (
layer_sum / len(tokens) if tokens else self.model_config.num_hidden_layers
)
return GenerationResult(
text=text,
tokens=tokens,
total_time=total_time,
tokens_per_second=len(tokens) / total_time if total_time > 0 else 0,
avg_exit_layer=avg_layer,
exit_distribution=dict(exit_dist),
)
def generate_streaming(
self,
prompt: str,
max_tokens: int = 100,
accuracy_level: float = 0.75,
use_chat_template: bool = True,
max_draft_length: int = 5,
):
"""
Generate with streaming - yields events showing draft/verify process.
Each event shows current validated tokens and pending drafted tokens.
Yields a final "complete" event with StreamingResult containing metrics.
"""
input_ids = self._format_and_encode_prompt(prompt, use_chat_template)
# Get thresholds
thresholds = {}
if self.calibration:
thresholds = self.calibration.get_thresholds_for_level(accuracy_level)
validated_tokens = []
current_ids = input_ids.clone()
num_layers = self.adapter.get_num_layers()
start_time = time.time()
while len(validated_tokens) < max_tokens:
# ============================================================
# DRAFT PHASE: Generate tokens using early exit or lm_head
# ============================================================
drafted_tokens = []
draft_ids = current_ids.clone()
got_lm_head_token = False
should_stop = False
for _ in range(max_draft_length):
if len(validated_tokens) + len(drafted_tokens) >= max_tokens:
break
# Generate a token (always returns a result)
token_id, exit_head, exit_layer, uncertainty = self._draft_single_token(
draft_ids, thresholds
)
if token_id == self.tokenizer.eos_token_id:
# EOS handling
if exit_head is not None and drafted_tokens:
break # Verify pending drafts first
should_stop = True
break # Stop generation
token_text = self.tokenizer.decode([token_id])
drafted_token = TokenInfo(
token_id=token_id,
token_text=token_text,
exit_head=exit_head,
exit_layer=exit_layer,
uncertainty=uncertainty,
)
drafted_tokens.append(drafted_token)
draft_ids = torch.cat(
[draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
)
if exit_head is None:
# Token from lm_head - triggers verification
got_lm_head_token = True
yield StreamEvent(
event_type="draft",
tokens=list(validated_tokens),
drafted_tokens=list(drafted_tokens),
message=f"Drafting token {len(drafted_tokens)} using Full Model",
)
break
else:
# Token from early exit head
yield StreamEvent(
event_type="draft",
tokens=list(validated_tokens),
drafted_tokens=list(drafted_tokens),
message=f"Drafting token {len(drafted_tokens)} using Head {exit_head}",
)
# Check if we should stop (EOS encountered with no pending drafts)
if should_stop:
break
# ============================================================
# VERIFY PHASE
# ============================================================
if not drafted_tokens:
break
yield StreamEvent(
event_type="verify_start",
tokens=list(validated_tokens),
drafted_tokens=list(drafted_tokens),
message=f"Verifying {len(drafted_tokens)} drafted tokens...",
)
with torch.no_grad():
outputs = self.model(draft_ids, use_cache=False)
verify_logits = outputs.logits
start_pos = current_ids.shape[1] - 1
all_accepted = True
for i, drafted_token in enumerate(drafted_tokens):
verify_pos = start_pos + i
verified_token_id = torch.argmax(
verify_logits[0, verify_pos, :]
).item()
if drafted_token.token_id == verified_token_id:
# Accept
validated_tokens.append(drafted_token)
current_ids = torch.cat(
[
current_ids,
torch.tensor(
[[drafted_token.token_id]], device=self.device
),
],
dim=1,
)
yield StreamEvent(
event_type="accept",
tokens=list(validated_tokens),
drafted_tokens=[],
message=f"✓ Accepted '{drafted_token.token_text}'",
)
else:
# Reject - use full model's token
all_accepted = False
token_text = self.tokenizer.decode([verified_token_id])
corrected_token = TokenInfo(
token_id=verified_token_id,
token_text=token_text,
exit_head=None,
exit_layer=num_layers,
uncertainty=0.0,
)
validated_tokens.append(corrected_token)
current_ids = torch.cat(
[
current_ids,
torch.tensor([[verified_token_id]], device=self.device),
],
dim=1,
)
yield StreamEvent(
event_type="reject",
tokens=list(validated_tokens),
drafted_tokens=[],
message=f"✗ Rejected '{drafted_token.token_text}' → '{token_text}'",
)
break
# BONUS TOKEN: If all tokens were accepted, get bonus from last position
if all_accepted and len(validated_tokens) < max_tokens:
bonus_pos = start_pos + len(drafted_tokens)
if bonus_pos < verify_logits.shape[1]:
bonus_token_id = torch.argmax(
verify_logits[0, bonus_pos, :]
).item()
if bonus_token_id != self.tokenizer.eos_token_id:
bonus_text = self.tokenizer.decode([bonus_token_id])
bonus_token = TokenInfo(
token_id=bonus_token_id,
token_text=bonus_text,
exit_head=None,
exit_layer=num_layers,
uncertainty=0.0,
)
validated_tokens.append(bonus_token)
current_ids = torch.cat(
[
current_ids,
torch.tensor([[bonus_token_id]], device=self.device),
],
dim=1,
)
yield StreamEvent(
event_type="accept",
tokens=list(validated_tokens),
drafted_tokens=[],
message=f"✓ Bonus token '{bonus_text}'",
)
if (
validated_tokens
and validated_tokens[-1].token_id == self.tokenizer.eos_token_id
):
break
# Yield final "complete" event with metrics
total_time = time.time() - start_time
result = StreamingResult.from_tokens(validated_tokens, total_time, num_layers)
yield StreamEvent(
event_type="complete",
tokens=list(validated_tokens),
drafted_tokens=[],
message="Generation complete",
result=result,
)
def _generate_with_early_exit(
self,
input_ids: torch.Tensor,
max_tokens: int,
thresholds: Dict[int, float],
max_draft_length: int = 5,
) -> List[TokenInfo]:
"""
Speculative decoding with early exit heads.
The flow:
1. Generate tokens using _draft_single_token (which may early exit or use lm_head)
2. Tokens from early exit heads are "drafts" that need verification
3. When we get a token from lm_head (exit_head=None), it triggers verification
of all pending drafts, and the lm_head token is accepted as verified
4. All accepted tokens are guaranteed to match full model output
"""
tokens = []
current_ids = input_ids.clone()
num_layers = self.adapter.get_num_layers()
while len(tokens) < max_tokens:
# ============================================================
# DRAFT PHASE: Generate tokens, collecting early exit drafts
# ============================================================
drafted_tokens = [] # List of (token_id, exit_head, exit_layer, uncertainty)
draft_ids = current_ids.clone()
got_lm_head_token = False
for _ in range(max_draft_length):
if len(tokens) + len(drafted_tokens) >= max_tokens:
break
# Generate a token (always returns a result, never None)
token_id, exit_head, exit_layer, uncertainty = self._draft_single_token(
draft_ids, thresholds
)
if token_id == self.tokenizer.eos_token_id:
# If EOS from early exit, we still need to verify pending drafts
if exit_head is not None and drafted_tokens:
# Don't add EOS to drafts, just break to verify
break
# If EOS from lm_head or no pending drafts, we're done
return tokens
if exit_head is None:
# Token from lm_head - this is verified, triggers verification of drafts
got_lm_head_token = True
# Add to drafts for unified handling, but mark as already verified
drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
draft_ids = torch.cat(
[draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
)
break # Stop drafting, go to verification
else:
# Token from early exit head - add to drafts for later verification
drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
draft_ids = torch.cat(
[draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
)
# ============================================================
# VERIFY PHASE: Verify drafted tokens with full model
# ============================================================
if not drafted_tokens:
# No tokens generated (shouldn't happen with the new logic)
break
# If the last token is from lm_head, we already have full model output
# for all positions. Use it for verification.
last_token = drafted_tokens[-1]
_, last_exit_head, _, _ = last_token
if last_exit_head is None:
# Last token is from lm_head - all earlier tokens need verification
# The lm_head pass already computed logits for all positions
# We can use the model output to verify
# Need to run full model to get logits for verification
with torch.no_grad():
outputs = self.model(draft_ids, use_cache=False)
verify_logits = outputs.logits
start_pos = current_ids.shape[1] - 1
for i, (drafted_token, exit_head, exit_layer, uncertainty) in enumerate(
drafted_tokens
):
verify_pos = start_pos + i
verified_token = torch.argmax(
verify_logits[0, verify_pos, :]
).item()
if drafted_token == verified_token:
# Token matches - accept it
token_text = self.tokenizer.decode([drafted_token])
tokens.append(
TokenInfo(
token_id=drafted_token,
token_text=token_text,
exit_head=exit_head,
exit_layer=exit_layer,
uncertainty=uncertainty,
)
)
current_ids = torch.cat(
[
current_ids,
torch.tensor([[drafted_token]], device=self.device),
],
dim=1,
)
else:
# Mismatch - use full model's token
token_text = self.tokenizer.decode([verified_token])
tokens.append(
TokenInfo(
token_id=verified_token,
token_text=token_text,
exit_head=None, # Full model
exit_layer=num_layers,
uncertainty=0.0,
)
)
current_ids = torch.cat(
[
current_ids,
torch.tensor([[verified_token]], device=self.device),
],
dim=1,
)
# Stop - discard remaining drafted tokens
break
# BONUS TOKEN: If all drafted tokens were accepted, use the last position
# to get an additional token (this is the "free" token from lm_head)
if len(tokens) >= len(drafted_tokens):
# All drafts were accepted, check for bonus token
bonus_pos = start_pos + len(drafted_tokens)
if bonus_pos < verify_logits.shape[1]:
bonus_token_id = torch.argmax(
verify_logits[0, bonus_pos, :]
).item()
if (
bonus_token_id != self.tokenizer.eos_token_id
and len(tokens) < max_tokens
):
bonus_text = self.tokenizer.decode([bonus_token_id])
tokens.append(
TokenInfo(
token_id=bonus_token_id,
token_text=bonus_text,
exit_head=None, # Full model
exit_layer=num_layers,
uncertainty=0.0,
)
)
current_ids = torch.cat(
[
current_ids,
torch.tensor(
[[bonus_token_id]], device=self.device
),
],
dim=1,
)
else:
# All tokens are from early exit heads - need to run full model for verification
with torch.no_grad():
outputs = self.model(draft_ids, use_cache=False)
verify_logits = outputs.logits
start_pos = current_ids.shape[1] - 1
for i, (drafted_token, exit_head, exit_layer, uncertainty) in enumerate(
drafted_tokens
):
verify_pos = start_pos + i
verified_token = torch.argmax(
verify_logits[0, verify_pos, :]
).item()
if drafted_token == verified_token:
# Token matches - accept it with early exit info
token_text = self.tokenizer.decode([drafted_token])
tokens.append(
TokenInfo(
token_id=drafted_token,
token_text=token_text,
exit_head=exit_head,
exit_layer=exit_layer,
uncertainty=uncertainty,
)
)
current_ids = torch.cat(
[
current_ids,
torch.tensor([[drafted_token]], device=self.device),
],
dim=1,
)
else:
# Mismatch - use full model's token
token_text = self.tokenizer.decode([verified_token])
tokens.append(
TokenInfo(
token_id=verified_token,
token_text=token_text,
exit_head=None, # Full model
exit_layer=num_layers,
uncertainty=0.0,
)
)
current_ids = torch.cat(
[
current_ids,
torch.tensor([[verified_token]], device=self.device),
],
dim=1,
)
# Stop - discard remaining drafted tokens
break
# BONUS TOKEN from verification pass
if len(tokens) >= len(drafted_tokens):
bonus_pos = start_pos + len(drafted_tokens)
if bonus_pos < verify_logits.shape[1]:
bonus_token_id = torch.argmax(
verify_logits[0, bonus_pos, :]
).item()
if (
bonus_token_id != self.tokenizer.eos_token_id
and len(tokens) < max_tokens
):
bonus_text = self.tokenizer.decode([bonus_token_id])
tokens.append(
TokenInfo(
token_id=bonus_token_id,
token_text=bonus_text,
exit_head=None, # Full model
exit_layer=num_layers,
uncertainty=0.0,
)
)
current_ids = torch.cat(
[
current_ids,
torch.tensor(
[[bonus_token_id]], device=self.device
),
],
dim=1,
)
# Check for EOS in accepted tokens
if tokens and tokens[-1].token_id == self.tokenizer.eos_token_id:
break
return tokens
def _draft_single_token(
self,
input_ids: torch.Tensor,
thresholds: Dict[int, float],
) -> Tuple[int, Optional[int], int, float]:
"""
Generate a single token using early exit or full model.
Returns (token_id, exit_head, exit_layer, uncertainty):
- If an early exit head is confident: returns token with that head's info
- If no head is confident: continues to lm_head and returns token from there
This function ALWAYS returns a token (never returns None).
"""
device = input_ids.device
seq_len = input_ids.shape[1]
head_layers = self.model_config.head_layer_indices
num_layers = self.adapter.get_num_layers()
# Position IDs
position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(
0
)
# Cache position (required by newer transformers for Qwen3)
cache_position = torch.arange(seq_len, dtype=torch.long, device=device)
# Get embeddings
hidden_states = self.adapter.get_embed_tokens(input_ids)
# Get rotary embeddings
position_embeddings = self.adapter.get_position_embeddings(
hidden_states, position_ids
)
# Sort heads by layer
sorted_heads = sorted(enumerate(head_layers), key=lambda x: x[1])
# Iterate through layers
with torch.no_grad():
for layer_idx, layer in enumerate(self.adapter.get_layers()):
hidden_states, _ = self.adapter.forward_layer(
layer=layer,
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=None,
past_key_values=None,
position_embeddings=position_embeddings,
use_cache=False,
cache_position=cache_position,
)
# Check if this is a head checkpoint
for head_idx, head_layer in sorted_heads:
if layer_idx == head_layer:
# Run aux head on last position
aux_head = self.aux_heads[head_idx]
head_device = next(aux_head.parameters()).device
head_input = hidden_states[:, -1:, :].to(head_device)
head_logits = aux_head(head_input)
uncertainty = self.uncertainty_fn(
head_logits[:, -1, :], dim=-1
).item()
# Check threshold - if confident, return drafted token
if (
head_idx in thresholds
and uncertainty < thresholds[head_idx]
):
token_id = torch.argmax(head_logits[0, -1, :]).item()
return (token_id, head_idx, layer_idx, uncertainty)
# No head was confident - use lm_head to get the token
# Apply final norm and lm_head
final_hidden = self.adapter.apply_final_norm(hidden_states)
logits = self.adapter.get_lm_head_output(final_hidden)
# Get token from last position
token_id = torch.argmax(logits[0, -1, :]).item()
# Compute uncertainty for the lm_head output
uncertainty = self.uncertainty_fn(logits[0, -1, :].unsqueeze(0), dim=-1).item()
return (token_id, None, num_layers, uncertainty)
def _generate_full_model(
self,
input_ids: torch.Tensor,
max_tokens: int,
) -> List[TokenInfo]:
"""Generate using full model (no early exit)."""
tokens = []
current_ids = input_ids.clone()
num_layers = self.adapter.get_num_layers()
for _ in range(max_tokens):
with torch.no_grad():
outputs = self.model(current_ids, use_cache=False)
logits = outputs.logits
token_id = torch.argmax(logits[0, -1, :]).item()
if token_id == self.tokenizer.eos_token_id:
break
token_text = self.tokenizer.decode([token_id])
tokens.append(
TokenInfo(
token_id=token_id,
token_text=token_text,
exit_head=None,
exit_layer=num_layers,
uncertainty=0.0,
)
)
current_ids = torch.cat(
[current_ids, torch.tensor([[token_id]], device=self.device)], dim=1
)
return tokens
def generate_full_model_streaming(
self,
prompt: str,
max_tokens: int = 100,
use_chat_template: bool = True,
):
"""
Generate with full model in streaming mode - yields each token as generated.
Yields a final "complete" event with StreamingResult containing metrics.
"""
input_ids = self._format_and_encode_prompt(prompt, use_chat_template)
tokens = []
current_ids = input_ids.clone()
num_layers = self.adapter.get_num_layers()
start_time = time.time()
for i in range(max_tokens):
with torch.no_grad():
outputs = self.model(current_ids, use_cache=False)
logits = outputs.logits
token_id = torch.argmax(logits[0, -1, :]).item()
if token_id == self.tokenizer.eos_token_id:
break
token_text = self.tokenizer.decode([token_id])
token_info = TokenInfo(
token_id=token_id,
token_text=token_text,
exit_head=None,
exit_layer=num_layers,
uncertainty=0.0,
)
tokens.append(token_info)
current_ids = torch.cat(
[current_ids, torch.tensor([[token_id]], device=self.device)], dim=1
)
yield StreamEvent(
event_type="full_model",
tokens=list(tokens),
drafted_tokens=[],
message=f"Token {i + 1}: '{token_text}'",
)
# Yield final "complete" event with metrics
total_time = time.time() - start_time
result = StreamingResult.from_tokens(tokens, total_time, num_layers)
yield StreamEvent(
event_type="complete",
tokens=list(tokens),
drafted_tokens=[],
message="Generation complete",
result=result,
)
def load_dssd_model(
model_name: str,
heads_path: str,
config_path: str,
calibration_path: Optional[str] = None,
device: str = "auto",
) -> Tuple[DSSDecoder, AutoTokenizer]:
"""
Load a DSSD model from HuggingFace Hub or local paths.
Args:
model_name: HuggingFace model name (e.g., "meta-llama/Meta-Llama-3-8B")
heads_path: Path to aux_heads.pt
config_path: Path to config.json
calibration_path: Optional path to calibration.json
device: Device to load on
Returns:
decoder: DSSDecoder ready for generation
tokenizer: Tokenizer for the model
"""
# Load config
model_config = ModelConfig.from_json(config_path)
# Load calibration if provided
calibration = None
if calibration_path:
calibration = CalibrationResult.from_json(calibration_path)
# Quantization config
quant_config = None
if model_config.quantization == "4bit":
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
if torch.cuda.is_bf16_supported()
else torch.float32,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
elif model_config.quantization == "8bit":
quant_config = BitsAndBytesConfig(load_in_8bit=True)
# Load base model
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
device_map=device,
)
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Get adapter
adapter = get_adapter(model)
# Determine the norm type and create aux heads WITHOUT deepcopy (to avoid accelerate hooks)
aux_heads = nn.ModuleList()
# Get norm config from model
norm_eps = 1e-6
if hasattr(model.config, "rms_norm_eps"):
norm_eps = model.config.rms_norm_eps
elif hasattr(model.config, "layer_norm_eps"):
norm_eps = model.config.layer_norm_eps
for _ in range(model_config.num_heads):
# Create fresh RMSNorm (or LayerNorm) without accelerate hooks
norm_layer = nn.RMSNorm(model_config.hidden_size, eps=norm_eps)
head = AuxiliaryHead(
model_config.hidden_size,
model_config.vocab_size,
norm_layer,
)
aux_heads.append(head)
# Load trained weights (this will properly set the norm weights)
state_dict = torch.load(heads_path, map_location="cpu")
aux_heads.load_state_dict(state_dict)
# Move to device - use cuda:0 to keep on single device
model_device = (
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
model_dtype = next(model.parameters()).dtype
aux_heads = aux_heads.to(device=model_device, dtype=model_dtype)
aux_heads.eval()
# Create decoder
decoder = DSSDecoder(
model=model,
adapter=adapter,
aux_heads=aux_heads,
tokenizer=tokenizer,
model_config=model_config,
calibration=calibration,
device=str(model_device),
)
return decoder, tokenizer