File size: 20,972 Bytes
72b2f6d
 
 
 
 
 
33efa44
72b2f6d
33efa44
72b2f6d
 
33efa44
72b2f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33efa44
 
 
 
72b2f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33efa44
72b2f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33efa44
 
72b2f6d
 
 
 
 
 
 
 
 
 
 
33efa44
 
72b2f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33efa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72b2f6d
 
 
 
 
 
 
 
 
 
 
33efa44
72b2f6d
45e00e6
72b2f6d
 
33efa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72b2f6d
 
33efa44
72b2f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
33efa44
 
72b2f6d
 
 
 
 
 
33efa44
 
 
 
 
 
 
72b2f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
33efa44
 
 
 
 
 
72b2f6d
33efa44
 
 
 
 
 
 
72b2f6d
 
33efa44
72b2f6d
33efa44
72b2f6d
 
 
 
 
33efa44
 
72b2f6d
 
 
 
 
33efa44
 
 
 
 
 
 
72b2f6d
33efa44
 
 
 
 
 
 
 
 
 
 
 
72b2f6d
 
 
 
33efa44
72b2f6d
 
 
33efa44
 
 
 
 
 
 
 
 
 
 
72b2f6d
 
 
 
33efa44
 
 
72b2f6d
 
 
 
 
 
33efa44
 
 
 
 
 
 
 
72b2f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33efa44
 
 
 
 
72b2f6d
33efa44
 
 
 
 
 
 
 
 
 
 
 
 
 
72b2f6d
33efa44
 
 
 
 
 
 
 
 
72b2f6d
 
 
 
33efa44
 
 
72b2f6d
 
 
 
 
33efa44
 
 
 
 
 
 
72b2f6d
33efa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72b2f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33efa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72b2f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33efa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72b2f6d
 
 
 
 
 
 
33efa44
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
"""
DSSD Demo - Dynamic Self-Speculative Decoding Visualization
Showcases early exit inference with color-coded tokens showing which head generated each token.
"""

import gradio as gr
from dataclasses import dataclass
from pathlib import Path
import time
from huggingface_hub import hf_hub_download

from src.inference import load_dssd_model, DSSDecoder, TokenInfo, StreamEvent, StreamingResult

# Available models configuration
AVAILABLE_MODELS = {
    "DSSD-Llama3-8B": {
        "model_name": "meta-llama/Meta-Llama-3-8B",
        "repo_id": "valcore/DSSD-Llama3-8B",
        "local_path": "../checkpoints/llama3-8b-4bit",
    },
    "DSSD-Qwen3-0.6B": {
        "model_name": "Qwen/Qwen3-0.6B",
        "repo_id": "valcore/DSSD-Qwen3-0.6B",
        "local_path": "../checkpoints/qwen3-0.6b",
    },
}

# Color palette for exit heads (colorblind-friendly)
HEAD_COLORS = [
    "#E63946",  # Red - Head 0 (earliest)
    "#F4A261",  # Orange - Head 1
    "#2A9D8F",  # Teal - Head 2
    "#457B9D",  # Blue - Head 3
    "#8338EC",  # Purple - Head 4
]
FULL_MODEL_COLOR = "#95D5B2"  # Light green - Full model

PENDING_TOKEN_BORDER = "var(--border-color-primary)"
PENDING_TOKEN_TEXT = "var(--body-text-color)"
DRAFTED_FALLBACK_COLOR = "var(--neutral-200)"

# Global decoder cache
_decoder_cache = {}


