Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from transformers import BertTokenizerFast, BertForMaskedLM | |
| MODEL_NAME = "bert-base-uncased" | |
| # Load model & tokenizer once | |
| tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME) | |
| model = BertForMaskedLM.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| NUM_LAYERS = model.config.num_hidden_layers # 12 for bert-base-uncased | |
| def analyze(text: str, layer_idx: int): | |
| """ | |
| text: user input (ideally contains [MASK]) | |
| layer_idx: 1..NUM_LAYERS (which transformer block to visualise) | |
| """ | |
| if not text.strip(): | |
| return ( | |
| "<span style='color:#888'>Type some text above…</span>", | |
| "No [MASK] token, so I can’t show predictions.", | |
| None, | |
| None, | |
| "Please type some text containing the [MASK] token." | |
| ) | |
| # Tokenize | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| add_special_tokens=True | |
| ) | |
| input_ids = inputs["input_ids"] | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| # Find [MASK] position (if any) | |
| mask_token_id = tokenizer.mask_token_id | |
| mask_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0] | |
| mask_idx = int(mask_positions[0].item()) if len(mask_positions) > 0 else None | |
| # Run BERT encoder to get hidden states and attention | |
| outputs = model.bert( | |
| **inputs, | |
| output_hidden_states=True, | |
| output_attentions=True, | |
| return_dict=True, | |
| ) | |
| hidden_states = outputs.hidden_states # tuple: (emb, layer1, ..., layer12) | |
| attentions = outputs.attentions # tuple: (layer1..layer12), each [1, heads, seq, seq] | |
| # We'll compute predictions for ALL layers for the [MASK], then slice for plots | |
| layer_probs = [] # probability of best token per layer (or mask prob mass) | |
| layer_best_tokens = [] # best token name per layer | |
| if mask_idx is not None: | |
| for L in range(1, NUM_LAYERS + 1): | |
| hs = hidden_states[L] # [1, seq, hidden] | |
| logits = model.cls(hs) # [1, seq, vocab] | |
| mask_logits = logits[0, mask_idx, :] | |
| probs = torch.softmax(mask_logits, dim=-1) | |
| topk = torch.topk(probs, k=5) | |
| top_tokens = tokenizer.convert_ids_to_tokens(topk.indices.tolist()) | |
| top_probs = topk.values.tolist() | |
| # store best token per layer | |
| layer_probs.append(float(top_probs[0])) | |
| layer_best_tokens.append(top_tokens[0]) | |
| else: | |
| # no [MASK]: we won't run MLM head for curve, but everything else still works | |
| layer_probs = [0.0] * NUM_LAYERS | |
| layer_best_tokens = ["(no [MASK])"] * NUM_LAYERS | |
| # ---- Data for the selected layer ---- | |
| L = int(layer_idx) | |
| L_hidden = hidden_states[L][0] # [seq, hidden] | |
| # token "confidence" = norm of hidden vector, normalised for visualisation | |
| norms = torch.norm(L_hidden, dim=-1) | |
| norms_np = norms.cpu().numpy() | |
| if norms_np.max() > 0: | |
| conf = norms_np / norms_np.max() | |
| else: | |
| conf = norms_np | |
| # attention for this layer, head 0 | |
| L_att = attentions[L - 1][0, 0].cpu().numpy() # [seq, seq] | |
| # ensure it's [0,1] | |
| L_att = (L_att - L_att.min()) / (L_att.max() - L_att.min() + 1e-9) | |
| # ---- 1) Token visualisation (HTML with confidence-based background) ---- | |
| token_spans = [] | |
| for i, tok in enumerate(tokens): | |
| c = conf[i] if i < len(conf) else 0.0 | |
| bg = f"rgba(34,197,94,{0.15 + 0.7*c})" # green-ish | |
| border = "#22c55e" if i == mask_idx else "rgba(148,163,184,0.4)" | |
| token_spans.append( | |
| f"<span style='padding:2px 4px; margin:1px; border-radius:4px; " | |
| f"border:1px solid {border}; background:{bg}; font-size:12px; " | |
| f"display:inline-block;'>{tok}</span>" | |
| ) | |
| tokens_html = "<div style='line-height:1.8;'>" + " ".join(token_spans) + "</div>" | |
| # ---- 2) Top-k predictions for [MASK] at this layer ---- | |
| if mask_idx is not None: | |
| hs_L = hidden_states[L] # [1, seq, hidden] | |
| logits_L = model.cls(hs_L) | |
| mask_logits_L = logits_L[0, mask_idx, :] | |
| probs_L = torch.softmax(mask_logits_L, dim=-1) | |
| topk_L = torch.topk(probs_L, k=10) | |
| top_tokens_L = tokenizer.convert_ids_to_tokens(topk_L.indices.tolist()) | |
| top_probs_L = topk_L.values.tolist() | |
| # Build a markdown table | |
| lines = ["| Rank | Token | Prob |", "|------|-------|------|"] | |
| for rank, (tok, p) in enumerate(zip(top_tokens_L, top_probs_L), start=1): | |
| lines.append(f"| {rank} | `{tok}` | {p:.3f} |") | |
| pred_md = "\n".join(lines) | |
| else: | |
| pred_md = ( | |
| "There is **no `[MASK]` token** in your input.\n\n" | |
| "To see layer-wise predictions, include `[MASK]` somewhere in the text.\n" | |
| "Example: `The capital of France is [MASK].`" | |
| ) | |
| # ---- 3) Probability curve across layers ---- | |
| if mask_idx is not None: | |
| import plotly.graph_objs as go | |
| x = list(range(1, NUM_LAYERS + 1)) | |
| y = layer_probs | |
| fig_prob = go.Figure() | |
| fig_prob.add_trace(go.Scatter( | |
| x=x, | |
| y=y, | |
| mode="lines+markers", | |
| name="P(top token at [MASK])" | |
| )) | |
| fig_prob.update_layout( | |
| xaxis_title="Layer", | |
| yaxis_title="Probability of best prediction", | |
| template="plotly_dark", | |
| height=320, | |
| margin=dict(l=40, r=20, t=40, b=40), | |
| ) | |
| else: | |
| fig_prob = None | |
| # ---- 4) Attention heatmap for selected layer ---- | |
| import plotly.graph_objs as go | |
| att_fig = go.Figure( | |
| data=go.Heatmap( | |
| z=L_att, | |
| x=tokens, | |
| y=tokens, | |
| colorbar=dict(title="Attention"), | |
| ) | |
| ) | |
| att_fig.update_layout( | |
| xaxis_title="Key tokens", | |
| yaxis_title="Query tokens", | |
| template="plotly_dark", | |
| height=420, | |
| margin=dict(l=80, r=60, t=40, b=120), | |
| ) | |
| # ---- 5) Info text ---- | |
| info = ( | |
| f"### Layer {L} summary\n" | |
| f"- Hidden-state norms are used as a proxy for **token confidence** (bright = higher norm).\n" | |
| f"- The heatmap shows **self-attention weights** for layer {L}, head 1.\n" | |
| ) | |
| if mask_idx is not None: | |
| best_current = layer_best_tokens[L - 1] | |
| info += ( | |
| f"- At this layer, the top prediction for `[MASK]` is `{best_current}`.\n" | |
| f"- The line chart shows how the model’s confidence in its *current* best prediction " | |
| f"evolves across layers.\n" | |
| ) | |
| else: | |
| info += ( | |
| "- No `[MASK]` token detected, so layer-wise predictions are disabled. " | |
| "Add `[MASK]` to explore how different layers refine the guess.\n" | |
| ) | |
| return tokens_html, pred_md, fig_prob, att_fig, info | |
| # ------------- Gradio UI ------------- # | |
| DESCRIPTION = """ | |
| # 🔍 Transformer Layer Playground (BERT) | |
| Explore how a real transformer (**bert-base-uncased**) processes text *layer by layer*. | |
| - Type some text and choose a **layer** (1–12). | |
| - If you include `[MASK]`, you’ll see **layer-wise predictions** at that position. | |
| - Visualisations: | |
| - Token chips, where brightness ≈ **hidden state norm** (a rough proxy for confidence/activation). | |
| - A **line chart** of how the probability of the top prediction at `[MASK]` changes across layers. | |
| - A full **attention heatmap** for the selected layer and head 1. | |
| """ | |
| EXAMPLE_TEXTS = [ | |
| "The capital of France is [MASK].", | |
| "Transformers are very [MASK] models.", | |
| "I love eating [MASK] with tomato sauce.", | |
| "The [MASK] barked loudly at the stranger." | |
| ] | |
| with gr.Blocks() as demo: | |
| # Optional styling (safe even on older Gradio versions) | |
| gr.HTML(""" | |
| <style> | |
| #tokens-html { | |
| font-family: "JetBrains Mono", monospace; | |
| } | |
| </style> | |
| """) | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| text_in = gr.Textbox( | |
| label="Input text (use [MASK] to see predictions)", | |
| value="The capital of France is [MASK].", | |
| lines=3, | |
| placeholder="Type a sentence; include [MASK] somewhere." | |
| ) | |
| layer_slider = gr.Slider( | |
| minimum=1, | |
| maximum=NUM_LAYERS, | |
| value=4, | |
| step=1, | |
| label=f"Layer to visualise (1–{NUM_LAYERS})" | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLE_TEXTS, | |
| inputs=text_in, | |
| label="Example prompts" | |
| ) | |
| run_btn = gr.Button("Run", variant="primary") | |
| with gr.Column(scale=5): | |
| tokens_html = gr.HTML(label="Token representations", elem_id="tokens-html") | |
| with gr.Row(): | |
| pred_out = gr.Markdown(label="Layer-wise predictions at [MASK]") | |
| prob_plot = gr.Plot(label="Probability across layers") | |
| att_plot = gr.Plot(label="Self-attention heatmap (selected layer, head 1)") | |
| info_box = gr.Markdown(label="Explanation") | |
| run_btn.click( | |
| analyze, | |
| inputs=[text_in, layer_slider], | |
| outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box], | |
| ) | |
| # Allows instant update without clicking Run | |
| text_in.change( | |
| analyze, | |
| inputs=[text_in, layer_slider], | |
| outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box], | |
| ) | |
| layer_slider.change( | |
| analyze, | |
| inputs=[text_in, layer_slider], | |
| outputs=[tokens_html, pred_out, prob_plot, att_plot, info_box], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |