fariasultana commited on
Commit
00ecf49
·
verified ·
1 Parent(s): ef553ca

feat: Add capabilities/speculative.py

Browse files
Files changed (1) hide show
  1. capabilities/speculative.py +439 -0
capabilities/speculative.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speculative Decoding Module for MiniMind Max2
3
+ Use small draft model to accelerate large model inference.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Dict, Any, Tuple
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import time
12
+
13
+
14
+ @dataclass
15
+ class SpeculativeConfig:
16
+ """Configuration for speculative decoding."""
17
+ # Speculation settings
18
+ num_speculative_tokens: int = 5 # Number of tokens to speculate
19
+ max_speculation_length: int = 8
20
+
21
+ # Acceptance settings
22
+ acceptance_method: str = "rejection" # rejection, nucleus
23
+ temperature: float = 1.0
24
+ top_p: float = 0.95
25
+
26
+ # Performance tuning
27
+ adaptive_speculation: bool = True # Adjust speculation based on acceptance rate
28
+ min_speculative_tokens: int = 2
29
+ max_speculative_tokens: int = 10
30
+ target_acceptance_rate: float = 0.8
31
+
32
+
33
+ class DraftModel:
34
+ """
35
+ Wrapper for draft model in speculative decoding.
36
+ Typically a smaller, faster model (e.g., max2-nano for max2-pro).
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ model: nn.Module,
42
+ tokenizer = None,
43
+ device: str = "cuda",
44
+ ):
45
+ self.model = model
46
+ self.tokenizer = tokenizer
47
+ self.device = device
48
+ self.model.eval()
49
+
50
+ @torch.no_grad()
51
+ def speculate(
52
+ self,
53
+ input_ids: torch.Tensor,
54
+ num_tokens: int = 5,
55
+ temperature: float = 1.0,
56
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ """
58
+ Generate speculative tokens.
59
+
60
+ Args:
61
+ input_ids: Current input sequence [batch, seq_len]
62
+ num_tokens: Number of tokens to speculate
63
+ temperature: Sampling temperature
64
+
65
+ Returns:
66
+ Tuple of (speculated_tokens, speculated_probs)
67
+ """
68
+ batch_size = input_ids.shape[0]
69
+ speculated_tokens = []
70
+ speculated_probs = []
71
+
72
+ current_ids = input_ids
73
+
74
+ for _ in range(num_tokens):
75
+ # Forward pass
76
+ _, logits, _, _ = self.model(current_ids)
77
+ next_logits = logits[:, -1, :] / temperature
78
+
79
+ # Sample
80
+ probs = F.softmax(next_logits, dim=-1)
81
+ next_token = torch.multinomial(probs, num_samples=1)
82
+
83
+ # Get probability of selected token
84
+ token_prob = probs.gather(1, next_token)
85
+
86
+ speculated_tokens.append(next_token)
87
+ speculated_probs.append(token_prob)
88
+
89
+ # Append to sequence
90
+ current_ids = torch.cat([current_ids, next_token], dim=1)
91
+
92
+ # Stack results
93
+ speculated_tokens = torch.cat(speculated_tokens, dim=1) # [batch, num_tokens]
94
+ speculated_probs = torch.cat(speculated_probs, dim=1) # [batch, num_tokens]
95
+
96
+ return speculated_tokens, speculated_probs
97
+
98
+
99
+ class SpeculativeDecoder:
100
+ """
101
+ Speculative decoding for accelerated generation.
102
+ Uses a small draft model to propose tokens, verified by target model.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ target_model: nn.Module,
108
+ draft_model: nn.Module,
109
+ config: Optional[SpeculativeConfig] = None,
110
+ device: str = "cuda",
111
+ ):
112
+ self.target = target_model
113
+ self.draft = DraftModel(draft_model, device=device)
114
+ self.config = config or SpeculativeConfig()
115
+ self.device = device
116
+
117
+ # Statistics
118
+ self.total_generated = 0
119
+ self.total_accepted = 0
120
+ self.speculation_lengths = []
121
+
122
+ def _rejection_sampling(
123
+ self,
124
+ draft_probs: torch.Tensor,
125
+ target_probs: torch.Tensor,
126
+ draft_tokens: torch.Tensor,
127
+ ) -> Tuple[torch.Tensor, int]:
128
+ """
129
+ Rejection sampling for token acceptance.
130
+
131
+ Returns:
132
+ Tuple of (accepted_mask, num_accepted)
133
+ """
134
+ batch_size, num_tokens = draft_tokens.shape
135
+
136
+ # Compute acceptance probability: min(1, target_p / draft_p)
137
+ acceptance_probs = torch.min(
138
+ torch.ones_like(draft_probs),
139
+ target_probs / (draft_probs + 1e-10),
140
+ )
141
+
142
+ # Sample uniform for rejection test
143
+ uniform = torch.rand_like(acceptance_probs)
144
+ accepted = uniform < acceptance_probs
145
+
146
+ # Find first rejection point
147
+ accepted_mask = torch.cumprod(accepted.float(), dim=1).bool()
148
+ num_accepted = accepted_mask.sum(dim=1).min().item()
149
+
150
+ return accepted_mask, num_accepted
151
+
152
+ @torch.no_grad()
153
+ def generate_step(
154
+ self,
155
+ input_ids: torch.Tensor,
156
+ num_speculative: Optional[int] = None,
157
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
158
+ """
159
+ Single speculative generation step.
160
+
161
+ Args:
162
+ input_ids: Current sequence [batch, seq_len]
163
+ num_speculative: Number of tokens to speculate (uses config if None)
164
+
165
+ Returns:
166
+ New tokens and statistics
167
+ """
168
+ num_spec = num_speculative or self.config.num_speculative_tokens
169
+
170
+ # Phase 1: Draft model speculation
171
+ draft_tokens, draft_probs = self.draft.speculate(
172
+ input_ids,
173
+ num_tokens=num_spec,
174
+ temperature=self.config.temperature,
175
+ )
176
+
177
+ # Phase 2: Target model verification (single forward pass)
178
+ spec_input = torch.cat([input_ids, draft_tokens], dim=1)
179
+ _, target_logits, _, _ = self.target(spec_input)
180
+
181
+ # Get target probabilities for draft tokens
182
+ target_probs = F.softmax(target_logits[:, -num_spec-1:-1, :] / self.config.temperature, dim=-1)
183
+ target_probs_selected = target_probs.gather(2, draft_tokens.unsqueeze(-1)).squeeze(-1)
184
+
185
+ # Phase 3: Rejection sampling
186
+ accepted_mask, num_accepted = self._rejection_sampling(
187
+ draft_probs,
188
+ target_probs_selected,
189
+ draft_tokens,
190
+ )
191
+
192
+ # Accept verified tokens
193
+ if num_accepted > 0:
194
+ new_tokens = draft_tokens[:, :num_accepted]
195
+ else:
196
+ new_tokens = torch.empty(input_ids.shape[0], 0, dtype=torch.long, device=self.device)
197
+
198
+ # Sample one more token from target if not all accepted
199
+ if num_accepted < num_spec:
200
+ # Resample from target distribution at rejection point
201
+ next_logits = target_logits[:, input_ids.shape[1] + num_accepted - 1, :]
202
+ next_probs = F.softmax(next_logits / self.config.temperature, dim=-1)
203
+ bonus_token = torch.multinomial(next_probs, num_samples=1)
204
+ new_tokens = torch.cat([new_tokens, bonus_token], dim=1)
205
+
206
+ # Statistics
207
+ self.total_generated += new_tokens.shape[1]
208
+ self.total_accepted += num_accepted
209
+ self.speculation_lengths.append(num_spec)
210
+
211
+ stats = {
212
+ "num_speculated": num_spec,
213
+ "num_accepted": num_accepted,
214
+ "num_generated": new_tokens.shape[1],
215
+ "acceptance_rate": num_accepted / num_spec if num_spec > 0 else 0,
216
+ }
217
+
218
+ return new_tokens, stats
219
+
220
+ @torch.no_grad()
221
+ def generate(
222
+ self,
223
+ input_ids: torch.Tensor,
224
+ max_new_tokens: int = 100,
225
+ eos_token_id: Optional[int] = None,
226
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
227
+ """
228
+ Full speculative generation.
229
+
230
+ Args:
231
+ input_ids: Initial input [batch, seq_len]
232
+ max_new_tokens: Maximum tokens to generate
233
+ eos_token_id: EOS token to stop generation
234
+
235
+ Returns:
236
+ Generated sequence and statistics
237
+ """
238
+ self.target.eval()
239
+
240
+ generated = input_ids.clone()
241
+ total_stats = {
242
+ "steps": 0,
243
+ "tokens_generated": 0,
244
+ "acceptance_rates": [],
245
+ }
246
+
247
+ start_time = time.time()
248
+ num_speculative = self.config.num_speculative_tokens
249
+
250
+ while total_stats["tokens_generated"] < max_new_tokens:
251
+ # Speculative step
252
+ new_tokens, step_stats = self.generate_step(generated, num_speculative)
253
+
254
+ if new_tokens.shape[1] == 0:
255
+ break
256
+
257
+ generated = torch.cat([generated, new_tokens], dim=1)
258
+
259
+ # Update stats
260
+ total_stats["steps"] += 1
261
+ total_stats["tokens_generated"] += new_tokens.shape[1]
262
+ total_stats["acceptance_rates"].append(step_stats["acceptance_rate"])
263
+
264
+ # Check for EOS
265
+ if eos_token_id is not None and (new_tokens == eos_token_id).any():
266
+ break
267
+
268
+ # Adaptive speculation
269
+ if self.config.adaptive_speculation:
270
+ avg_acceptance = sum(total_stats["acceptance_rates"][-5:]) / min(5, len(total_stats["acceptance_rates"]))
271
+ if avg_acceptance > self.config.target_acceptance_rate:
272
+ num_speculative = min(num_speculative + 1, self.config.max_speculative_tokens)
273
+ elif avg_acceptance < self.config.target_acceptance_rate - 0.1:
274
+ num_speculative = max(num_speculative - 1, self.config.min_speculative_tokens)
275
+
276
+ end_time = time.time()
277
+
278
+ total_stats["time_seconds"] = end_time - start_time
279
+ total_stats["tokens_per_second"] = total_stats["tokens_generated"] / total_stats["time_seconds"]
280
+ total_stats["avg_acceptance_rate"] = sum(total_stats["acceptance_rates"]) / max(1, len(total_stats["acceptance_rates"]))
281
+ total_stats["avg_tokens_per_step"] = total_stats["tokens_generated"] / max(1, total_stats["steps"])
282
+
283
+ return generated, total_stats
284
+
285
+ def get_statistics(self) -> Dict[str, float]:
286
+ """Get overall statistics."""
287
+ return {
288
+ "total_generated": self.total_generated,
289
+ "total_accepted": self.total_accepted,
290
+ "overall_acceptance_rate": self.total_accepted / max(1, self.total_generated),
291
+ "avg_speculation_length": sum(self.speculation_lengths) / max(1, len(self.speculation_lengths)),
292
+ }
293
+
294
+ def reset_statistics(self):
295
+ """Reset statistics counters."""
296
+ self.total_generated = 0
297
+ self.total_accepted = 0
298
+ self.speculation_lengths = []
299
+
300
+
301
+ class TreeSpeculativeDecoder(SpeculativeDecoder):
302
+ """
303
+ Tree-based speculative decoding for higher acceptance rates.
304
+ Generates multiple speculation branches.
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ target_model: nn.Module,
310
+ draft_model: nn.Module,
311
+ num_branches: int = 3,
312
+ config: Optional[SpeculativeConfig] = None,
313
+ device: str = "cuda",
314
+ ):
315
+ super().__init__(target_model, draft_model, config, device)
316
+ self.num_branches = num_branches
317
+
318
+ @torch.no_grad()
319
+ def generate_tree(
320
+ self,
321
+ input_ids: torch.Tensor,
322
+ depth: int = 3,
323
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
324
+ """
325
+ Generate tree of speculative tokens.
326
+
327
+ Returns:
328
+ List of (tokens, probs) tuples for each branch
329
+ """
330
+ branches = []
331
+
332
+ # Generate multiple branches from draft model
333
+ for _ in range(self.num_branches):
334
+ tokens, probs = self.draft.speculate(
335
+ input_ids,
336
+ num_tokens=depth,
337
+ temperature=self.config.temperature,
338
+ )
339
+ branches.append((tokens, probs))
340
+
341
+ return branches
342
+
343
+ @torch.no_grad()
344
+ def generate_step(
345
+ self,
346
+ input_ids: torch.Tensor,
347
+ num_speculative: Optional[int] = None,
348
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
349
+ """Tree-based speculative step."""
350
+ num_spec = num_speculative or self.config.num_speculative_tokens
351
+
352
+ # Generate tree of speculations
353
+ branches = self.generate_tree(input_ids, num_spec)
354
+
355
+ best_tokens = None
356
+ best_accepted = 0
357
+
358
+ # Verify each branch and pick best
359
+ for draft_tokens, draft_probs in branches:
360
+ spec_input = torch.cat([input_ids, draft_tokens], dim=1)
361
+ _, target_logits, _, _ = self.target(spec_input)
362
+
363
+ target_probs = F.softmax(
364
+ target_logits[:, -num_spec-1:-1, :] / self.config.temperature, dim=-1
365
+ )
366
+ target_probs_selected = target_probs.gather(2, draft_tokens.unsqueeze(-1)).squeeze(-1)
367
+
368
+ _, num_accepted = self._rejection_sampling(
369
+ draft_probs,
370
+ target_probs_selected,
371
+ draft_tokens,
372
+ )
373
+
374
+ if num_accepted > best_accepted:
375
+ best_accepted = num_accepted
376
+ best_tokens = draft_tokens[:, :num_accepted]
377
+
378
+ if best_tokens is None or best_tokens.shape[1] == 0:
379
+ # Fallback: sample from target
380
+ _, logits, _, _ = self.target(input_ids)
381
+ probs = F.softmax(logits[:, -1, :] / self.config.temperature, dim=-1)
382
+ best_tokens = torch.multinomial(probs, num_samples=1)
383
+ best_accepted = 0
384
+
385
+ stats = {
386
+ "num_speculated": num_spec * self.num_branches,
387
+ "num_accepted": best_accepted,
388
+ "num_generated": best_tokens.shape[1],
389
+ "acceptance_rate": best_accepted / num_spec if num_spec > 0 else 0,
390
+ "num_branches": self.num_branches,
391
+ }
392
+
393
+ return best_tokens, stats
394
+
395
+
396
+ def benchmark_speculative_decoding(
397
+ target_model: nn.Module,
398
+ draft_model: nn.Module,
399
+ input_ids: torch.Tensor,
400
+ num_tokens: int = 100,
401
+ device: str = "cuda",
402
+ ) -> Dict[str, Any]:
403
+ """
404
+ Benchmark speculative decoding vs standard generation.
405
+ """
406
+ import time
407
+
408
+ # Standard generation
409
+ target_model.eval()
410
+ start = time.time()
411
+ with torch.no_grad():
412
+ standard_output = target_model.generate(
413
+ input_ids,
414
+ max_new_tokens=num_tokens,
415
+ )
416
+ standard_time = time.time() - start
417
+
418
+ # Speculative generation
419
+ decoder = SpeculativeDecoder(target_model, draft_model, device=device)
420
+ start = time.time()
421
+ spec_output, spec_stats = decoder.generate(
422
+ input_ids,
423
+ max_new_tokens=num_tokens,
424
+ )
425
+ spec_time = time.time() - start
426
+
427
+ return {
428
+ "standard": {
429
+ "time": standard_time,
430
+ "tokens_per_second": num_tokens / standard_time,
431
+ },
432
+ "speculative": {
433
+ "time": spec_time,
434
+ "tokens_per_second": spec_stats["tokens_per_second"],
435
+ "acceptance_rate": spec_stats["avg_acceptance_rate"],
436
+ "avg_tokens_per_step": spec_stats["avg_tokens_per_step"],
437
+ },
438
+ "speedup": standard_time / spec_time,
439
+ }