""" Selection Entropy for LLMs — Interactive Demo ============================================= HuggingFace Space · Gradio 4.x Visualises token-level uncertainty using Selection Entropy, a novel divergence metric adapted from computational neuroscience. Original metric: Selection entropy: The information hidden within neuronal patterns Erik D. Fagerholm, Zalina Dezhina, Rosalyn J. Moran, Karl J. Friston, Federico Turkheimer, Robert Leech. Department of Neuroimaging, King’s College London, London, United Kingdom Wellcome Centre for Human Neuroimaging, London, United Kingdom Adaptation: token probability distributions as discrete neural states """ import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import gradio as gr import numpy as np import plotly.graph_objects as go import plotly.express as px from plotly.subplots import make_subplots from model_inference import generate_with_logits, get_available_models from selection_entropy import normalise_entropies # ── Colour palette ────────────────────────────────────────────────────────── PALETTE = { "bg": "#0d0f14", "surface": "#161923", "border": "#2a2d3a", "accent": "#7c6af7", # purple — SE "accent2": "#3ec6c6", # teal — Shannon "low": "#1a3a2a", "mid": "#4a3a10", "high": "#5a1a1a", "text": "#e8e9f0", "muted": "#8a8d9e", } PLOTLY_TEMPLATE = dict( layout=dict( paper_bgcolor=PALETTE["bg"], plot_bgcolor=PALETTE["surface"], font=dict(color=PALETTE["text"], family="'IBM Plex Mono', monospace"), xaxis=dict(gridcolor=PALETTE["border"], linecolor=PALETTE["border"]), yaxis=dict(gridcolor=PALETTE["border"], linecolor=PALETTE["border"]), margin=dict(l=40, r=20, t=40, b=40), ) ) # ── Visualisation builders ──────────────────────────────────────────────── def build_heatmap_figure(token_results, metric="selection"): """Horizontal token heatmap coloured by entropy intensity.""" tokens = [t.token for t in token_results] norm_vals = normalise_entropies(token_results, metric=metric) raw_vals = ( [t.selection_entropy for t in token_results] if metric == "selection" else [t.shannon_entropy for t in token_results] ) # Wrap tokens into rows of ~12 row_size = 12 rows, row_vals, row_raw = [], [], [] for i in range(0, len(tokens), row_size): chunk = tokens[i:i+row_size] v_chunk = norm_vals[i:i+row_size] r_chunk = raw_vals[i:i+row_size] # Pad to row_size while len(chunk) < row_size: chunk.append("") v_chunk.append(None) r_chunk.append(None) rows.append(chunk) row_vals.append(v_chunk) row_raw.append(r_chunk) n_rows = len(rows) z = row_vals text = rows customdata = row_raw fig = go.Figure( go.Heatmap( z=z, text=text, customdata=customdata, texttemplate="%{text}", colorscale=[ [0.0, "#1a2a4a"], [0.3, "#2a4a6a"], [0.5, "#7c6af7"], [0.75, "#c44a6a"], [1.0, "#ff4444"], ], showscale=True, colorbar=dict( title=dict(text="Entropy (norm)", side="right"), tickfont=dict(color=PALETTE["muted"], size=10), thickness=12, ), hovertemplate=( "Token: '%{text}'
" "Entropy: %{customdata:.4f} nats
" "Normalised: %{z:.3f}" ), xgap=3, ygap=3, ) ) fig.update_layout( paper_bgcolor=PALETTE["bg"], plot_bgcolor=PALETTE["surface"], font=dict(color=PALETTE["text"], family="'IBM Plex Mono', monospace"), margin=dict(l=40, r=20, t=40, b=40), height=max(120, n_rows * 72), xaxis=dict(showticklabels=False, showgrid=False), yaxis=dict(showticklabels=False, showgrid=False), title=dict( text=f"{'Selection Entropy' if metric == 'selection' else 'Shannon Entropy'} — Token Heatmap", font=dict(size=13, color=PALETTE["muted"]), x=0.0, ), ) return fig def build_comparison_figure(token_results): """Line chart: SE vs Shannon entropy per token position.""" se_vals = [t.selection_entropy for t in token_results] sh_vals = [t.shannon_entropy for t in token_results] tokens = [repr(t.token) for t in token_results] positions = list(range(len(token_results))) fig = go.Figure() fig.add_trace(go.Scatter( x=positions, y=se_vals, mode="lines+markers", name="Selection Entropy", line=dict(color=PALETTE["accent"], width=2.5), marker=dict(size=6, color=PALETTE["accent"]), text=tokens, hovertemplate="%{text}
SE: %{y:.4f}", )) fig.add_trace(go.Scatter( x=positions, y=sh_vals, mode="lines+markers", name="Shannon Entropy", line=dict(color=PALETTE["accent2"], width=2.5, dash="dot"), marker=dict(size=6, color=PALETTE["accent2"]), text=tokens, hovertemplate="%{text}
H: %{y:.4f}", )) fig.update_layout( paper_bgcolor=PALETTE["bg"], plot_bgcolor=PALETTE["surface"], font=dict(color=PALETTE["text"], family="'IBM Plex Mono', monospace"), margin=dict(l=40, r=20, t=40, b=40), height=280, xaxis=dict(title="Token position", tickvals=positions, ticktext=tokens, tickangle=-45, tickfont=dict(size=9)), yaxis=dict(title="Entropy (nats)"), legend=dict( orientation="h", yanchor="bottom", y=1.02, bgcolor="rgba(0,0,0,0)", font=dict(color=PALETTE["text"], size=11), ), title=dict( text="Selection Entropy vs Shannon Entropy per Token", font=dict(size=13, color=PALETTE["muted"]), x=0.0, ), ) return fig def build_top_k_figure(token_results, token_idx: int): """Bar chart of top-5 alternatives for a selected token.""" te = token_results[token_idx] alts = te.top_k_alternatives labels = [a["token"] for a in alts] probs = [a["prob"] for a in alts] colors = [PALETTE["accent"] if lbl == te.token else PALETTE["accent2"] for lbl in labels] fig = go.Figure(go.Bar( x=[repr(l) for l in labels], y=probs, marker_color=colors, hovertemplate="%{x}
P = %{y:.4f}", text=[f"{p:.3f}" for p in probs], textposition="outside", textfont=dict(color=PALETTE["muted"], size=10), )) fig.update_layout( paper_bgcolor=PALETTE["bg"], plot_bgcolor=PALETTE["surface"], font=dict(color=PALETTE["text"], family="'IBM Plex Mono', monospace"), margin=dict(l=40, r=20, t=40, b=40), height=240, xaxis=dict(title="Token", tickfont=dict(size=11)), yaxis=dict(title="Probability", range=[0, max(probs) * 1.25]), title=dict( text=f"Top-5 at position {token_idx}: chosen = {repr(te.token)} " f"| SE={te.selection_entropy:.4f} H={te.shannon_entropy:.4f}", font=dict(size=12, color=PALETTE["muted"]), x=0.0, ), ) return fig def build_stats_table(token_results): """Summary statistics markdown.""" se = [t.selection_entropy for t in token_results] sh = [t.shannon_entropy for t in token_results] most_uncertain_se = token_results[int(np.argmax(se))] most_certain_se = token_results[int(np.argmin(se))] delta = [s - h for s, h in zip(se, sh)] max_delta_idx = int(np.argmax(np.abs(delta))) max_delta_tok = token_results[max_delta_idx] lines = [ "### 📊 Summary Statistics\n", f"| Metric | Mean | Max | Min |", f"|--------|------|-----|-----|", f"| **Selection Entropy** | {np.mean(se):.4f} | {np.max(se):.4f} | {np.min(se):.4f} |", f"| **Shannon Entropy** | {np.mean(sh):.4f} | {np.max(sh):.4f} | {np.min(sh):.4f} |", "", f"🔴 **Most uncertain token (SE):** `{repr(most_uncertain_se.token)}` " f"— SE={most_uncertain_se.selection_entropy:.4f}", f"🟢 **Most certain token (SE):** `{repr(most_certain_se.token)}` " f"— SE={most_certain_se.selection_entropy:.4f}", "", f"⚡ **Largest SE–Shannon divergence:** `{repr(max_delta_tok.token)}` " f"— ΔE={delta[max_delta_idx]:+.4f}", "", "> *Positive ΔE means Selection Entropy detects MORE competition " "than Shannon entropy sees. This is the core insight of the metric.*", ] return "\n".join(lines) # ── Main inference + render ─────────────────────────────────────────────── def run_analysis( prompt: str, model_name: str, max_tokens: int, temperature: float, alpha: float, metric_choice: str, token_idx: int, ): if not prompt.strip(): return ( go.Figure(), go.Figure(), go.Figure(), "*(enter a prompt above)*", "", 0, ) metric = "selection" if "Selection" in metric_choice else "shannon" try: generated_text, token_results = generate_with_logits( prompt=prompt, model_name=model_name, max_new_tokens=int(max_tokens), temperature=float(temperature), alpha=float(alpha), ) except Exception as e: err_fig = go.Figure() err_fig.add_annotation(text=f"Error: {e}", showarrow=False, font=dict(color="red", size=14)) return err_fig, err_fig, err_fig, str(e), "", 0 tok_idx = min(int(token_idx), len(token_results) - 1) if token_results else 0 heatmap_fig = build_heatmap_figure(token_results, metric=metric) compare_fig = build_comparison_figure(token_results) topk_fig = build_top_k_figure(token_results, tok_idx) stats_md = build_stats_table(token_results) n_tokens = len(token_results) return heatmap_fig, compare_fig, topk_fig, stats_md, generated_text, n_tokens # ── CSS ─────────────────────────────────────────────────────────────────── CUSTOM_CSS = """ @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&family=Space+Grotesk:wght@300;400;600&display=swap'); body, .gradio-container { background: #0d0f14 !important; font-family: 'Space Grotesk', sans-serif; } /* Only colour text inside dark markdown/prose areas, not inputs */ .gradio-container .markdown, .gradio-container .prose, gradio-app .markdown-body { color: #e8e9f0 !important; background: transparent !important; } gradio-app .markdown-body p, gradio-app .markdown-body li, gradio-app .markdown-body strong, gradio-app .markdown-body em, gradio-app .markdown-body h1, gradio-app .markdown-body h2, gradio-app .markdown-body h3 { color: #e8e9f0 !important; } gradio-app .markdown-body blockquote { border-left: 3px solid #7c6af7 !important; padding-left: 12px; color: #8a8d9e !important; } gradio-app .markdown-body code { background: #1e2130 !important; color: #3ec6c6 !important; padding: 2px 6px; border-radius: 4px; } .header-block { border-left: 3px solid #7c6af7; padding: 18px 24px; background: #161923; border-radius: 0 8px 8px 0; margin-bottom: 8px; } .header-block h1 { font-family: 'IBM Plex Mono', monospace; font-size: 1.5rem; color: #7c6af7; margin: 0 0 4px 0; letter-spacing: -0.02em; } .header-block p { color: #8a8d9e; font-size: 0.85rem; margin: 0; line-height: 1.5; } .section-label { font-family: 'IBM Plex Mono', monospace; font-size: 0.7rem; color: #7c6af7; letter-spacing: 0.12em; text-transform: uppercase; margin-bottom: 6px; } .metric-note { font-size: 0.78rem; color: #8a8d9e !important; font-style: italic; line-height: 1.5; padding: 10px 14px; border-left: 2px solid #2a2d3a; margin-top: 4px; } gradio-app button.primary { background: #7c6af7 !important; border: none !important; font-family: 'IBM Plex Mono', monospace; font-size: 0.88rem; letter-spacing: 0.04em; } gradio-app button.primary:hover { background: #9a8aff !important; } """ # ── Layout ──────────────────────────────────────────────────────────────── DESCRIPTION_MD = """ Selection Entropy is a divergence-based metric inspired by computational neuroscience, adapted here to characterise token-level uncertainty in LLMs. Unlike Shannon entropy, which mainly reflects the overall spread of a probability distribution, Selection Entropy is designed to be more sensitive to local competition among likely alternatives. When two or more tokens have similar probabilities, it marks that step as more uncertain. In practice, this makes it useful for visualising points where the model’s choice was not simply diffuse, but actively contested by nearby alternatives. > 🔴 **Hot tokens** = model is uncertain, many close alternatives > 🔵 **Cool tokens** = model is confident, one dominant choice """ with gr.Blocks(css=CUSTOM_CSS, title="Selection Entropy · LLM Uncertainty") as demo: gr.HTML("""