def get_decoder(model_key: str) -> DSSDecoder:
    """Get or load a decoder for the specified model."""
    global _decoder_cache

    if model_key in _decoder_cache:
        return _decoder_cache[model_key]

    model_info = AVAILABLE_MODELS[model_key]

    # Try local path first (for development)
    local_dir = Path(__file__).parent / model_info["local_path"]
    heads_path = local_dir / "aux_heads.pt"
    config_path = local_dir / "config.json"
    calibration_path = local_dir / "calibration.json"

    if heads_path.exists() and config_path.exists():
        print(f"Loading model heads from local path: {local_dir}")
        # calibration_path is optional, so no need to check its existence here
    else:
        # Download from HF Hub
        repo_id = model_info["repo_id"]
        print(f"Downloading model heads from {repo_id}...")
        heads_path = hf_hub_download(repo_id=repo_id, filename="aux_heads.pt")
        config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
        try:
            calibration_path = hf_hub_download(
                repo_id=repo_id, filename="calibration.json"
            )
        except Exception:
            calibration_path = None  # calibration.json is optional

    decoder, tokenizer = load_dssd_model(
        model_name=model_info["model_name"],
        heads_path=str(heads_path),
        config_path=str(config_path),
        calibration_path=str(calibration_path) if calibration_path else None,
        device="auto",
    )

    _decoder_cache[model_key] = decoder
    return decoder


def tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> str:
    """Convert token info list to color-coded HTML."""
    html_parts = []

    for token in tokens:
        if token.exit_head is not None:
            color = HEAD_COLORS[token.exit_head % len(HEAD_COLORS)]
            layer = head_layers[token.exit_head]
            title = f"Head {token.exit_head} (Layer {layer})"
        else:
            color = FULL_MODEL_COLOR
            title = f"Full Model (Layer {token.exit_layer})"

        # Escape HTML special chars
        text = (
            token.token_text.replace("&", "&")
            .replace("<", "&lt;")
            .replace(">", "&gt;")
        )
        text = text.replace("\n", "<br>").replace(" ", "&nbsp;")

        html_parts.append(
            f'<span style="background-color: {color}; padding: 2px 4px; '
            f'border-radius: 3px; margin: 1px; display: inline-block; color: #111827;" title="{title}">{text}</span>'
        )

    # Wrap in container with word-wrap to prevent overflow
    tokens_html = "".join(html_parts)
    return f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{tokens_html}</div>"""


def drafted_tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> str:
    """Convert drafted (pending) tokens to HTML with dashed border style."""
    html_parts = []

    for token in tokens:
        if token.exit_head is not None:
            color = HEAD_COLORS[token.exit_head % len(HEAD_COLORS)]
            layer = head_layers[token.exit_head]
            title = f"PENDING - Head {token.exit_head} (Layer {layer})"
        else:
            color = DRAFTED_FALLBACK_COLOR
            title = "PENDING - Unassigned"

        text = (
            token.token_text.replace("&", "&amp;")
            .replace("<", "&lt;")
            .replace(">", "&gt;")
        )
        text = text.replace("\n", "<br>").replace(" ", "&nbsp;")

        html_parts.append(
            f'<span style="background-color: {color}; padding: 2px 4px; '
            f"border-radius: 3px; margin: 1px; display: inline-block; "
            f"border: 2px dashed {PENDING_TOKEN_BORDER}; color: {PENDING_TOKEN_TEXT}; "
            f'opacity: 0.75;" title="{title}">{text}</span>'
        )

    return "".join(html_parts)


def create_legend(head_layers: list[int]) -> str:
    """Create HTML legend for the color scheme."""
    legend_items = []
    for i, layer in enumerate(head_layers):
        color = HEAD_COLORS[i % len(HEAD_COLORS)]
        legend_items.append(
            f'<span style="background-color: {color}; padding: 4px 8px; '
            f'border-radius: 4px; margin-right: 8px;">Head {i} (Layer {layer})</span>'
        )
    legend_items.append(
        f'<span style="background-color: {FULL_MODEL_COLOR}; padding: 4px 8px; '
        f'border-radius: 4px;">Full Model</span>'
    )
    return " ".join(legend_items)



@dataclass
class StatsPayload:
    generated_at: float
    speedup_text: str
    ee_time: str | None
    ee_tps: str | None
    ee_avg: str | None
    full_time: str | None
    full_tps: str | None
    full_avg: str | None
    show_ee: bool
    show_full: bool


def build_stats_outputs(
    result_ee,
    result_full,
    use_early_exit: bool,
    compare_mode: bool,
    generated_at: float | None = None,
):
    speedup_text = ""
    if result_ee and result_full and result_full.tokens_per_second > 0:
        speedup = result_ee.tokens_per_second / result_full.tokens_per_second
        speedup_text = f"**Speedup:** {speedup:.2f}x"
    elif result_ee:
        speedup_text = "**Speedup:** N/A (full model not run)"
    elif result_full:
        speedup_text = "**Speedup:** N/A (early exit disabled)"

    if not speedup_text:
        speedup_text = "**Speedup:** N/A"

    ee_time = f"{result_ee.total_time:.2f}" if result_ee else None
    ee_tps = f"{result_ee.tokens_per_second:.2f}" if result_ee else None
    ee_avg = f"{result_ee.avg_exit_layer:.1f}" if result_ee else None

    full_time = f"{result_full.total_time:.2f}" if result_full else None
    full_tps = f"{result_full.tokens_per_second:.2f}" if result_full else None
    full_avg = f"{result_full.avg_exit_layer:.1f}" if result_full else None

    show_ee = compare_mode or use_early_exit
    show_full = compare_mode or not use_early_exit

    return StatsPayload(
        generated_at=generated_at if generated_at is not None else time.time(),
        speedup_text=speedup_text,
        ee_time=ee_time,
        ee_tps=ee_tps,
        ee_avg=ee_avg,
        full_time=full_time,
        full_tps=full_tps,
        full_avg=full_avg,
        show_ee=show_ee,
        show_full=show_full,
    )


def stats_payload_to_outputs(payload: StatsPayload):
    return (
        payload.speedup_text,
        payload.ee_time,
        payload.ee_tps,
        payload.ee_avg,
        payload.full_time,
        payload.full_tps,
        payload.full_avg,
        gr.update(visible=payload.show_ee),
        gr.update(visible=payload.show_full),
    )



def generate(
    prompt: str,
    model_key: str,
    use_early_exit: bool,
    accuracy_level: float,
    max_tokens: int,
    compare_mode: bool,
):
    """Main generation function for Gradio interface with streaming."""
    initial_stats_timestamp = time.time()
    try:
        decoder = get_decoder(model_key)
    except Exception as e:
        error_msg = f"<p style='color: red;'>Error loading model: {e}</p>"
        status_msg = f"**Error loading model:** {e}"
        stats_payload = build_stats_outputs(
            None,
            None,
            use_early_exit,
            compare_mode,
            generated_at=initial_stats_timestamp,
        )
        yield (
            error_msg,
            "",
            status_msg,
            *stats_payload_to_outputs(stats_payload),
            "",
        )
        return


    head_layers = decoder.model_config.head_layer_indices
    legend = create_legend(head_layers)

    # Get calibration accuracy levels
    if decoder.calibration:
        available_levels = decoder.calibration.accuracy_levels
        closest_level = min(available_levels, key=lambda x: abs(x - accuracy_level))
    else:
        closest_level = accuracy_level

    if compare_mode:
        # Compare mode with streaming for early exit
        # First, stream the early exit generation
        final_ee_tokens = []
        ee_streaming_result = None

        for event in decoder.generate_streaming(
            prompt=prompt,
            max_tokens=int(max_tokens),
            accuracy_level=closest_level,
            use_chat_template=True,
        ):
            # Handle "complete" event - extract result and break
            if event.event_type == "complete":
                ee_streaming_result = event.result
                final_ee_tokens = event.tokens
                break

            final_ee_tokens = event.tokens
            validated_html = ""
            if event.tokens:
                validated_html = tokens_to_html(event.tokens, head_layers)
                validated_html = validated_html.replace(
                    '<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">',
                    "",
                ).rstrip("</div>")

            drafted_html = ""
            if event.drafted_tokens:
                drafted_html = drafted_tokens_to_html(event.drafted_tokens, head_layers)

            combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""

            status = (
                "**Early Exit:** {message}  \n"
                "**Full Model:** Waiting..."
            ).format(
                message=event.message,
            )

            stats_payload = build_stats_outputs(
                None,
                None,
                use_early_exit,
                compare_mode,
                generated_at=initial_stats_timestamp,
            )
            yield (
                combined_html,
                "<p style='color: var(--body-text-color-subdued);'>Waiting for early exit to complete...</p>",
                status,
                *stats_payload_to_outputs(stats_payload),
                legend,
            )

        # Now stream full model
        final_full_tokens = []
        full_streaming_result = None

        for event in decoder.generate_full_model_streaming(
            prompt=prompt,
            max_tokens=int(max_tokens),
            use_chat_template=True,
        ):
            # Handle "complete" event - extract result and break
            if event.event_type == "complete":
                full_streaming_result = event.result
                final_full_tokens = event.tokens
                break

            final_full_tokens = event.tokens
            html_full = tokens_to_html(event.tokens, head_layers)
            status = (
                "**Full Model:** {message}"
            ).format(
                message=event.message,
            )
            stats_payload = build_stats_outputs(
                None,
                None,
                use_early_exit,
                compare_mode,
                generated_at=initial_stats_timestamp,
            )
            yield (
                tokens_to_html(final_ee_tokens, head_layers),
                html_full,
                status,
                *stats_payload_to_outputs(stats_payload),
                legend,
            )

        # Final output with metrics from streaming results (no re-run needed)
        html_ee = tokens_to_html(final_ee_tokens, head_layers)
        html_full = tokens_to_html(final_full_tokens, head_layers)

        stats_payload = build_stats_outputs(ee_streaming_result, full_streaming_result, use_early_exit, compare_mode)
        yield (
            html_ee,
            html_full,
            "",
            *stats_payload_to_outputs(stats_payload),
            legend,
        )

    elif use_early_exit:
        # STREAMING mode for early exit - show draft/verify process
        streaming_result = None
        final_tokens = []

        for event in decoder.generate_streaming(
            prompt=prompt,
            max_tokens=int(max_tokens),
            accuracy_level=closest_level,
            use_chat_template=True,
        ):
            # Handle "complete" event - extract result and break
            if event.event_type == "complete":
                streaming_result = event.result
                final_tokens = event.tokens
                break

            final_tokens = event.tokens

            # Build HTML showing validated + drafted tokens
            validated_html = ""
            if event.tokens:
                validated_html = tokens_to_html(event.tokens, head_layers)
                # Remove the outer div to combine with drafted
                validated_html = validated_html.replace(
                    '<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">',
                    "",
                ).rstrip("</div>")

            drafted_html = ""
            if event.drafted_tokens:
                drafted_html = drafted_tokens_to_html(event.drafted_tokens, head_layers)

            # Combine
            combined_html = f"""<div style="word-wrap: break-word; overflow-wrap: break-word; max-width: 100%; line-height: 1.8;">{validated_html}{drafted_html}</div>"""

            # Status message
            status = (
                "**Status:** {message}"
            ).format(
                message=event.message,
            )

            stats_payload = build_stats_outputs(
                None,
                None,
                use_early_exit,
                compare_mode,
                generated_at=initial_stats_timestamp,
            )
            yield (
                combined_html,
                "",
                status,
                *stats_payload_to_outputs(stats_payload),
                legend,
            )

        # Final output with metrics from streaming result (no re-run needed)
        html = tokens_to_html(final_tokens, head_layers)
        stats_payload = build_stats_outputs(streaming_result, None, use_early_exit, compare_mode)
        yield (
            html,
            "",
            "",
            *stats_payload_to_outputs(stats_payload),
            legend,
        )

    else:
        # Full model mode (streaming)
        streaming_result = None
        final_tokens = []

        for event in decoder.generate_full_model_streaming(
            prompt=prompt,
            max_tokens=int(max_tokens),
            use_chat_template=True,
        ):
            # Handle "complete" event - extract result and break
            if event.event_type == "complete":
                streaming_result = event.result
                final_tokens = event.tokens
                break

            final_tokens = event.tokens
            html = tokens_to_html(event.tokens, head_layers)
            status = (
                "**Full Model:** {message}"
            ).format(
                message=event.message,
            )
            stats_payload = build_stats_outputs(
                None,
                None,
                use_early_exit,
                compare_mode,
                generated_at=initial_stats_timestamp,
            )
            yield (
                html,
                "",
                status,
                *stats_payload_to_outputs(stats_payload),
                legend,
            )

        # Final output with metrics from streaming result (no re-run needed)
        html = tokens_to_html(final_tokens, head_layers)
        stats_payload = build_stats_outputs(None, streaming_result, use_early_exit, compare_mode)
        yield (
            html,
            "",
            "",
            *stats_payload_to_outputs(stats_payload),
            legend,
        )


def build_demo():
    """Build the Gradio demo interface."""
    with gr.Blocks(title="DSSD Demo", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # 🚀 Dynamic Self-Speculative Decoding (DSSD) Demo
        
        This demo showcases **early exit inference** where tokens can be generated from intermediate 
        layers when the model is confident, resulting in faster generation.
        
        **Colors indicate which layer generated each token** - earlier layers = faster!
        """)

        with gr.Row():
            with gr.Column(scale=1):
                prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="Enter your prompt here...",
                    lines=3,
                    value="What is machine learning in simple terms?",
                )

                model_selector = gr.Dropdown(
                    label="Model",
                    choices=list(AVAILABLE_MODELS.keys()),
                    value=list(AVAILABLE_MODELS.keys())[0],
                )

                with gr.Row():
                    use_early_exit = gr.Checkbox(label="Enable Early Exit", value=True)
                    compare_mode = gr.Checkbox(label="Compare Mode", value=False)

                accuracy_level = gr.Slider(
                    label="Accuracy Level",
                    minimum=0.6,
                    maximum=0.99,
                    step=0.05,
                    value=0.75,
                    info="Higher = more accurate but slower",
                )

                max_tokens = gr.Slider(
                    label="Max Tokens",
                    minimum=10,
                    maximum=200,
                    step=10,
                    value=50,
                )

                generate_btn = gr.Button("Generate", variant="primary")

        # Legend (full width, above outputs)
        legend_html = gr.HTML()

        # Outputs section - dynamic based on compare mode
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Generated Output")
                output_ee = gr.HTML()

            with gr.Column(scale=1, visible=False) as compare_col:
                gr.Markdown("### Full Model (Comparison)")
                output_full = gr.HTML()

        status_html = gr.Markdown()

        with gr.Group():
            gr.Markdown("### Speedup Recap")
            speedup_md = gr.Markdown()
            with gr.Row():
                with gr.Column(visible=True) as ee_stats_col:
                    gr.Markdown("#### Early Exit")
                    ee_time = gr.Label(label="Time (s)")
                    ee_tps = gr.Label(label="Tokens/sec")
                    ee_avg = gr.Label(label="Avg Exit Layer")
                with gr.Column(visible=False) as full_stats_col:
                    gr.Markdown("#### Full Model")
                    full_time = gr.Label(label="Time (s)")
                    full_tps = gr.Label(label="Tokens/sec")
                    full_avg = gr.Label(label="Avg Exit Layer")

        def update_visibility(compare):
            return gr.update(visible=compare)

        compare_mode.change(
            fn=update_visibility,
            inputs=[compare_mode],
            outputs=[compare_col],
        )

        generate_btn.click(
            fn=generate,
            inputs=[
                prompt,
                model_selector,
                use_early_exit,
                accuracy_level,
                max_tokens,
                compare_mode,
            ],
            outputs=[
                output_ee,
                output_full,
                status_html,
                speedup_md,
                ee_time,
                ee_tps,
                ee_avg,
                full_time,
                full_tps,
                full_avg,
                ee_stats_col,
                full_stats_col,
                legend_html,
            ],
        )

    return demo


if __name__ == "__main__":
    demo = build_demo()
    demo.launch(share=False, debug=True)