|
|
""" |
|
|
Speculative Decoding Module for MiniMind Max2 |
|
|
Use small draft model to accelerate large model inference. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import List, Optional, Dict, Any, Tuple |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import time |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SpeculativeConfig: |
|
|
"""Configuration for speculative decoding.""" |
|
|
|
|
|
num_speculative_tokens: int = 5 |
|
|
max_speculation_length: int = 8 |
|
|
|
|
|
|
|
|
acceptance_method: str = "rejection" |
|
|
temperature: float = 1.0 |
|
|
top_p: float = 0.95 |
|
|
|
|
|
|
|
|
adaptive_speculation: bool = True |
|
|
min_speculative_tokens: int = 2 |
|
|
max_speculative_tokens: int = 10 |
|
|
target_acceptance_rate: float = 0.8 |
|
|
|
|
|
|
|
|
class DraftModel: |
|
|
""" |
|
|
Wrapper for draft model in speculative decoding. |
|
|
Typically a smaller, faster model (e.g., max2-nano for max2-pro). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
tokenizer = None, |
|
|
device: str = "cuda", |
|
|
): |
|
|
self.model = model |
|
|
self.tokenizer = tokenizer |
|
|
self.device = device |
|
|
self.model.eval() |
|
|
|
|
|
@torch.no_grad() |
|
|
def speculate( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
num_tokens: int = 5, |
|
|
temperature: float = 1.0, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Generate speculative tokens. |
|
|
|
|
|
Args: |
|
|
input_ids: Current input sequence [batch, seq_len] |
|
|
num_tokens: Number of tokens to speculate |
|
|
temperature: Sampling temperature |
|
|
|
|
|
Returns: |
|
|
Tuple of (speculated_tokens, speculated_probs) |
|
|
""" |
|
|
batch_size = input_ids.shape[0] |
|
|
speculated_tokens = [] |
|
|
speculated_probs = [] |
|
|
|
|
|
current_ids = input_ids |
|
|
|
|
|
for _ in range(num_tokens): |
|
|
|
|
|
_, logits, _, _ = self.model(current_ids) |
|
|
next_logits = logits[:, -1, :] / temperature |
|
|
|
|
|
|
|
|
probs = F.softmax(next_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
token_prob = probs.gather(1, next_token) |
|
|
|
|
|
speculated_tokens.append(next_token) |
|
|
speculated_probs.append(token_prob) |
|
|
|
|
|
|
|
|
current_ids = torch.cat([current_ids, next_token], dim=1) |
|
|
|
|
|
|
|
|
speculated_tokens = torch.cat(speculated_tokens, dim=1) |
|
|
speculated_probs = torch.cat(speculated_probs, dim=1) |
|
|
|
|
|
return speculated_tokens, speculated_probs |
|
|
|
|
|
|
|
|
class SpeculativeDecoder: |
|
|
""" |
|
|
Speculative decoding for accelerated generation. |
|
|
Uses a small draft model to propose tokens, verified by target model. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
target_model: nn.Module, |
|
|
draft_model: nn.Module, |
|
|
config: Optional[SpeculativeConfig] = None, |
|
|
device: str = "cuda", |
|
|
): |
|
|
self.target = target_model |
|
|
self.draft = DraftModel(draft_model, device=device) |
|
|
self.config = config or SpeculativeConfig() |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.total_generated = 0 |
|
|
self.total_accepted = 0 |
|
|
self.speculation_lengths = [] |
|
|
|
|
|
def _rejection_sampling( |
|
|
self, |
|
|
draft_probs: torch.Tensor, |
|
|
target_probs: torch.Tensor, |
|
|
draft_tokens: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, int]: |
|
|
""" |
|
|
Rejection sampling for token acceptance. |
|
|
|
|
|
Returns: |
|
|
Tuple of (accepted_mask, num_accepted) |
|
|
""" |
|
|
batch_size, num_tokens = draft_tokens.shape |
|
|
|
|
|
|
|
|
acceptance_probs = torch.min( |
|
|
torch.ones_like(draft_probs), |
|
|
target_probs / (draft_probs + 1e-10), |
|
|
) |
|
|
|
|
|
|
|
|
uniform = torch.rand_like(acceptance_probs) |
|
|
accepted = uniform < acceptance_probs |
|
|
|
|
|
|
|
|
accepted_mask = torch.cumprod(accepted.float(), dim=1).bool() |
|
|
num_accepted = accepted_mask.sum(dim=1).min().item() |
|
|
|
|
|
return accepted_mask, num_accepted |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_step( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
num_speculative: Optional[int] = None, |
|
|
) -> Tuple[torch.Tensor, Dict[str, Any]]: |
|
|
""" |
|
|
Single speculative generation step. |
|
|
|
|
|
Args: |
|
|
input_ids: Current sequence [batch, seq_len] |
|
|
num_speculative: Number of tokens to speculate (uses config if None) |
|
|
|
|
|
Returns: |
|
|
New tokens and statistics |
|
|
""" |
|
|
num_spec = num_speculative or self.config.num_speculative_tokens |
|
|
|
|
|
|
|
|
draft_tokens, draft_probs = self.draft.speculate( |
|
|
input_ids, |
|
|
num_tokens=num_spec, |
|
|
temperature=self.config.temperature, |
|
|
) |
|
|
|
|
|
|
|
|
spec_input = torch.cat([input_ids, draft_tokens], dim=1) |
|
|
_, target_logits, _, _ = self.target(spec_input) |
|
|
|
|
|
|
|
|
target_probs = F.softmax(target_logits[:, -num_spec-1:-1, :] / self.config.temperature, dim=-1) |
|
|
target_probs_selected = target_probs.gather(2, draft_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
|
|
|
accepted_mask, num_accepted = self._rejection_sampling( |
|
|
draft_probs, |
|
|
target_probs_selected, |
|
|
draft_tokens, |
|
|
) |
|
|
|
|
|
|
|
|
if num_accepted > 0: |
|
|
new_tokens = draft_tokens[:, :num_accepted] |
|
|
else: |
|
|
new_tokens = torch.empty(input_ids.shape[0], 0, dtype=torch.long, device=self.device) |
|
|
|
|
|
|
|
|
if num_accepted < num_spec: |
|
|
|
|
|
next_logits = target_logits[:, input_ids.shape[1] + num_accepted - 1, :] |
|
|
next_probs = F.softmax(next_logits / self.config.temperature, dim=-1) |
|
|
bonus_token = torch.multinomial(next_probs, num_samples=1) |
|
|
new_tokens = torch.cat([new_tokens, bonus_token], dim=1) |
|
|
|
|
|
|
|
|
self.total_generated += new_tokens.shape[1] |
|
|
self.total_accepted += num_accepted |
|
|
self.speculation_lengths.append(num_spec) |
|
|
|
|
|
stats = { |
|
|
"num_speculated": num_spec, |
|
|
"num_accepted": num_accepted, |
|
|
"num_generated": new_tokens.shape[1], |
|
|
"acceptance_rate": num_accepted / num_spec if num_spec > 0 else 0, |
|
|
} |
|
|
|
|
|
return new_tokens, stats |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
max_new_tokens: int = 100, |
|
|
eos_token_id: Optional[int] = None, |
|
|
) -> Tuple[torch.Tensor, Dict[str, Any]]: |
|
|
""" |
|
|
Full speculative generation. |
|
|
|
|
|
Args: |
|
|
input_ids: Initial input [batch, seq_len] |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
eos_token_id: EOS token to stop generation |
|
|
|
|
|
Returns: |
|
|
Generated sequence and statistics |
|
|
""" |
|
|
self.target.eval() |
|
|
|
|
|
generated = input_ids.clone() |
|
|
total_stats = { |
|
|
"steps": 0, |
|
|
"tokens_generated": 0, |
|
|
"acceptance_rates": [], |
|
|
} |
|
|
|
|
|
start_time = time.time() |
|
|
num_speculative = self.config.num_speculative_tokens |
|
|
|
|
|
while total_stats["tokens_generated"] < max_new_tokens: |
|
|
|
|
|
new_tokens, step_stats = self.generate_step(generated, num_speculative) |
|
|
|
|
|
if new_tokens.shape[1] == 0: |
|
|
break |
|
|
|
|
|
generated = torch.cat([generated, new_tokens], dim=1) |
|
|
|
|
|
|
|
|
total_stats["steps"] += 1 |
|
|
total_stats["tokens_generated"] += new_tokens.shape[1] |
|
|
total_stats["acceptance_rates"].append(step_stats["acceptance_rate"]) |
|
|
|
|
|
|
|
|
if eos_token_id is not None and (new_tokens == eos_token_id).any(): |
|
|
break |
|
|
|
|
|
|
|
|
if self.config.adaptive_speculation: |
|
|
avg_acceptance = sum(total_stats["acceptance_rates"][-5:]) / min(5, len(total_stats["acceptance_rates"])) |
|
|
if avg_acceptance > self.config.target_acceptance_rate: |
|
|
num_speculative = min(num_speculative + 1, self.config.max_speculative_tokens) |
|
|
elif avg_acceptance < self.config.target_acceptance_rate - 0.1: |
|
|
num_speculative = max(num_speculative - 1, self.config.min_speculative_tokens) |
|
|
|
|
|
end_time = time.time() |
|
|
|
|
|
total_stats["time_seconds"] = end_time - start_time |
|
|
total_stats["tokens_per_second"] = total_stats["tokens_generated"] / total_stats["time_seconds"] |
|
|
total_stats["avg_acceptance_rate"] = sum(total_stats["acceptance_rates"]) / max(1, len(total_stats["acceptance_rates"])) |
|
|
total_stats["avg_tokens_per_step"] = total_stats["tokens_generated"] / max(1, total_stats["steps"]) |
|
|
|
|
|
return generated, total_stats |
|
|
|
|
|
def get_statistics(self) -> Dict[str, float]: |
|
|
"""Get overall statistics.""" |
|
|
return { |
|
|
"total_generated": self.total_generated, |
|
|
"total_accepted": self.total_accepted, |
|
|
"overall_acceptance_rate": self.total_accepted / max(1, self.total_generated), |
|
|
"avg_speculation_length": sum(self.speculation_lengths) / max(1, len(self.speculation_lengths)), |
|
|
} |
|
|
|
|
|
def reset_statistics(self): |
|
|
"""Reset statistics counters.""" |
|
|
self.total_generated = 0 |
|
|
self.total_accepted = 0 |
|
|
self.speculation_lengths = [] |
|
|
|
|
|
|
|
|
class TreeSpeculativeDecoder(SpeculativeDecoder): |
|
|
""" |
|
|
Tree-based speculative decoding for higher acceptance rates. |
|
|
Generates multiple speculation branches. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
target_model: nn.Module, |
|
|
draft_model: nn.Module, |
|
|
num_branches: int = 3, |
|
|
config: Optional[SpeculativeConfig] = None, |
|
|
device: str = "cuda", |
|
|
): |
|
|
super().__init__(target_model, draft_model, config, device) |
|
|
self.num_branches = num_branches |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_tree( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
depth: int = 3, |
|
|
) -> List[Tuple[torch.Tensor, torch.Tensor]]: |
|
|
""" |
|
|
Generate tree of speculative tokens. |
|
|
|
|
|
Returns: |
|
|
List of (tokens, probs) tuples for each branch |
|
|
""" |
|
|
branches = [] |
|
|
|
|
|
|
|
|
for _ in range(self.num_branches): |
|
|
tokens, probs = self.draft.speculate( |
|
|
input_ids, |
|
|
num_tokens=depth, |
|
|
temperature=self.config.temperature, |
|
|
) |
|
|
branches.append((tokens, probs)) |
|
|
|
|
|
return branches |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_step( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
num_speculative: Optional[int] = None, |
|
|
) -> Tuple[torch.Tensor, Dict[str, Any]]: |
|
|
"""Tree-based speculative step.""" |
|
|
num_spec = num_speculative or self.config.num_speculative_tokens |
|
|
|
|
|
|
|
|
branches = self.generate_tree(input_ids, num_spec) |
|
|
|
|
|
best_tokens = None |
|
|
best_accepted = 0 |
|
|
|
|
|
|
|
|
for draft_tokens, draft_probs in branches: |
|
|
spec_input = torch.cat([input_ids, draft_tokens], dim=1) |
|
|
_, target_logits, _, _ = self.target(spec_input) |
|
|
|
|
|
target_probs = F.softmax( |
|
|
target_logits[:, -num_spec-1:-1, :] / self.config.temperature, dim=-1 |
|
|
) |
|
|
target_probs_selected = target_probs.gather(2, draft_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
_, num_accepted = self._rejection_sampling( |
|
|
draft_probs, |
|
|
target_probs_selected, |
|
|
draft_tokens, |
|
|
) |
|
|
|
|
|
if num_accepted > best_accepted: |
|
|
best_accepted = num_accepted |
|
|
best_tokens = draft_tokens[:, :num_accepted] |
|
|
|
|
|
if best_tokens is None or best_tokens.shape[1] == 0: |
|
|
|
|
|
_, logits, _, _ = self.target(input_ids) |
|
|
probs = F.softmax(logits[:, -1, :] / self.config.temperature, dim=-1) |
|
|
best_tokens = torch.multinomial(probs, num_samples=1) |
|
|
best_accepted = 0 |
|
|
|
|
|
stats = { |
|
|
"num_speculated": num_spec * self.num_branches, |
|
|
"num_accepted": best_accepted, |
|
|
"num_generated": best_tokens.shape[1], |
|
|
"acceptance_rate": best_accepted / num_spec if num_spec > 0 else 0, |
|
|
"num_branches": self.num_branches, |
|
|
} |
|
|
|
|
|
return best_tokens, stats |
|
|
|
|
|
|
|
|
def benchmark_speculative_decoding( |
|
|
target_model: nn.Module, |
|
|
draft_model: nn.Module, |
|
|
input_ids: torch.Tensor, |
|
|
num_tokens: int = 100, |
|
|
device: str = "cuda", |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Benchmark speculative decoding vs standard generation. |
|
|
""" |
|
|
import time |
|
|
|
|
|
|
|
|
target_model.eval() |
|
|
start = time.time() |
|
|
with torch.no_grad(): |
|
|
standard_output = target_model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=num_tokens, |
|
|
) |
|
|
standard_time = time.time() - start |
|
|
|
|
|
|
|
|
decoder = SpeculativeDecoder(target_model, draft_model, device=device) |
|
|
start = time.time() |
|
|
spec_output, spec_stats = decoder.generate( |
|
|
input_ids, |
|
|
max_new_tokens=num_tokens, |
|
|
) |
|
|
spec_time = time.time() - start |
|
|
|
|
|
return { |
|
|
"standard": { |
|
|
"time": standard_time, |
|
|
"tokens_per_second": num_tokens / standard_time, |
|
|
}, |
|
|
"speculative": { |
|
|
"time": spec_time, |
|
|
"tokens_per_second": spec_stats["tokens_per_second"], |
|
|
"acceptance_rate": spec_stats["avg_acceptance_rate"], |
|
|
"avg_tokens_per_step": spec_stats["avg_tokens_per_step"], |
|
|
}, |
|
|
"speedup": standard_time / spec_time, |
|
|
} |
|
|
|