⟨ Selection Entropy · LLM Uncertainty ⟩

Token-level uncertainty visualisation
SE detects near-competition in token distributions — beyond what Shannon entropy reveals

""") gr.Markdown(DESCRIPTION_MD) with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox( label="Prompt", placeholder="The theory of general relativity states that...", lines=3, value="The key difference between Bayesian and frequentist statistics is", ) run_btn = gr.Button("▶ Analyse", variant="primary") with gr.Accordion("⚙ Generation settings", open=False): model_dd = gr.Dropdown( choices=get_available_models(), value="gpt2", label="Model", ) max_tokens_sl = gr.Slider(10, 80, value=30, step=5, label="Max new tokens") temperature_sl = gr.Slider(0.1, 2.0, value=0.9, step=0.1, label="Temperature") alpha_sl = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="α — SE competition sensitivity") gr.Markdown( "α = 0 → SE ≈ Shannon entropy · α = 1 → maximum competition weighting", elem_classes=["metric-note"] ) with gr.Column(scale=1): metric_radio = gr.Radio( choices=["Selection Entropy", "Shannon Entropy"], value="Selection Entropy", label="Heatmap colouring", ) token_idx_sl = gr.Slider( 0, 39, value=0, step=1, label="Token index for Top-5 detail view", ) n_tokens_out = gr.Number(label="Tokens generated", value=0, interactive=False) gr.Markdown(""" **How to read the heatmap** Each cell = one generated token. Colour intensity = normalised entropy. High SE at a token → model "hesitated" there. Hover over cells for exact values. """, elem_classes=["metric-note"]) gr.HTML('
Generated text
') generated_out = gr.Textbox(label="", interactive=False, elem_classes=["generated-box"]) gr.HTML('
Token entropy heatmap
') heatmap_plot = gr.Plot(label="") with gr.Row(): with gr.Column(): gr.HTML('
SE vs Shannon — per token
') compare_plot = gr.Plot(label="") with gr.Column(): gr.HTML('
Top-5 alternatives at selected token
') topk_plot = gr.Plot(label="") gr.HTML('
Summary
') stats_out = gr.Markdown("") # ── Example prompts ──────────────────────────────────────────────── gr.Examples( examples=[ ["The key difference between Bayesian and frequentist statistics is"], ["Once upon a time in a kingdom far away, there lived a"], ["The capital of France is Paris. The capital of Germany is"], ["To train a neural network, you need to"], ["I think therefore I"], ], inputs=prompt_input, label="Example prompts", ) # ── Event wiring ─────────────────────────────────────────────────── inputs = [ prompt_input, model_dd, max_tokens_sl, temperature_sl, alpha_sl, metric_radio, token_idx_sl, ] outputs = [ heatmap_plot, compare_plot, topk_plot, stats_out, generated_out, n_tokens_out, ] run_btn.click(fn=run_analysis, inputs=inputs, outputs=outputs) metric_radio.change(fn=run_analysis, inputs=inputs, outputs=outputs) token_idx_sl.change(fn=run_analysis, inputs=inputs, outputs=outputs) if __name__ == "__main__": demo.launch()