MiniMind / capabilities /speculative.py
fariasultana's picture
feat: Add capabilities/speculative.py
00ecf49 verified
"""
Speculative Decoding Module for MiniMind Max2
Use small draft model to accelerate large model inference.
"""
from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding."""
# Speculation settings
num_speculative_tokens: int = 5 # Number of tokens to speculate
max_speculation_length: int = 8
# Acceptance settings
acceptance_method: str = "rejection" # rejection, nucleus
temperature: float = 1.0
top_p: float = 0.95
# Performance tuning
adaptive_speculation: bool = True # Adjust speculation based on acceptance rate
min_speculative_tokens: int = 2
max_speculative_tokens: int = 10
target_acceptance_rate: float = 0.8
class DraftModel:
"""
Wrapper for draft model in speculative decoding.
Typically a smaller, faster model (e.g., max2-nano for max2-pro).
"""
def __init__(
self,
model: nn.Module,
tokenizer = None,
device: str = "cuda",
):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.eval()
@torch.no_grad()
def speculate(
self,
input_ids: torch.Tensor,
num_tokens: int = 5,
temperature: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate speculative tokens.
Args:
input_ids: Current input sequence [batch, seq_len]
num_tokens: Number of tokens to speculate
temperature: Sampling temperature
Returns:
Tuple of (speculated_tokens, speculated_probs)
"""
batch_size = input_ids.shape[0]
speculated_tokens = []
speculated_probs = []
current_ids = input_ids
for _ in range(num_tokens):
# Forward pass
_, logits, _, _ = self.model(current_ids)
next_logits = logits[:, -1, :] / temperature
# Sample
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Get probability of selected token
token_prob = probs.gather(1, next_token)
speculated_tokens.append(next_token)
speculated_probs.append(token_prob)
# Append to sequence
current_ids = torch.cat([current_ids, next_token], dim=1)
# Stack results
speculated_tokens = torch.cat(speculated_tokens, dim=1) # [batch, num_tokens]
speculated_probs = torch.cat(speculated_probs, dim=1) # [batch, num_tokens]
return speculated_tokens, speculated_probs
class SpeculativeDecoder:
"""
Speculative decoding for accelerated generation.
Uses a small draft model to propose tokens, verified by target model.
"""
def __init__(
self,
target_model: nn.Module,
draft_model: nn.Module,
config: Optional[SpeculativeConfig] = None,
device: str = "cuda",
):
self.target = target_model
self.draft = DraftModel(draft_model, device=device)
self.config = config or SpeculativeConfig()
self.device = device
# Statistics
self.total_generated = 0
self.total_accepted = 0
self.speculation_lengths = []
def _rejection_sampling(
self,
draft_probs: torch.Tensor,
target_probs: torch.Tensor,
draft_tokens: torch.Tensor,
) -> Tuple[torch.Tensor, int]:
"""
Rejection sampling for token acceptance.
Returns:
Tuple of (accepted_mask, num_accepted)
"""
batch_size, num_tokens = draft_tokens.shape
# Compute acceptance probability: min(1, target_p / draft_p)
acceptance_probs = torch.min(
torch.ones_like(draft_probs),
target_probs / (draft_probs + 1e-10),
)
# Sample uniform for rejection test
uniform = torch.rand_like(acceptance_probs)
accepted = uniform < acceptance_probs
# Find first rejection point
accepted_mask = torch.cumprod(accepted.float(), dim=1).bool()
num_accepted = accepted_mask.sum(dim=1).min().item()
return accepted_mask, num_accepted
@torch.no_grad()
def generate_step(
self,
input_ids: torch.Tensor,
num_speculative: Optional[int] = None,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""
Single speculative generation step.
Args:
input_ids: Current sequence [batch, seq_len]
num_speculative: Number of tokens to speculate (uses config if None)
Returns:
New tokens and statistics
"""
num_spec = num_speculative or self.config.num_speculative_tokens
# Phase 1: Draft model speculation
draft_tokens, draft_probs = self.draft.speculate(
input_ids,
num_tokens=num_spec,
temperature=self.config.temperature,
)
# Phase 2: Target model verification (single forward pass)
spec_input = torch.cat([input_ids, draft_tokens], dim=1)
_, target_logits, _, _ = self.target(spec_input)
# Get target probabilities for draft tokens
target_probs = F.softmax(target_logits[:, -num_spec-1:-1, :] / self.config.temperature, dim=-1)
target_probs_selected = target_probs.gather(2, draft_tokens.unsqueeze(-1)).squeeze(-1)
# Phase 3: Rejection sampling
accepted_mask, num_accepted = self._rejection_sampling(
draft_probs,
target_probs_selected,
draft_tokens,
)
# Accept verified tokens
if num_accepted > 0:
new_tokens = draft_tokens[:, :num_accepted]
else:
new_tokens = torch.empty(input_ids.shape[0], 0, dtype=torch.long, device=self.device)
# Sample one more token from target if not all accepted
if num_accepted < num_spec:
# Resample from target distribution at rejection point
next_logits = target_logits[:, input_ids.shape[1] + num_accepted - 1, :]
next_probs = F.softmax(next_logits / self.config.temperature, dim=-1)
bonus_token = torch.multinomial(next_probs, num_samples=1)
new_tokens = torch.cat([new_tokens, bonus_token], dim=1)
# Statistics
self.total_generated += new_tokens.shape[1]
self.total_accepted += num_accepted
self.speculation_lengths.append(num_spec)
stats = {
"num_speculated": num_spec,
"num_accepted": num_accepted,
"num_generated": new_tokens.shape[1],
"acceptance_rate": num_accepted / num_spec if num_spec > 0 else 0,
}
return new_tokens, stats
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""
Full speculative generation.
Args:
input_ids: Initial input [batch, seq_len]
max_new_tokens: Maximum tokens to generate
eos_token_id: EOS token to stop generation
Returns:
Generated sequence and statistics
"""
self.target.eval()
generated = input_ids.clone()
total_stats = {
"steps": 0,
"tokens_generated": 0,
"acceptance_rates": [],
}
start_time = time.time()
num_speculative = self.config.num_speculative_tokens
while total_stats["tokens_generated"] < max_new_tokens:
# Speculative step
new_tokens, step_stats = self.generate_step(generated, num_speculative)
if new_tokens.shape[1] == 0:
break
generated = torch.cat([generated, new_tokens], dim=1)
# Update stats
total_stats["steps"] += 1
total_stats["tokens_generated"] += new_tokens.shape[1]
total_stats["acceptance_rates"].append(step_stats["acceptance_rate"])
# Check for EOS
if eos_token_id is not None and (new_tokens == eos_token_id).any():
break
# Adaptive speculation
if self.config.adaptive_speculation:
avg_acceptance = sum(total_stats["acceptance_rates"][-5:]) / min(5, len(total_stats["acceptance_rates"]))
if avg_acceptance > self.config.target_acceptance_rate:
num_speculative = min(num_speculative + 1, self.config.max_speculative_tokens)
elif avg_acceptance < self.config.target_acceptance_rate - 0.1:
num_speculative = max(num_speculative - 1, self.config.min_speculative_tokens)
end_time = time.time()
total_stats["time_seconds"] = end_time - start_time
total_stats["tokens_per_second"] = total_stats["tokens_generated"] / total_stats["time_seconds"]
total_stats["avg_acceptance_rate"] = sum(total_stats["acceptance_rates"]) / max(1, len(total_stats["acceptance_rates"]))
total_stats["avg_tokens_per_step"] = total_stats["tokens_generated"] / max(1, total_stats["steps"])
return generated, total_stats
def get_statistics(self) -> Dict[str, float]:
"""Get overall statistics."""
return {
"total_generated": self.total_generated,
"total_accepted": self.total_accepted,
"overall_acceptance_rate": self.total_accepted / max(1, self.total_generated),
"avg_speculation_length": sum(self.speculation_lengths) / max(1, len(self.speculation_lengths)),
}
def reset_statistics(self):
"""Reset statistics counters."""
self.total_generated = 0
self.total_accepted = 0
self.speculation_lengths = []
class TreeSpeculativeDecoder(SpeculativeDecoder):
"""
Tree-based speculative decoding for higher acceptance rates.
Generates multiple speculation branches.
"""
def __init__(
self,
target_model: nn.Module,
draft_model: nn.Module,
num_branches: int = 3,
config: Optional[SpeculativeConfig] = None,
device: str = "cuda",
):
super().__init__(target_model, draft_model, config, device)
self.num_branches = num_branches
@torch.no_grad()
def generate_tree(
self,
input_ids: torch.Tensor,
depth: int = 3,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""
Generate tree of speculative tokens.
Returns:
List of (tokens, probs) tuples for each branch
"""
branches = []
# Generate multiple branches from draft model
for _ in range(self.num_branches):
tokens, probs = self.draft.speculate(
input_ids,
num_tokens=depth,
temperature=self.config.temperature,
)
branches.append((tokens, probs))
return branches
@torch.no_grad()
def generate_step(
self,
input_ids: torch.Tensor,
num_speculative: Optional[int] = None,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Tree-based speculative step."""
num_spec = num_speculative or self.config.num_speculative_tokens
# Generate tree of speculations
branches = self.generate_tree(input_ids, num_spec)
best_tokens = None
best_accepted = 0
# Verify each branch and pick best
for draft_tokens, draft_probs in branches:
spec_input = torch.cat([input_ids, draft_tokens], dim=1)
_, target_logits, _, _ = self.target(spec_input)
target_probs = F.softmax(
target_logits[:, -num_spec-1:-1, :] / self.config.temperature, dim=-1
)
target_probs_selected = target_probs.gather(2, draft_tokens.unsqueeze(-1)).squeeze(-1)
_, num_accepted = self._rejection_sampling(
draft_probs,
target_probs_selected,
draft_tokens,
)
if num_accepted > best_accepted:
best_accepted = num_accepted
best_tokens = draft_tokens[:, :num_accepted]
if best_tokens is None or best_tokens.shape[1] == 0:
# Fallback: sample from target
_, logits, _, _ = self.target(input_ids)
probs = F.softmax(logits[:, -1, :] / self.config.temperature, dim=-1)
best_tokens = torch.multinomial(probs, num_samples=1)
best_accepted = 0
stats = {
"num_speculated": num_spec * self.num_branches,
"num_accepted": best_accepted,
"num_generated": best_tokens.shape[1],
"acceptance_rate": best_accepted / num_spec if num_spec > 0 else 0,
"num_branches": self.num_branches,
}
return best_tokens, stats
def benchmark_speculative_decoding(
target_model: nn.Module,
draft_model: nn.Module,
input_ids: torch.Tensor,
num_tokens: int = 100,
device: str = "cuda",
) -> Dict[str, Any]:
"""
Benchmark speculative decoding vs standard generation.
"""
import time
# Standard generation
target_model.eval()
start = time.time()
with torch.no_grad():
standard_output = target_model.generate(
input_ids,
max_new_tokens=num_tokens,
)
standard_time = time.time() - start
# Speculative generation
decoder = SpeculativeDecoder(target_model, draft_model, device=device)
start = time.time()
spec_output, spec_stats = decoder.generate(
input_ids,
max_new_tokens=num_tokens,
)
spec_time = time.time() - start
return {
"standard": {
"time": standard_time,
"tokens_per_second": num_tokens / standard_time,
},
"speculative": {
"time": spec_time,
"tokens_per_second": spec_stats["tokens_per_second"],
"acceptance_rate": spec_stats["avg_acceptance_rate"],
"avg_tokens_per_step": spec_stats["avg_tokens_per_step"],
},
"speedup": standard_time / spec_time,
}