Florian valade commited on
Commit
72b2f6d
·
0 Parent(s):

Initial commit of standalone DSSD demo for HF Spaces

Browse files
README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DSSD Demo - Dynamic Self-Speculative Decoding
2
+
3
+ A Gradio demo showcasing early exit inference with color-coded token visualization.
4
+
5
+ ## Features
6
+
7
+ - **Color-coded tokens**: Each token shows which head/layer generated it
8
+ - **True early exit**: Actual speedup by stopping layer computation early
9
+ - **Compare mode**: Side-by-side comparison with full model
10
+ - **Model selection**: Switch between different DSSD models
11
+
12
+ ## Quick Start
13
+
14
+ ```bash
15
+ # Install dependencies
16
+ pip install -r requirements.txt
17
+
18
+ # Run the demo
19
+ python app.py
20
+ ```
21
+
22
+ Then open http://localhost:7860 in your browser.
23
+
24
+ ## Models
25
+
26
+ - **DSSD-Llama3-8B**: Llama 3 8B with 3 early exit heads at layers 8, 16, 24
27
+ - **DSSD-Qwen3-0.6B**: Qwen3 0.6B with 4 early exit heads at layers 5, 11, 16, 22
28
+
29
+ ## Color Legend
30
+
31
+ - 🔴 **Red**: Head 0 (earliest layer)
32
+ - 🟠 **Orange**: Head 1
33
+ - 🔵 **Teal/Blue**: Head 2-3
34
+ - 🟢 **Light Green**: Full model (all layers)
__pycache__/app.cpython-310.pyc ADDED
Binary file (7.34 kB). View file
 
