""" DSSD Demo - Dynamic Self-Speculative Decoding Visualization Showcases early exit inference with color-coded tokens showing which head generated each token. """ import gradio as gr from dataclasses import dataclass from pathlib import Path import time from huggingface_hub import hf_hub_download from src.inference import load_dssd_model, DSSDecoder, TokenInfo, StreamEvent, StreamingResult # Available models configuration AVAILABLE_MODELS = { "DSSD-Llama3-8B": { "model_name": "meta-llama/Meta-Llama-3-8B", "repo_id": "valcore/DSSD-Llama3-8B", "local_path": "../checkpoints/llama3-8b-4bit", }, "DSSD-Qwen3-0.6B": { "model_name": "Qwen/Qwen3-0.6B", "repo_id": "valcore/DSSD-Qwen3-0.6B", "local_path": "../checkpoints/qwen3-0.6b", }, } # Color palette for exit heads (colorblind-friendly) HEAD_COLORS = [ "#E63946", # Red - Head 0 (earliest) "#F4A261", # Orange - Head 1 "#2A9D8F", # Teal - Head 2 "#457B9D", # Blue - Head 3 "#8338EC", # Purple - Head 4 ] FULL_MODEL_COLOR = "#95D5B2" # Light green - Full model PENDING_TOKEN_BORDER = "var(--border-color-primary)" PENDING_TOKEN_TEXT = "var(--body-text-color)" DRAFTED_FALLBACK_COLOR = "var(--neutral-200)" # Global decoder cache _decoder_cache = {} def get_decoder(model_key: str) -> DSSDecoder: """Get or load a decoder for the specified model.""" global _decoder_cache if model_key in _decoder_cache: return _decoder_cache[model_key] model_info = AVAILABLE_MODELS[model_key] # Try local path first (for development) local_dir = Path(__file__).parent / model_info["local_path"] heads_path = local_dir / "aux_heads.pt" config_path = local_dir / "config.json" calibration_path = local_dir / "calibration.json" if heads_path.exists() and config_path.exists(): print(f"Loading model heads from local path: {local_dir}") # calibration_path is optional, so no need to check its existence here else: # Download from HF Hub repo_id = model_info["repo_id"] print(f"Downloading model heads from {repo_id}...") heads_path = hf_hub_download(repo_id=repo_id, filename="aux_heads.pt") config_path = hf_hub_download(repo_id=repo_id, filename="config.json") try: calibration_path = hf_hub_download( repo_id=repo_id, filename="calibration.json" ) except Exception: calibration_path = None # calibration.json is optional decoder, tokenizer = load_dssd_model( model_name=model_info["model_name"], heads_path=str(heads_path), config_path=str(config_path), calibration_path=str(calibration_path) if calibration_path else None, device="auto", ) _decoder_cache[model_key] = decoder return decoder def tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> str: """Convert token info list to color-coded HTML.""" html_parts = [] for token in tokens: if token.exit_head is not None: color = HEAD_COLORS[token.exit_head % len(HEAD_COLORS)] layer = head_layers[token.exit_head] title = f"Head {token.exit_head} (Layer {layer})" else: color = FULL_MODEL_COLOR title = f"Full Model (Layer {token.exit_layer})" # Escape HTML special chars text = ( token.token_text.replace("&", "&") .replace("<", "<") .replace(">", ">") ) text = text.replace("\n", "
").replace(" ", " ") html_parts.append( f'{text}' ) # Wrap in container with word-wrap to prevent overflow tokens_html = "".join(html_parts) return f"""
{tokens_html}
""" def drafted_tokens_to_html(tokens: list[TokenInfo], head_layers: list[int]) -> str: """Convert drafted (pending) tokens to HTML with dashed border style.""" html_parts = [] for token in tokens: if token.exit_head is not None: color = HEAD_COLORS[token.exit_head % len(HEAD_COLORS)] layer = head_layers[token.exit_head] title = f"PENDING - Head {token.exit_head} (Layer {layer})" else: color = DRAFTED_FALLBACK_COLOR title = "PENDING - Unassigned" text = ( token.token_text.replace("&", "&") .replace("<", "<") .replace(">", ">") ) text = text.replace("\n", "
").replace(" ", " ") html_parts.append( f'{text}' ) return "".join(html_parts) def create_legend(head_layers: list[int]) -> str: """Create HTML legend for the color scheme.""" legend_items = [] for i, layer in enumerate(head_layers): color = HEAD_COLORS[i % len(HEAD_COLORS)] legend_items.append( f'Head {i} (Layer {layer})' ) legend_items.append( f'Full Model' ) return " ".join(legend_items) @dataclass class StatsPayload: generated_at: float speedup_text: str ee_time: str | None ee_tps: str | None ee_avg: str | None full_time: str | None full_tps: str | None full_avg: str | None show_ee: bool show_full: bool def build_stats_outputs( result_ee, result_full, use_early_exit: bool, compare_mode: bool, generated_at: float | None = None, ): speedup_text = "" if result_ee and result_full and result_full.tokens_per_second > 0: speedup = result_ee.tokens_per_second / result_full.tokens_per_second speedup_text = f"**Speedup:** {speedup:.2f}x" elif result_ee: speedup_text = "**Speedup:** N/A (full model not run)" elif result_full: speedup_text = "**Speedup:** N/A (early exit disabled)" if not speedup_text: speedup_text = "**Speedup:** N/A" ee_time = f"{result_ee.total_time:.2f}" if result_ee else None ee_tps = f"{result_ee.tokens_per_second:.2f}" if result_ee else None ee_avg = f"{result_ee.avg_exit_layer:.1f}" if result_ee else None full_time = f"{result_full.total_time:.2f}" if result_full else None full_tps = f"{result_full.tokens_per_second:.2f}" if result_full else None full_avg = f"{result_full.avg_exit_layer:.1f}" if result_full else None show_ee = compare_mode or use_early_exit show_full = compare_mode or not use_early_exit return StatsPayload( generated_at=generated_at if generated_at is not None else time.time(), speedup_text=speedup_text, ee_time=ee_time, ee_tps=ee_tps, ee_avg=ee_avg, full_time=full_time, full_tps=full_tps, full_avg=full_avg, show_ee=show_ee, show_full=show_full, ) def stats_payload_to_outputs(payload: StatsPayload): return ( payload.speedup_text, payload.ee_time, payload.ee_tps, payload.ee_avg, payload.full_time, payload.full_tps, payload.full_avg, gr.update(visible=payload.show_ee), gr.update(visible=payload.show_full), ) def generate( prompt: str, model_key: str, use_early_exit: bool, accuracy_level: float, max_tokens: int, compare_mode: bool, ): """Main generation function for Gradio interface with streaming.""" initial_stats_timestamp = time.time() try: decoder = get_decoder(model_key) except Exception as e: error_msg = f"

Error loading model: {e}

" status_msg = f"**Error loading model:** {e}" stats_payload = build_stats_outputs( None, None, use_early_exit, compare_mode, generated_at=initial_stats_timestamp, ) yield ( error_msg, "", status_msg, *stats_payload_to_outputs(stats_payload), "", ) return head_layers = decoder.model_config.head_layer_indices legend = create_legend(head_layers) # Get calibration accuracy levels if decoder.calibration: available_levels = decoder.calibration.accuracy_levels closest_level = min(available_levels, key=lambda x: abs(x - accuracy_level)) else: closest_level = accuracy_level if compare_mode: # Compare mode with streaming for early exit # First, stream the early exit generation final_ee_tokens = [] ee_streaming_result = None for event in decoder.generate_streaming( prompt=prompt, max_tokens=int(max_tokens), accuracy_level=closest_level, use_chat_template=True, ): # Handle "complete" event - extract result and break if event.event_type == "complete": ee_streaming_result = event.result final_ee_tokens = event.tokens break final_ee_tokens = event.tokens validated_html = "" if event.tokens: validated_html = tokens_to_html(event.tokens, head_layers) validated_html = validated_html.replace( '
', "", ).rstrip("
") drafted_html = "" if event.drafted_tokens: drafted_html = drafted_tokens_to_html(event.drafted_tokens, head_layers) combined_html = f"""
{validated_html}{drafted_html}
""" status = ( "**Early Exit:** {message} \n" "**Full Model:** Waiting..." ).format( message=event.message, ) stats_payload = build_stats_outputs( None, None, use_early_exit, compare_mode, generated_at=initial_stats_timestamp, ) yield ( combined_html, "

Waiting for early exit to complete...

", status, *stats_payload_to_outputs(stats_payload), legend, ) # Now stream full model final_full_tokens = [] full_streaming_result = None for event in decoder.generate_full_model_streaming( prompt=prompt, max_tokens=int(max_tokens), use_chat_template=True, ): # Handle "complete" event - extract result and break if event.event_type == "complete": full_streaming_result = event.result final_full_tokens = event.tokens break final_full_tokens = event.tokens html_full = tokens_to_html(event.tokens, head_layers) status = ( "**Full Model:** {message}" ).format( message=event.message, ) stats_payload = build_stats_outputs( None, None, use_early_exit, compare_mode, generated_at=initial_stats_timestamp, ) yield ( tokens_to_html(final_ee_tokens, head_layers), html_full, status, *stats_payload_to_outputs(stats_payload), legend, ) # Final output with metrics from streaming results (no re-run needed) html_ee = tokens_to_html(final_ee_tokens, head_layers) html_full = tokens_to_html(final_full_tokens, head_layers) stats_payload = build_stats_outputs(ee_streaming_result, full_streaming_result, use_early_exit, compare_mode) yield ( html_ee, html_full, "", *stats_payload_to_outputs(stats_payload), legend, ) elif use_early_exit: # STREAMING mode for early exit - show draft/verify process streaming_result = None final_tokens = [] for event in decoder.generate_streaming( prompt=prompt, max_tokens=int(max_tokens), accuracy_level=closest_level, use_chat_template=True, ): # Handle "complete" event - extract result and break if event.event_type == "complete": streaming_result = event.result final_tokens = event.tokens break final_tokens = event.tokens # Build HTML showing validated + drafted tokens validated_html = "" if event.tokens: validated_html = tokens_to_html(event.tokens, head_layers) # Remove the outer div to combine with drafted validated_html = validated_html.replace( '
', "", ).rstrip("
") drafted_html = "" if event.drafted_tokens: drafted_html = drafted_tokens_to_html(event.drafted_tokens, head_layers) # Combine combined_html = f"""
{validated_html}{drafted_html}
""" # Status message status = ( "**Status:** {message}" ).format( message=event.message, ) stats_payload = build_stats_outputs( None, None, use_early_exit, compare_mode, generated_at=initial_stats_timestamp, ) yield ( combined_html, "", status, *stats_payload_to_outputs(stats_payload), legend, ) # Final output with metrics from streaming result (no re-run needed) html = tokens_to_html(final_tokens, head_layers) stats_payload = build_stats_outputs(streaming_result, None, use_early_exit, compare_mode) yield ( html, "", "", *stats_payload_to_outputs(stats_payload), legend, ) else: # Full model mode (streaming) streaming_result = None final_tokens = [] for event in decoder.generate_full_model_streaming( prompt=prompt, max_tokens=int(max_tokens), use_chat_template=True, ): # Handle "complete" event - extract result and break if event.event_type == "complete": streaming_result = event.result final_tokens = event.tokens break final_tokens = event.tokens html = tokens_to_html(event.tokens, head_layers) status = ( "**Full Model:** {message}" ).format( message=event.message, ) stats_payload = build_stats_outputs( None, None, use_early_exit, compare_mode, generated_at=initial_stats_timestamp, ) yield ( html, "", status, *stats_payload_to_outputs(stats_payload), legend, ) # Final output with metrics from streaming result (no re-run needed) html = tokens_to_html(final_tokens, head_layers) stats_payload = build_stats_outputs(None, streaming_result, use_early_exit, compare_mode) yield ( html, "", "", *stats_payload_to_outputs(stats_payload), legend, ) def build_demo(): """Build the Gradio demo interface.""" with gr.Blocks(title="DSSD Demo", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🚀 Dynamic Self-Speculative Decoding (DSSD) Demo This demo showcases **early exit inference** where tokens can be generated from intermediate layers when the model is confident, resulting in faster generation. **Colors indicate which layer generated each token** - earlier layers = faster! """) with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", lines=3, value="What is machine learning in simple terms?", ) model_selector = gr.Dropdown( label="Model", choices=list(AVAILABLE_MODELS.keys()), value=list(AVAILABLE_MODELS.keys())[0], ) with gr.Row(): use_early_exit = gr.Checkbox(label="Enable Early Exit", value=True) compare_mode = gr.Checkbox(label="Compare Mode", value=False) accuracy_level = gr.Slider( label="Accuracy Level", minimum=0.6, maximum=0.99, step=0.05, value=0.75, info="Higher = more accurate but slower", ) max_tokens = gr.Slider( label="Max Tokens", minimum=10, maximum=200, step=10, value=50, ) generate_btn = gr.Button("Generate", variant="primary") # Legend (full width, above outputs) legend_html = gr.HTML() # Outputs section - dynamic based on compare mode with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Generated Output") output_ee = gr.HTML() with gr.Column(scale=1, visible=False) as compare_col: gr.Markdown("### Full Model (Comparison)") output_full = gr.HTML() status_html = gr.Markdown() with gr.Group(): gr.Markdown("### Speedup Recap") speedup_md = gr.Markdown() with gr.Row(): with gr.Column(visible=True) as ee_stats_col: gr.Markdown("#### Early Exit") ee_time = gr.Label(label="Time (s)") ee_tps = gr.Label(label="Tokens/sec") ee_avg = gr.Label(label="Avg Exit Layer") with gr.Column(visible=False) as full_stats_col: gr.Markdown("#### Full Model") full_time = gr.Label(label="Time (s)") full_tps = gr.Label(label="Tokens/sec") full_avg = gr.Label(label="Avg Exit Layer") def update_visibility(compare): return gr.update(visible=compare) compare_mode.change( fn=update_visibility, inputs=[compare_mode], outputs=[compare_col], ) generate_btn.click( fn=generate, inputs=[ prompt, model_selector, use_early_exit, accuracy_level, max_tokens, compare_mode, ], outputs=[ output_ee, output_full, status_html, speedup_md, ee_time, ee_tps, ee_avg, full_time, full_tps, full_avg, ee_stats_col, full_stats_col, legend_html, ], ) return demo if __name__ == "__main__": demo = build_demo() demo.launch(share=False, debug=True)