app.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DSSD Demo - Dynamic Self-Speculative Decoding Visualization
3
+ Showcases early exit inference with color-coded tokens showing which head generated each token.
4
+ """
5
+
6
+ import gradio as gr
7
+ from pathlib import Path
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from src.inference import load_dssd_model, DSSDecoder, TokenInfo, StreamEvent
11
+
12
+ # Available models configuration
13
+ AVAILABLE_MODELS = {
14
+ "DSSD-Llama3-8B": {
15
+ "model_name": "meta-llama/Meta-Llama-3-8B",
16
+ "repo_id": "valcore/DSSD-Llama3-8B",
17
+ "local_path": "../checkpoints/llama3-8b-4bit",
18
+ },
19
+ "DSSD-Qwen3-0.6B": {
20
+ "model_name": "Qwen/Qwen3-0.6B",
21
+ "repo_id": "valcore/DSSD-Qwen3-0.6B",
22
+ "local_path": "../checkpoints/qwen3-0.6b",
23
+ },
24
+ }
25
+
26
+ # Color palette for exit heads (colorblind-friendly)
27
+ HEAD_COLORS = [
28
+ "#E63946", # Red - Head 0 (earliest)
29
+ "#F4A261", # Orange - Head 1
30
+ "#2A9D8F", # Teal - Head 2
31
+ "#457B9D", # Blue - Head 3
32
+ "#8338EC", # Purple - Head 4
33
+ ]
34
+ FULL_MODEL_COLOR = "#95D5B2" # Light green - Full model
35
+
36
+ # Global decoder cache
37
+ _decoder_cache = {}
38
+
39
+
40
+ def get_decoder(model_key: str) -> DSSDecoder:
41
+ """Get or load a decoder for the specified model."""
42
+ global _decoder_cache
43
+
44
+ if model_key in _decoder_cache:
45
+ return _decoder_cache[model_key]
46
+
47
+ model_info = AVAILABLE_MODELS[model_key]
48
+
49
+ # Try local path first (for development)
50
+ local_dir = Path(__file__).parent / model_info["local_path"]
51
+ heads_path = local_dir / "aux_heads.pt"
52
+ config_path = local_dir / "config.json"
53
+ calibration_path = local_dir / "calibration.json"
54
+
55
+ if heads_path.exists() and config_path.exists():
56
+ print(f"Loading model heads from local path: {local_dir}")
57
+ # calibration_path is optional, so no need to check its existence here
58
+ else:
59
+ # Download from HF Hub
60
+ repo_id = model_info["repo_id"]
61
+ print(f"Downloading model heads from {repo_id}...")
62
+ heads_path = hf_hub_download(repo_id=repo_id, filename="aux_heads.pt")
63
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
64
+ try:
65
+ calibration_path = hf_hub_download(
66
+ repo_id=repo_id, filename="calibration.json"
67
+ )
68
+ except Exception:
69
+ calibration_path = None # calibration.json is optional
70
+
71
+ decoder, tokenizer = load_dssd_model(
72
+ model_name=model_info["model_name"],
73
+ heads_path=str(heads_path),
74
+ config_path=str(config_path),
75
+ calibration_path=str(calibration_path) if calibration_path else None,
76
+ device="auto",
77
+ )
78
+
79
+ _decoder_cache[model_key] = decoder
80
+ return decoder
81
+
82
+
83
+ def tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> str:
84
+ """Convert token info list to color-coded HTML."""
85
+ html_parts = []
86
+
87
+ for token in tokens:
88
+ if token.exit_head is not None:
89
+ color = HEAD_COLORS[token.exit_head % len(HEAD_COLORS)]
90
+ layer = head_layers[token.exit_head]
91
+ title = f"Head {token.exit_head} (Layer {layer})"
92
+ else:
93
+ color = FULL_MODEL_COLOR
94
+ title = f"Full Model (Layer {token.exit_layer})"
95
+
96
+ # Escape HTML special chars
97
+ text = (
98
+ token.token_text.replace("&", "&")
99
+ .replace("<", "&lt;")
100
+ .replace(">", "&gt;")
101
+ )
102
+ text = text.replace("\n", "<br>").replace(" ", "&nbsp;")
103
+
104
+ html_parts.append(
105
+ f'<span style="background-color: {color}; padding: 2px 4px; '
106
+ f'border-radius: 3px; margin: 1px; display: inline-block;" title="{title}">{text}</span>'
107
+ )
108
+
109
+ # Wrap in container with word-wrap to prevent overflow
110
+ tokens_html = "".join(html_parts)
111
+ return f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{tokens_html}</div>"""
112
+
113
+
114
+ def drafted_tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> str:
115
+ """Convert drafted (pending) tokens to HTML with dashed border style."""
116
+ html_parts = []
117
+
118
+ for token in tokens:
119
+ if token.exit_head is not None:
120
+ color = HEAD_COLORS[token.exit_head % len(HEAD_COLORS)]
121
+ layer = head_layers[token.exit_head]
122
+ title = f"PENDING - Head {token.exit_head} (Layer {layer})"
123
+ else:
124
+ color = FULL_MODEL_COLOR
125
+ title = "PENDING - Full Model"
126
+
127
+ text = (
128
+ token.token_text.replace("&", "&amp;")
129
+ .replace("<", "&lt;")
130
+ .replace(">", "&gt;")
131
+ )
132
+ text = text.replace("\n", "<br>").replace(" ", "&nbsp;")
133
+
134
+ html_parts.append(
135
+ f'<span style="background-color: {color}; padding: 2px 4px; '
136
+ f"border-radius: 3px; margin: 1px; display: inline-block; "
137
+ f'border: 2px dashed #333; opacity: 0.7;" title="{title}">{text}</span>'
138
+ )
139
+
140
+ return "".join(html_parts)
141
+
142
+
143
+ def create_legend(head_layers: list[int]) -> str:
144
+ """Create HTML legend for the color scheme."""
145
+ legend_items = []
146
+ for i, layer in enumerate(head_layers):
147
+ color = HEAD_COLORS[i % len(HEAD_COLORS)]
148
+ legend_items.append(
149
+ f'<span style="background-color: {color}; padding: 4px 8px; '
150
+ f'border-radius: 4px; margin-right: 8px;">Head {i} (Layer {layer})</span>'
151
+ )
152
+ legend_items.append(
153
+ f'<span style="background-color: {FULL_MODEL_COLOR}; padding: 4px 8px; '
154
+ f'border-radius: 4px;">Full Model</span>'
155
+ )
156
+ return " ".join(legend_items)
157
+
158
+
159
+ def create_stats_html(result, label: str) -> str:
160
+ """Create statistics HTML display."""
161
+ return f"""
162
+ <div style="padding: 10px; background: #f5f5f5; border-radius: 8px; margin-top: 10px;">
163
+ <h4 style="margin: 0 0 10px 0;">{label} Statistics</h4>
164
+ <p><b>Time:</b> {result.total_time:.2f}s</p>
165
+ <p><b>Tokens/sec:</b> {result.tokens_per_second:.2f}</p>
166
+ <p><b>Avg Exit Layer:</b> {result.avg_exit_layer:.1f}</p>
167
+ <p><b>Exit Distribution:</b> {result.exit_distribution}</p>
168
+ </div>
169
+ """
170
+
171
+
172
+ def generate(
173
+ prompt: str,
174
+ model_key: str,
175
+ use_early_exit: bool,
176
+ accuracy_level: float,
177
+ max_tokens: int,
178
+ compare_mode: bool,
179
+ ):
180
+ """Main generation function for Gradio interface with streaming."""
181
+ try:
182
+ decoder = get_decoder(model_key, use_local=True)
183
+ except Exception as e:
184
+ error_msg = f"<p style='color: red;'>Error loading model: {e}</p>"
185
+ yield (error_msg, "", "", error_msg)
186
+ return
187
+
188
+ head_layers = decoder.model_config.head_layer_indices
189
+ legend = create_legend(head_layers)
190
+
191
+ # Get calibration accuracy levels
192
+ if decoder.calibration:
193
+ available_levels = decoder.calibration.accuracy_levels
194
+ closest_level = min(available_levels, key=lambda x: abs(x - accuracy_level))
195
+ else:
196
+ closest_level = accuracy_level
197
+
198
+ if compare_mode:
199
+ # Compare mode with streaming for early exit
200
+ # First, stream the early exit generation
201
+ final_ee_tokens = []
202
+ for event in decoder.generate_streaming(
203
+ prompt=prompt,
204
+ max_tokens=int(max_tokens),
205
+ accuracy_level=closest_level,
206
+ use_chat_template=True,
207
+ ):
208
+ validated_html = ""
209
+ if event.tokens:
210
+ validated_html = tokens_to_html(event.tokens, head_layers)
211
+ validated_html = validated_html.replace(
212
+ '<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">',
213
+ "",
214
+ ).rstrip("</div>")
215
+
216
+ drafted_html = ""
217
+ if event.drafted_tokens:
218
+ drafted_html = drafted_tokens_to_html(event.drafted_tokens, head_layers)
219
+
220
+ combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""
221
+
222
+ status = f"""
223
+ <div style="padding: 10px; background: #fff3cd; border-radius: 8px;">
224
+ <b>Early Exit:</b> {event.message} | <b>Full Model:</b> Waiting...
225
+ </div>
226
+ """
227
+
228
+ yield (
229
+ combined_html,
230
+ "<p style='color: #666;'>Waiting for early exit to complete...</p>",
231
+ status,
232
+ legend,
233
+ )
234
+ final_ee_tokens = event.tokens
235
+
236
+ # Now stream full model
237
+ final_full_tokens = []
238
+ for event in decoder.generate_full_model_streaming(
239
+ prompt=prompt,
240
+ max_tokens=int(max_tokens),
241
+ use_chat_template=True,
242
+ ):
243
+ html_full = tokens_to_html(event.tokens, head_layers)
244
+ status = f"""
245
+ <div style="padding: 10px; background: #fff3cd; border-radius: 8px;">
246
+ <b>Full Model:</b> {event.message}
247
+ </div>
248
+ """
249
+ yield (
250
+ tokens_to_html(final_ee_tokens, head_layers),
251
+ html_full,
252
+ status,
253
+ legend,
254
+ )
255
+ final_full_tokens = event.tokens
256
+
257
+ # Final stats
258
+ result_ee = decoder.generate(
259
+ prompt=prompt,
260
+ max_tokens=int(max_tokens),
261
+ use_early_exit=True,
262
+ accuracy_level=closest_level,
263
+ use_chat_template=True,
264
+ )
265
+ result_full = decoder.generate(
266
+ prompt=prompt,
267
+ max_tokens=int(max_tokens),
268
+ use_early_exit=False,
269
+ use_chat_template=True,
270
+ )
271
+
272
+ html_ee = tokens_to_html(result_ee.tokens, head_layers)
273
+ html_full = tokens_to_html(result_full.tokens, head_layers)
274
+
275
+ speedup = (
276
+ result_ee.tokens_per_second / result_full.tokens_per_second
277
+ if result_full.tokens_per_second > 0
278
+ else 0
279
+ )
280
+ stats = f"""
281
+ <div style="padding: 15px; background: #e8f5e9; border-radius: 8px;">
282
+ <h3 style="margin: 0 0 10px 0;">🚀 Speedup: {speedup:.2f}x</h3>
283
+ <div style="display: flex; gap: 20px;">
284
+ <div style="flex: 1; padding: 10px; background: white; border-radius: 8px;">
285
+ <h4>Early Exit</h4>
286
+ <p><b>Time:</b> {result_ee.total_time:.2f}s | <b>Tokens/sec:</b> {result_ee.tokens_per_second:.2f}</p>
287
+ <p><b>Avg Exit Layer:</b> {result_ee.avg_exit_layer:.1f}</p>
288
+ </div>
289
+ <div style="flex: 1; padding: 10px; background: white; border-radius: 8px;">
290
+ <h4>Full Model</h4>
291
+ <p><b>Time:</b> {result_full.total_time:.2f}s | <b>Tokens/sec:</b> {result_full.tokens_per_second:.2f}</p>
292
+ <p><b>Avg Exit Layer:</b> {result_full.avg_exit_layer:.1f}</p>
293
+ </div>
294
+ </div>
295
+ </div>
296
+ """
297
+ yield (html_ee, html_full, stats, legend)
298
+
299
+ elif use_early_exit:
300
+ # STREAMING mode for early exit - show draft/verify process
301
+ for event in decoder.generate_streaming(
302
+ prompt=prompt,
303
+ max_tokens=int(max_tokens),
304
+ accuracy_level=closest_level,
305
+ use_chat_template=True,
306
+ ):
307
+ # Build HTML showing validated + drafted tokens
308
+ validated_html = ""
309
+ if event.tokens:
310
+ validated_html = tokens_to_html(event.tokens, head_layers)
311
+ # Remove the outer div to combine with drafted
312
+ validated_html = validated_html.replace(
313
+ '<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">',
314
+ "",
315
+ ).rstrip("</div>")
316
+
317
+ drafted_html = ""
318
+ if event.drafted_tokens:
319
+ drafted_html = drafted_tokens_to_html(event.drafted_tokens, head_layers)
320
+
321
+ # Combine
322
+ combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""
323
+
324
+ # Status message
325
+ status = f"""
326
+ <div style="padding: 10px; background: #fff3cd; border-radius: 8px; margin-top: 5px;">
327
+ <b>Status:</b> {event.message}
328
+ </div>
329
+ """
330
+
331
+ yield (combined_html, "", status, legend)
332
+
333
+ # Final stats after streaming completes
334
+ # Re-run to get final stats (or we could track during streaming)
335
+ result = decoder.generate(
336
+ prompt=prompt,
337
+ max_tokens=int(max_tokens),
338
+ use_early_exit=True,
339
+ accuracy_level=closest_level,
340
+ use_chat_template=True,
341
+ )
342
+ html = tokens_to_html(result.tokens, head_layers)
343
+ stats = f"""
344
+ <div style="padding: 15px; background: #f5f5f5; border-radius: 8px;">
345
+ <h4 style="margin: 0 0 10px 0;">Early Exit Statistics (Final)</h4>
346
+ <p><b>Tokens:</b> {len(result.tokens)} | <b>Tokens/sec:</b> {result.tokens_per_second:.2f} | <b>Avg Exit Layer:</b> {result.avg_exit_layer:.1f}</p>
347
+ <p><b>Exit Distribution:</b> {result.exit_distribution}</p>
348
+ </div>
349
+ """
350
+ yield (html, "", stats, legend)
351
+
352
+ else:
353
+ # Full model mode (streaming)
354
+ for event in decoder.generate_full_model_streaming(
355
+ prompt=prompt,
356
+ max_tokens=int(max_tokens),
357
+ use_chat_template=True,
358
+ ):
359
+ html = tokens_to_html(event.tokens, head_layers)
360
+ status = f"""
361
+ <div style="padding: 10px; background: #fff3cd; border-radius: 8px;">
362
+ <b>Full Model:</b> {event.message}
363
+ </div>
364
+ """
365
+ yield (html, "", status, legend)
366
+
367
+ # Final stats
368
+ result = decoder.generate(
369
+ prompt=prompt,
370
+ max_tokens=int(max_tokens),
371
+ use_early_exit=False,
372
+ use_chat_template=True,
373
+ )
374
+ html = tokens_to_html(result.tokens, head_layers)
375
+ stats = f"""
376
+ <div style="padding: 15px; background: #f5f5f5; border-radius: 8px;">
377
+ <h4 style="margin: 0 0 10px 0;">Full Model Statistics</h4>
378
+ <p><b>Tokens:</b> {len(result.tokens)} | <b>Time:</b> {result.total_time:.2f}s | <b>Tokens/sec:</b> {result.tokens_per_second:.2f}</p>
379
+ </div>
380
+ """
381
+ yield (html, "", stats, legend)
382
+
383
+
384
+ def build_demo():
385
+ """Build the Gradio demo interface."""
386
+ with gr.Blocks(title="DSSD Demo", theme=gr.themes.Soft()) as demo:
387
+ gr.Markdown("""
388
+ # 🚀 Dynamic Self-Speculative Decoding (DSSD) Demo
389
+
390
+ This demo showcases **early exit inference** where tokens can be generated from intermediate
391
+ layers when the model is confident, resulting in faster generation.
392
+
393
+ **Colors indicate which layer generated each token** - earlier layers = faster!
394
+ """)
395
+
396
+ with gr.Row():
397
+ with gr.Column(scale=1):
398
+ prompt = gr.Textbox(
399
+ label="Prompt",
400
+ placeholder="Enter your prompt here...",
401
+ lines=3,
402
+ value="What is machine learning in simple terms?",
403
+ )
404
+
405
+ model_selector = gr.Dropdown(
406
+ label="Model",
407
+ choices=list(AVAILABLE_MODELS.keys()),
408
+ value=list(AVAILABLE_MODELS.keys())[0],
409
+ )
410
+
411
+ with gr.Row():
412
+ use_early_exit = gr.Checkbox(label="Enable Early Exit", value=True)
413
+ compare_mode = gr.Checkbox(label="Compare Mode", value=False)
414
+
415
+ accuracy_level = gr.Slider(
416
+ label="Accuracy Level",
417
+ minimum=0.6,
418
+ maximum=0.99,
419
+ step=0.05,
420
+ value=0.75,
421
+ info="Higher = more accurate but slower",
422
+ )
423
+
424
+ max_tokens = gr.Slider(
425
+ label="Max Tokens",
426
+ minimum=10,
427
+ maximum=200,
428
+ step=10,
429
+ value=50,
430
+ )
431
+
432
+ generate_btn = gr.Button("Generate", variant="primary")
433
+
434
+ # Legend (full width, above outputs)
435
+ legend_html = gr.HTML()
436
+
437
+ # Outputs section - dynamic based on compare mode
438
+ with gr.Row():
439
+ with gr.Column(scale=1):
440
+ gr.Markdown("### Generated Output")
441
+ output_ee = gr.HTML()
442
+
443
+ with gr.Column(scale=1, visible=False) as compare_col:
444
+ gr.Markdown("### Full Model (Comparison)")
445
+ output_full = gr.HTML()
446
+
447
+ # Stats (full width)
448
+ stats_html = gr.HTML()
449
+
450
+ def update_visibility(compare):
451
+ return gr.update(visible=compare)
452
+
453
+ compare_mode.change(
454
+ fn=update_visibility,
455
+ inputs=[compare_mode],
456
+ outputs=[compare_col],
457
+ )
458
+
459
+ generate_btn.click(
460
+ fn=generate,
461
+ inputs=[
462
+ prompt,
463
+ model_selector,
464
+ use_early_exit,
465
+ accuracy_level,
466
+ max_tokens,
467
+ compare_mode,
468
+ ],
469
+ outputs=[output_ee, output_full, stats_html, legend_html],
470
+ )
471
+
472
+ return demo
473
+
474
+
475
+ if __name__ == "__main__":
476
+ demo = build_demo()
477
+ demo.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.37.0
3
+ gradio>=4.0.0
4
+ bitsandbytes>=0.41.0
5
+ accelerate>=0.25.0
6
+ huggingface_hub>=0.19.0
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Demo package
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
src/__pycache__/inference.cpython-310.pyc ADDED
Binary file (15 kB). View file
 
src/__pycache__/model_adapters.cpython-310.pyc ADDED
Binary file (5.12 kB). View file
 
src/__pycache__/model_config.cpython-310.pyc ADDED
Binary file (3.05 kB). View file
 
src/inference.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # True Early Exit Inference with Dynamic Self-Speculative Decoding
2
+ # Provides actual speedup by stopping layer computation early
3
+
4
+ from dataclasses import dataclass, asdict
5
+ from typing import Dict, List, Optional, Tuple, Callable
6
+ from collections import defaultdict
7
+ import time
8
+ import copy
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoTokenizer,
16
+ AutoConfig,
17
+ BitsAndBytesConfig,
18
+ )
19
+
20
+ from .model_adapters import get_adapter, ModelAdapter
21
+ from .model_config import ModelConfig, CalibrationResult
22
+
23
+
24
+ def compute_entropy(logits: torch.Tensor, dim: int = -1) -> torch.Tensor:
25
+ """Compute entropy - lower = more confident."""
26
+ probs = F.softmax(logits, dim=dim)
27
+ log_probs = F.log_softmax(logits, dim=dim)
28
+ return -torch.sum(probs * log_probs, dim=dim)
29
+
30
+
31
+ class AuxiliaryHead(nn.Module):
32
+ """Auxiliary head for early exit prediction."""
33
+
34
+ def __init__(
35
+ self, hidden_size: int, vocab_size: int, norm_layer: Optional[nn.Module] = None
36
+ ):
37
+ super().__init__()
38
+ self.norm = norm_layer if norm_layer is not None else nn.Identity()
39
+ self.linear = nn.Linear(hidden_size, vocab_size, bias=False)
40
+
41
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
42
+ return self.linear(self.norm(hidden_states))
43
+
44
+
45
+ @dataclass
46
+ class TokenInfo:
47
+ """Information about a generated token for visualization."""
48
+
49
+ token_id: int
50
+ token_text: str
51
+ exit_head: Optional[int] # None = full model
52
+ exit_layer: int
53
+ uncertainty: float
54
+
55
+
56
+ @dataclass
57
+ class StreamEvent:
58
+ """Event for streaming generation updates."""
59
+
60
+ event_type: str # "draft", "verify_start", "accept", "reject", "full_model"
61
+ tokens: List[TokenInfo] # All tokens so far (validated)
62
+ drafted_tokens: List[TokenInfo] # Currently drafted (pending verification)
63
+ message: str # Human-readable status
64
+
65
+
66
+ @dataclass
67
+ class GenerationResult:
68
+ """Complete generation result with token-level information."""
69
+
70
+ text: str
71
+ tokens: List[TokenInfo]
72
+ total_time: float
73
+ tokens_per_second: float
74
+ avg_exit_layer: float
75
+ exit_distribution: Dict[str, int]
76
+
77
+
78
+ class DSSDecoder:
79
+ """
80
+ Dynamic Self-Speculative Decoder with TRUE early exit.
81
+ Actually stops computation at intermediate layers for speedup.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ model: AutoModelForCausalLM,
87
+ adapter: ModelAdapter,
88
+ aux_heads: nn.ModuleList,
89
+ tokenizer: AutoTokenizer,
90
+ model_config: ModelConfig,
91
+ calibration: Optional[CalibrationResult] = None,
92
+ device: str = "cuda",
93
+ ):
94
+ self.model = model
95
+ self.adapter = adapter
96
+ self.aux_heads = aux_heads
97
+ self.tokenizer = tokenizer
98
+ self.model_config = model_config
99
+ self.calibration = calibration
100
+ self.device = device
101
+ self.uncertainty_fn = compute_entropy
102
+
103
+ def generate(
104
+ self,
105
+ prompt: str,
106
+ max_tokens: int = 100,
107
+ use_early_exit: bool = True,
108
+ accuracy_level: float = 0.75,
109
+ use_chat_template: bool = True,
110
+ ) -> GenerationResult:
111
+ """
112
+ Generate text with optional early exit.
113
+ Returns detailed token-level information for visualization.
114
+ """
115
+ # Format prompt - check if tokenizer has a chat template set
116
+ if (
117
+ use_chat_template
118
+ and hasattr(self.tokenizer, "chat_template")
119
+ and self.tokenizer.chat_template is not None
120
+ ):
121
+ try:
122
+ messages = [{"role": "user", "content": prompt}]
123
+ formatted = self.tokenizer.apply_chat_template(
124
+ messages, add_generation_prompt=True, tokenize=False
125
+ )
126
+ input_ids = self.tokenizer.encode(formatted, return_tensors="pt").to(
127
+ self.device
128
+ )
129
+ except Exception:
130
+ # Fallback to raw prompt if chat template fails
131
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
132
+ self.device
133
+ )
134
+ else:
135
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
136
+ self.device
137
+ )
138
+
139
+ # Get thresholds
140
+ thresholds = {}
141
+ if use_early_exit and self.calibration:
142
+ thresholds = self.calibration.get_thresholds_for_level(accuracy_level)
143
+
144
+ # Generate
145
+ start_time = time.time()
146
+
147
+ if use_early_exit:
148
+ tokens = self._generate_with_early_exit(input_ids, max_tokens, thresholds)
149
+ else:
150
+ tokens = self._generate_full_model(input_ids, max_tokens)
151
+
152
+ end_time = time.time()
153
+ total_time = end_time - start_time
154
+
155
+ # Build result
156
+ text = "".join(t.token_text for t in tokens)
157
+ exit_dist = defaultdict(int)
158
+ layer_sum = 0
159
+
160
+ for t in tokens:
161
+ key = str(t.exit_head) if t.exit_head is not None else "full"
162
+ exit_dist[key] += 1
163
+ layer_sum += t.exit_layer
164
+
165
+ avg_layer = (
166
+ layer_sum / len(tokens) if tokens else self.model_config.num_hidden_layers
167
+ )
168
+
169
+ return GenerationResult(
170
+ text=text,
171
+ tokens=tokens,
172
+ total_time=total_time,
173
+ tokens_per_second=len(tokens) / total_time if total_time > 0 else 0,
174
+ avg_exit_layer=avg_layer,
175
+ exit_distribution=dict(exit_dist),
176
+ )
177
+
178
+ def generate_streaming(
179
+ self,
180
+ prompt: str,
181
+ max_tokens: int = 100,
182
+ accuracy_level: float = 0.75,
183
+ use_chat_template: bool = True,
184
+ max_draft_length: int = 5,
185
+ ):
186
+ """
187
+ Generate with streaming - yields events showing draft/verify process.
188
+ Each event shows current validated tokens and pending drafted tokens.
189
+ """
190
+ # Format prompt
191
+ if (
192
+ use_chat_template
193
+ and hasattr(self.tokenizer, "chat_template")
194
+ and self.tokenizer.chat_template is not None
195
+ ):
196
+ try:
197
+ messages = [{"role": "user", "content": prompt}]
198
+ formatted = self.tokenizer.apply_chat_template(
199
+ messages, add_generation_prompt=True, tokenize=False
200
+ )
201
+ input_ids = self.tokenizer.encode(formatted, return_tensors="pt").to(
202
+ self.device
203
+ )
204
+ except Exception:
205
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
206
+ self.device
207
+ )
208
+ else:
209
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
210
+ self.device
211
+ )
212
+
213
+ # Get thresholds
214
+ thresholds = {}
215
+ if self.calibration:
216
+ thresholds = self.calibration.get_thresholds_for_level(accuracy_level)
217
+
218
+ validated_tokens = []
219
+ current_ids = input_ids.clone()
220
+ num_layers = self.adapter.get_num_layers()
221
+ head_layers = self.model_config.head_layer_indices
222
+
223
+ while len(validated_tokens) < max_tokens:
224
+ # ============================================================
225
+ # DRAFT PHASE: Generate tokens using early exit heads
226
+ # ============================================================
227
+ drafted_tokens = []
228
+ draft_ids = current_ids.clone()
229
+
230
+ for _ in range(max_draft_length):
231
+ if len(validated_tokens) + len(drafted_tokens) >= max_tokens:
232
+ break
233
+
234
+ draft_result = self._draft_single_token(draft_ids, thresholds)
235
+
236
+ if draft_result is None:
237
+ break
238
+
239
+ token_id, exit_head, exit_layer, uncertainty = draft_result
240
+
241
+ if token_id == self.tokenizer.eos_token_id:
242
+ break
243
+
244
+ token_text = self.tokenizer.decode([token_id])
245
+ drafted_token = TokenInfo(
246
+ token_id=token_id,
247
+ token_text=token_text,
248
+ exit_head=exit_head,
249
+ exit_layer=exit_layer,
250
+ uncertainty=uncertainty,
251
+ )
252
+ drafted_tokens.append(drafted_token)
253
+ draft_ids = torch.cat(
254
+ [draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
255
+ )
256
+
257
+ # Yield draft event
258
+ yield StreamEvent(
259
+ event_type="draft",
260
+ tokens=list(validated_tokens),
261
+ drafted_tokens=list(drafted_tokens),
262
+ message=f"Drafting token {len(drafted_tokens)} using Head {exit_head}",
263
+ )
264
+
265
+ # ============================================================
266
+ # VERIFY PHASE
267
+ # ============================================================
268
+ if drafted_tokens:
269
+ yield StreamEvent(
270
+ event_type="verify_start",
271
+ tokens=list(validated_tokens),
272
+ drafted_tokens=list(drafted_tokens),
273
+ message=f"Verifying {len(drafted_tokens)} drafted tokens...",
274
+ )
275
+
276
+ with torch.no_grad():
277
+ outputs = self.model(draft_ids, use_cache=False)
278
+ verify_logits = outputs.logits
279
+
280
+ start_pos = current_ids.shape[1] - 1
281
+
282
+ for i, drafted_token in enumerate(drafted_tokens):
283
+ verify_pos = start_pos + i
284
+ verified_token_id = torch.argmax(
285
+ verify_logits[0, verify_pos, :]
286
+ ).item()
287
+
288
+ if drafted_token.token_id == verified_token_id:
289
+ # Accept
290
+ validated_tokens.append(drafted_token)
291
+ current_ids = torch.cat(
292
+ [
293
+ current_ids,
294
+ torch.tensor(
295
+ [[drafted_token.token_id]], device=self.device
296
+ ),
297
+ ],
298
+ dim=1,
299
+ )
300
+ yield StreamEvent(
301
+ event_type="accept",
302
+ tokens=list(validated_tokens),
303
+ drafted_tokens=[],
304
+ message=f"✓ Accepted '{drafted_token.token_text}'",
305
+ )
306
+ else:
307
+ # Reject - use full model's token
308
+ token_text = self.tokenizer.decode([verified_token_id])
309
+ corrected_token = TokenInfo(
310
+ token_id=verified_token_id,
311
+ token_text=token_text,
312
+ exit_head=None,
313
+ exit_layer=num_layers,
314
+ uncertainty=0.0,
315
+ )
316
+ validated_tokens.append(corrected_token)
317
+ current_ids = torch.cat(
318
+ [
319
+ current_ids,
320
+ torch.tensor([[verified_token_id]], device=self.device),
321
+ ],
322
+ dim=1,
323
+ )
324
+ yield StreamEvent(
325
+ event_type="reject",
326
+ tokens=list(validated_tokens),
327
+ drafted_tokens=[],
328
+ message=f"✗ Rejected '{drafted_token.token_text}' → '{token_text}'",
329
+ )
330
+ break
331
+ else:
332
+ # No drafts - generate with full model
333
+ with torch.no_grad():
334
+ outputs = self.model(current_ids, use_cache=False)
335
+ logits = outputs.logits
336
+
337
+ token_id = torch.argmax(logits[0, -1, :]).item()
338
+
339
+ if token_id == self.tokenizer.eos_token_id:
340
+ break
341
+
342
+ token_text = self.tokenizer.decode([token_id])
343
+ full_token = TokenInfo(
344
+ token_id=token_id,
345
+ token_text=token_text,
346
+ exit_head=None,
347
+ exit_layer=num_layers,
348
+ uncertainty=0.0,
349
+ )
350
+ validated_tokens.append(full_token)
351
+ current_ids = torch.cat(
352
+ [current_ids, torch.tensor([[token_id]], device=self.device)], dim=1
353
+ )
354
+ yield StreamEvent(
355
+ event_type="full_model",
356
+ tokens=list(validated_tokens),
357
+ drafted_tokens=[],
358
+ message=f"Full model: '{token_text}'",
359
+ )
360
+
361
+ if (
362
+ validated_tokens
363
+ and validated_tokens[-1].token_id == self.tokenizer.eos_token_id
364
+ ):
365
+ break
366
+
367
+ def _generate_with_early_exit(
368
+ self,
369
+ input_ids: torch.Tensor,
370
+ max_tokens: int,
371
+ thresholds: Dict[int, float],
372
+ max_draft_length: int = 5,
373
+ ) -> List[TokenInfo]:
374
+ """
375
+ Speculative decoding with early exit heads.
376
+
377
+ GUARANTEES same output as full model by:
378
+ 1. DRAFT: Generate tokens using early exit heads (fast, partial compute)
379
+ 2. VERIFY: When full model needed, verify ALL drafted tokens
380
+ 3. ACCEPT: Keep matching tokens, take model's token at first mismatch
381
+ """
382
+ tokens = []
383
+ current_ids = input_ids.clone()
384
+ num_layers = self.adapter.get_num_layers()
385
+ head_layers = self.model_config.head_layer_indices
386
+
387
+ while len(tokens) < max_tokens:
388
+ # ============================================================
389
+ # DRAFT PHASE: Generate tokens using early exit heads
390
+ # ============================================================
391
+ drafted_tokens = [] # List of (token_id, exit_head, exit_layer, uncertainty)
392
+ draft_ids = current_ids.clone()
393
+
394
+ for _ in range(max_draft_length):
395
+ if len(tokens) + len(drafted_tokens) >= max_tokens:
396
+ break
397
+
398
+ # Try to draft a token using early exit
399
+ draft_result = self._draft_single_token(draft_ids, thresholds)
400
+
401
+ if draft_result is None:
402
+ # No head was confident enough - need to verify
403
+ break
404
+
405
+ token_id, exit_head, exit_layer, uncertainty = draft_result
406
+
407
+ if token_id == self.tokenizer.eos_token_id:
408
+ break
409
+
410
+ drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
411
+ draft_ids = torch.cat(
412
+ [draft_ids, torch.tensor([[token_id]], device=self.device)], dim=1
413
+ )
414
+
415
+ # ============================================================
416
+ # VERIFY PHASE: Run full model to verify drafted tokens
417
+ # ============================================================
418
+ if drafted_tokens:
419
+ # Run full model on current_ids + all drafted tokens
420
+ with torch.no_grad():
421
+ outputs = self.model(draft_ids, use_cache=False)
422
+ verify_logits = outputs.logits
423
+
424
+ # Verify each drafted token
425
+ start_pos = current_ids.shape[1] - 1 # Position before drafting
426
+
427
+ for i, (drafted_token, exit_head, exit_layer, uncertainty) in enumerate(
428
+ drafted_tokens
429
+ ):
430
+ verify_pos = start_pos + i
431
+ verified_token = torch.argmax(
432
+ verify_logits[0, verify_pos, :]
433
+ ).item()
434
+
435
+ if drafted_token == verified_token:
436
+ # Token matches - accept it with early exit info
437
+ token_text = self.tokenizer.decode([drafted_token])
438
+ tokens.append(
439
+ TokenInfo(
440
+ token_id=drafted_token,
441
+ token_text=token_text,
442
+ exit_head=exit_head,
443
+ exit_layer=exit_layer,
444
+ uncertainty=uncertainty,
445
+ )
446
+ )
447
+ current_ids = torch.cat(
448
+ [
449
+ current_ids,
450
+ torch.tensor([[drafted_token]], device=self.device),
451
+ ],
452
+ dim=1,
453
+ )
454
+ else:
455
+ # Mismatch - use full model's token
456
+ token_text = self.tokenizer.decode([verified_token])
457
+ tokens.append(
458
+ TokenInfo(
459
+ token_id=verified_token,
460
+ token_text=token_text,
461
+ exit_head=None, # Full model
462
+ exit_layer=num_layers,
463
+ uncertainty=0.0,
464
+ )
465
+ )
466
+ current_ids = torch.cat(
467
+ [
468
+ current_ids,
469
+ torch.tensor([[verified_token]], device=self.device),
470
+ ],
471
+ dim=1,
472
+ )
473
+ # Stop - discard remaining drafted tokens
474
+ break
475
+ else:
476
+ # No tokens drafted - generate one with full model
477
+ with torch.no_grad():
478
+ outputs = self.model(current_ids, use_cache=False)
479
+ logits = outputs.logits
480
+
481
+ token_id = torch.argmax(logits[0, -1, :]).item()
482
+
483
+ if token_id == self.tokenizer.eos_token_id:
484
+ break
485
+
486
+ token_text = self.tokenizer.decode([token_id])
487
+ tokens.append(
488
+ TokenInfo(
489
+ token_id=token_id,
490
+ token_text=token_text,
491
+ exit_head=None,
492
+ exit_layer=num_layers,
493
+ uncertainty=0.0,
494
+ )
495
+ )
496
+ current_ids = torch.cat(
497
+ [current_ids, torch.tensor([[token_id]], device=self.device)], dim=1
498
+ )
499
+
500
+ # Check for EOS in accepted tokens
501
+ if tokens and tokens[-1].token_id == self.tokenizer.eos_token_id:
502
+ break
503
+
504
+ return tokens
505
+
506
+ def _draft_single_token(
507
+ self,
508
+ input_ids: torch.Tensor,
509
+ thresholds: Dict[int, float],
510
+ ) -> Optional[Tuple[int, int, int, float]]:
511
+ """
512
+ Try to draft a single token using early exit heads.
513
+ Returns (token_id, exit_head, exit_layer, uncertainty) if confident enough.
514
+ Returns None if no head is confident enough (need full model verification).
515
+ """
516
+ device = input_ids.device
517
+ seq_len = input_ids.shape[1]
518
+ head_layers = self.model_config.head_layer_indices
519
+
520
+ # Position IDs
521
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(
522
+ 0
523
+ )
524
+
525
+ # Get embeddings
526
+ hidden_states = self.adapter.get_embed_tokens(input_ids)
527
+
528
+ # Get rotary embeddings
529
+ position_embeddings = self.adapter.get_position_embeddings(
530
+ hidden_states, position_ids
531
+ )
532
+
533
+ # Sort heads by layer
534
+ sorted_heads = sorted(enumerate(head_layers), key=lambda x: x[1])
535
+
536
+ # Iterate through layers
537
+ with torch.no_grad():
538
+ for layer_idx, layer in enumerate(self.adapter.get_layers()):
539
+ hidden_states, _ = self.adapter.forward_layer(
540
+ layer=layer,
541
+ hidden_states=hidden_states,
542
+ position_ids=position_ids,
543
+ attention_mask=None,
544
+ past_key_value=None,
545
+ position_embeddings=position_embeddings,
546
+ use_cache=False,
547
+ )
548
+
549
+ # Check if this is a head checkpoint
550
+ for head_idx, head_layer in sorted_heads:
551
+ if layer_idx == head_layer:
552
+ # Run aux head on last position
553
+ aux_head = self.aux_heads[head_idx]
554
+ head_device = next(aux_head.parameters()).device
555
+ head_input = hidden_states[:, -1:, :].to(head_device)
556
+ head_logits = aux_head(head_input)
557
+ uncertainty = self.uncertainty_fn(
558
+ head_logits[:, -1, :], dim=-1
559
+ ).item()
560
+
561
+ # Check threshold - if confident, return drafted token
562
+ if (
563
+ head_idx in thresholds
564
+ and uncertainty < thresholds[head_idx]
565
+ ):
566
+ token_id = torch.argmax(head_logits[0, -1, :]).item()
567
+ return (token_id, head_idx, layer_idx, uncertainty)
568
+
569
+ # No head was confident enough - need full model verification
570
+ return None
571
+
572
+ def _generate_full_model(
573
+ self,
574
+ input_ids: torch.Tensor,
575
+ max_tokens: int,
576
+ ) -> List[TokenInfo]:
577
+ """Generate using full model (no early exit)."""
578
+ tokens = []
579
+ current_ids = input_ids.clone()
580
+ num_layers = self.adapter.get_num_layers()
581
+
582
+ for _ in range(max_tokens):
583
+ with torch.no_grad():
584
+ outputs = self.model(current_ids, use_cache=False)
585
+ logits = outputs.logits
586
+
587
+ token_id = torch.argmax(logits[0, -1, :]).item()
588
+
589
+ if token_id == self.tokenizer.eos_token_id:
590
+ break
591
+
592
+ token_text = self.tokenizer.decode([token_id])
593
+ tokens.append(
594
+ TokenInfo(
595
+ token_id=token_id,
596
+ token_text=token_text,
597
+ exit_head=None,
598
+ exit_layer=num_layers,
599
+ uncertainty=0.0,
600
+ )
601
+ )
602
+
603
+ current_ids = torch.cat(
604
+ [current_ids, torch.tensor([[token_id]], device=self.device)], dim=1
605
+ )
606
+
607
+ return tokens
608
+
609
+ def generate_full_model_streaming(
610
+ self,
611
+ prompt: str,
612
+ max_tokens: int = 100,
613
+ use_chat_template: bool = True,
614
+ ):
615
+ """
616
+ Generate with full model in streaming mode - yields each token as generated.
617
+ """
618
+ # Format prompt
619
+ if (
620
+ use_chat_template
621
+ and hasattr(self.tokenizer, "chat_template")
622
+ and self.tokenizer.chat_template is not None
623
+ ):
624
+ try:
625
+ messages = [{"role": "user", "content": prompt}]
626
+ formatted = self.tokenizer.apply_chat_template(
627
+ messages, add_generation_prompt=True, tokenize=False
628
+ )
629
+ input_ids = self.tokenizer.encode(formatted, return_tensors="pt").to(
630
+ self.device
631
+ )
632
+ except Exception:
633
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
634
+ self.device
635
+ )
636
+ else:
637
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
638
+ self.device
639
+ )
640
+
641
+ tokens = []
642
+ current_ids = input_ids.clone()
643
+ num_layers = self.adapter.get_num_layers()
644
+
645
+ for i in range(max_tokens):
646
+ with torch.no_grad():
647
+ outputs = self.model(current_ids, use_cache=False)
648
+ logits = outputs.logits
649
+
650
+ token_id = torch.argmax(logits[0, -1, :]).item()
651
+
652
+ if token_id == self.tokenizer.eos_token_id:
653
+ break
654
+
655
+ token_text = self.tokenizer.decode([token_id])
656
+ token_info = TokenInfo(
657
+ token_id=token_id,
658
+ token_text=token_text,
659
+ exit_head=None,
660
+ exit_layer=num_layers,
661
+ uncertainty=0.0,
662
+ )
663
+ tokens.append(token_info)
664
+
665
+ current_ids = torch.cat(
666
+ [current_ids, torch.tensor([[token_id]], device=self.device)], dim=1
667
+ )
668
+
669
+ yield StreamEvent(
670
+ event_type="full_model",
671
+ tokens=list(tokens),
672
+ drafted_tokens=[],
673
+ message=f"Token {i + 1}: '{token_text}'",
674
+ )
675
+
676
+
677
+ def load_dssd_model(
678
+ model_name: str,
679
+ heads_path: str,
680
+ config_path: str,
681
+ calibration_path: Optional[str] = None,
682
+ device: str = "auto",
683
+ ) -> Tuple[DSSDecoder, AutoTokenizer]:
684
+ """
685
+ Load a DSSD model from HuggingFace Hub or local paths.
686
+
687
+ Args:
688
+ model_name: HuggingFace model name (e.g., "meta-llama/Meta-Llama-3-8B")
689
+ heads_path: Path to aux_heads.pt
690
+ config_path: Path to config.json
691
+ calibration_path: Optional path to calibration.json
692
+ device: Device to load on
693
+
694
+ Returns:
695
+ decoder: DSSDecoder ready for generation
696
+ tokenizer: Tokenizer for the model
697
+ """
698
+ # Load config
699
+ model_config = ModelConfig.from_json(config_path)
700
+
701
+ # Load calibration if provided
702
+ calibration = None
703
+ if calibration_path:
704
+ calibration = CalibrationResult.from_json(calibration_path)
705
+
706
+ # Quantization config
707
+ quant_config = None
708
+ if model_config.quantization == "4bit":
709
+ quant_config = BitsAndBytesConfig(
710
+ load_in_4bit=True,
711
+ bnb_4bit_compute_dtype=torch.bfloat16
712
+ if torch.cuda.is_bf16_supported()
713
+ else torch.float32,
714
+ bnb_4bit_quant_type="nf4",
715
+ bnb_4bit_use_double_quant=True,
716
+ )
717
+ elif model_config.quantization == "8bit":
718
+ quant_config = BitsAndBytesConfig(load_in_8bit=True)
719
+
720
+ # Load base model
721
+ model = AutoModelForCausalLM.from_pretrained(
722
+ model_name,
723
+ quantization_config=quant_config,
724
+ torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
725
+ device_map=device,
726
+ )
727
+ model.eval()
728
+
729
+ # Load tokenizer
730
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
731
+ if tokenizer.pad_token is None:
732
+ tokenizer.pad_token = tokenizer.eos_token
733
+
734
+ # Get adapter
735
+ adapter = get_adapter(model)
736
+
737
+ # Determine the norm type and create aux heads WITHOUT deepcopy (to avoid accelerate hooks)
738
+ aux_heads = nn.ModuleList()
739
+
740
+ # Get norm config from model
741
+ norm_eps = 1e-6
742
+ if hasattr(model.config, "rms_norm_eps"):
743
+ norm_eps = model.config.rms_norm_eps
744
+ elif hasattr(model.config, "layer_norm_eps"):
745
+ norm_eps = model.config.layer_norm_eps
746
+
747
+ for _ in range(model_config.num_heads):
748
+ # Create fresh RMSNorm (or LayerNorm) without accelerate hooks
749
+ norm_layer = nn.RMSNorm(model_config.hidden_size, eps=norm_eps)
750
+
751
+ head = AuxiliaryHead(
752
+ model_config.hidden_size,
753
+ model_config.vocab_size,
754
+ norm_layer,
755
+ )
756
+ aux_heads.append(head)
757
+
758
+ # Load trained weights (this will properly set the norm weights)
759
+ state_dict = torch.load(heads_path, map_location="cpu")
760
+ aux_heads.load_state_dict(state_dict)
761
+
762
+ # Move to device - use cuda:0 to keep on single device
763
+ model_device = (
764
+ torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
765
+ )
766
+ model_dtype = next(model.parameters()).dtype
767
+ aux_heads = aux_heads.to(device=model_device, dtype=model_dtype)
768
+ aux_heads.eval()
769
+
770
+ # Create decoder
771
+ decoder = DSSDecoder(
772
+ model=model,
773
+ adapter=adapter,
774
+ aux_heads=aux_heads,
775
+ tokenizer=tokenizer,
776
+ model_config=model_config,
777
+ calibration=calibration,
778
+ device=str(model_device),
779
+ )
780
+
781
+ return decoder, tokenizer
src/model_adapters.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Adapters for True Early Exit
2
+ # Abstract interface to stop layer computation early across architectures
3
+
4
+ from abc import ABC, abstractmethod
5
+ from typing import Tuple, Optional, List, Dict, Callable
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch import Tensor
9
+
10
+
11
+ class ModelAdapter(ABC):
12
+ """Abstract interface for model internals to enable true early exit."""
13
+
14
+ @abstractmethod
15
+ def get_embed_tokens(self, input_ids: Tensor) -> Tensor:
16
+ """Get token embeddings."""
17
+ ...
18
+
19
+ @abstractmethod
20
+ def get_layers(self) -> nn.ModuleList:
21
+ """Get list of decoder layers."""
22
+ ...
23
+
24
+ @abstractmethod
25
+ def get_num_layers(self) -> int:
26
+ """Get total number of layers."""
27
+ ...
28
+
29
+ @abstractmethod
30
+ def forward_layer(
31
+ self,
32
+ layer: nn.Module,
33
+ hidden_states: Tensor,
34
+ position_ids: Tensor,
35
+ attention_mask: Optional[Tensor],
36
+ past_key_value: Optional[Tuple],
37
+ position_embeddings: Optional[Tuple],
38
+ use_cache: bool = True,
39
+ ) -> Tuple[Tensor, Optional[Tuple]]:
40
+ """Forward through a single layer, returning hidden states and optional KV cache."""
41
+ ...
42
+
43
+ @abstractmethod
44
+ def apply_final_norm(self, hidden_states: Tensor) -> Tensor:
45
+ """Apply final normalization before lm_head."""
46
+ ...
47
+
48
+ @abstractmethod
49
+ def get_lm_head_output(self, hidden_states: Tensor) -> Tensor:
50
+ """Get logits from lm_head."""
51
+ ...
52
+
53
+ @abstractmethod
54
+ def get_position_embeddings(
55
+ self, hidden_states: Tensor, position_ids: Tensor
56
+ ) -> Optional[Tuple[Tensor, Tensor]]:
57
+ """Get rotary position embeddings (cos, sin) if applicable."""
58
+ ...
59
+
60
+
61
+ class LlamaStyleAdapter(ModelAdapter):
62
+ """
63
+ Adapter for Llama-style architectures.
64
+ Works for: Llama, Llama2, Llama3, Qwen, Qwen2, Qwen3, Mistral, Gemma
65
+
66
+ These models share the same internal structure:
67
+ - model.model.embed_tokens
68
+ - model.model.layers (ModuleList of decoder layers)
69
+ - model.model.norm (final RMSNorm)
70
+ - model.lm_head
71
+ - model.model.rotary_emb (RoPE embeddings)
72
+ """
73
+
74
+ def __init__(self, model):
75
+ self.model = model
76
+ self._base = model.model
77
+ self._layers = self._base.layers
78
+ self._embed = self._base.embed_tokens
79
+ self._norm = self._base.norm
80
+ self._lm_head = model.lm_head
81
+ self._rotary = getattr(self._base, "rotary_emb", None)
82
+ self._num_layers = len(self._layers)
83
+
84
+ def get_embed_tokens(self, input_ids: Tensor) -> Tensor:
85
+ return self._embed(input_ids)
86
+
87
+ def get_layers(self) -> nn.ModuleList:
88
+ return self._layers
89
+
90
+ def get_num_layers(self) -> int:
91
+ return self._num_layers
92
+
93
+ def forward_layer(
94
+ self,
95
+ layer: nn.Module,
96
+ hidden_states: Tensor,
97
+ position_ids: Tensor,
98
+ attention_mask: Optional[Tensor],
99
+ past_key_value: Optional[Tuple],
100
+ position_embeddings: Optional[Tuple],
101
+ use_cache: bool = True,
102
+ ) -> Tuple[Tensor, Optional[Tuple]]:
103
+ """Forward through a decoder layer."""
104
+ layer_outputs = layer(
105
+ hidden_states,
106
+ attention_mask=attention_mask,
107
+ position_ids=position_ids,
108
+ past_key_value=past_key_value,
109
+ use_cache=use_cache,
110
+ position_embeddings=position_embeddings,
111
+ )
112
+ hidden_states = layer_outputs[0]
113
+ new_kv = layer_outputs[1] if len(layer_outputs) > 1 else None
114
+ return hidden_states, new_kv
115
+
116
+ def apply_final_norm(self, hidden_states: Tensor) -> Tensor:
117
+ return self._norm(hidden_states)
118
+
119
+ def get_lm_head_output(self, hidden_states: Tensor) -> Tensor:
120
+ return self._lm_head(hidden_states)
121
+
122
+ def get_position_embeddings(
123
+ self, hidden_states: Tensor, position_ids: Tensor
124
+ ) -> Optional[Tuple[Tensor, Tensor]]:
125
+ if self._rotary is not None:
126
+ cos, sin = self._rotary(hidden_states, position_ids)
127
+ return (cos, sin)
128
+ return None
129
+
130
+
131
+ def get_adapter(model) -> ModelAdapter:
132
+ """
133
+ Factory function to get the appropriate adapter for a model.
134
+
135
+ Currently supports Llama-style models (Llama, Qwen, Mistral, Gemma).
136
+ """
137
+ # Check for Llama-style architecture
138
+ if hasattr(model, "model") and hasattr(model.model, "layers"):
139
+ return LlamaStyleAdapter(model)
140
+
141
+ # GPT-2 style (transformer.h)
142
+ if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
143
+ raise NotImplementedError("GPT-2 style models not yet supported")
144
+
145
+ raise ValueError(f"Unsupported model architecture: {type(model)}")
src/model_config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model configuration and calibration dataclasses
2
+ # Re-exported from the main package for demo use
3
+
4
+ import json
5
+ from dataclasses import dataclass, field, asdict
6
+ from typing import Dict, List, Optional
7
+
8
+
9
+ @dataclass
10
+ class ModelConfig:
11
+ """Configuration for a trained early exit model."""
12
+
13
+ model_name: str
14
+ num_heads: int
15
+ head_layer_indices: List[int]
16
+ quantization: str # "none", "4bit", "8bit"
17
+ hidden_size: int
18
+ vocab_size: int
19
+ num_hidden_layers: int
20
+ training_config: Optional[Dict] = None
21
+
22
+ @classmethod
23
+ def from_json(cls, path: str) -> "ModelConfig":
24
+ with open(path, "r") as f:
25
+ data = json.load(f)
26
+ return cls(
27
+ model_name=data["model_name"],
28
+ num_heads=data["num_heads"],
29
+ head_layer_indices=data["head_layer_indices"],
30
+ quantization=data["quantization"],
31
+ hidden_size=data["hidden_size"],
32
+ vocab_size=data["vocab_size"],
33
+ num_hidden_layers=data["num_hidden_layers"],
34
+ training_config=data.get("training_config"),
35
+ )
36
+
37
+ def to_json(self, path: str) -> None:
38
+ with open(path, "w") as f:
39
+ json.dump(asdict(self), f, indent=2)
40
+
41
+
42
+ @dataclass
43
+ class CalibrationResult:
44
+ """Calibration results with thresholds per head per accuracy level."""
45
+
46
+ model_config_path: str
47
+ calibration_dataset: str
48
+ calibration_samples: int
49
+ uncertainty_metric: str # "entropy" or "confidence"
50
+ accuracy_levels: List[float]
51
+ thresholds: Dict[str, Dict[str, float]] = field(default_factory=dict)
52
+ statistics: Dict[str, Dict] = field(default_factory=dict)
53
+
54
+ @classmethod
55
+ def from_json(cls, path: str) -> "CalibrationResult":
56
+ with open(path, "r") as f:
57
+ data = json.load(f)
58
+ return cls(**data)
59
+
60
+ def to_json(self, path: str) -> None:
61
+ with open(path, "w") as f:
62
+ json.dump(asdict(self), f, indent=2)
63
+
64
+ def get_threshold(self, accuracy_level: float, head_idx: int) -> float:
65
+ level_key = f"{accuracy_level:.2f}"
66
+ head_key = str(head_idx)
67
+ return self.thresholds[level_key][head_key]
68
+
69
+ def get_thresholds_for_level(self, accuracy_level: float) -> Dict[int, float]:
70
+ """Get all thresholds for a given accuracy level."""
71
+ level_key = f"{accuracy_level:.2f}"
72
+ return {int(k): v for k, v in self.thresholds[level_key].items()}