""" Streamlit app for interactive Semantic and Temperature Scope visualizations. """ import base64 import gc from html import escape as html_escape import os import sys import numpy as np import matplotlib matplotlib.use('Agg') # Non-interactive backend for Streamlit import matplotlib as mpl import matplotlib.pyplot as plt import streamlit as st import torch from matplotlib.colors import LogNorm as Log_Norm from matplotlib.colors import Normalize as Norm from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer # Add current directory to path for JCBScope_utils _APP_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, _APP_DIR) import JCBScope_utils import JacobianScopes DESIGN_DIR = os.path.join(_APP_DIR, "design") if not os.path.exists(DESIGN_DIR): DESIGN_DIR = os.path.join(os.path.dirname(_APP_DIR), "design") # Device configuration: use CPU to match notebook and avoid device_map complexity device = torch.device("cpu") @st.cache_data def _load_svg(path: str) -> str | None: """Load SVG file content; returns None if not found.""" if not os.path.exists(path): return None with open(path, encoding="utf-8") as f: return f.read() def _render_svg_html(svg_content: str, max_width: int = 140) -> str: """Return HTML to render SVG via base64 (reliable in Streamlit).""" b64 = base64.b64encode(svg_content.encode("utf-8")).decode("utf-8") return f'' @st.cache_resource def load_model(model_name: str = "meta-llama/Llama-3.2-1B"): """Load and cache the tokenizer and model.""" token = os.environ.get("HF_TOKEN") tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) model = AutoModelForCausalLM.from_pretrained(model_name, token=token) model = model.to(device) return tokenizer, model def check_target_single_token(tokenizer, target_str: str) -> tuple[bool, list[int] | None]: """ Check that target is exactly one token. Returns (ok, ids) or (False, None). Uses target_str as-is (no strip) so e.g. " truthful" stays one token. """ ids = tokenizer(target_str, add_special_tokens=False)["input_ids"] if len(ids) != 1: return False, None return True, ids def _is_comma_delimited_numbers(s: str) -> bool: """Check if string is comma-delimited, two-digit integers.""" try: parts = [x.strip() for x in s.split(",") if x.strip()] return len(parts) > 0 and all(p.lstrip("-").isdigit() for p in parts) except Exception: return False def _sort_key_for_token(s: str): """Numeric tokens by value; others by lexicographic order. Total order.""" try: return (0, float(s)) except ValueError: return (1, s) def compute_attribution( string: str, mode: str, tokenizer, model, target_str: str | None = None, front_pad: int = 2, input_type: str = "text", ): """ Compute attribution using Temperature, Semantic, or Fisher Scope. input_type: "text" or "comma_delimited". For comma_delimited, attribution skips delimiter tokens. """ if mode not in ["Temperature", "Semantic", "Fisher"]: raise ValueError(f"Invalid mode '{mode}'. Must be 'Temperature', 'Semantic', or 'Fisher'.") if mode == "Semantic" and (not target_str or not target_str.strip()): raise ValueError("Semantic Scope requires a target token.") if mode == "Semantic": ok, target_id = check_target_single_token(tokenizer, target_str) if not ok: raise ValueError("Target must be a single token.") if input_type == "comma_delimited" and not _is_comma_delimited_numbers(string): raise ValueError("Input is not valid comma-delimited numbers.") back_pad = 0 bos_token_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id input_ids_list = [] if bos_token_id is not None: input_ids_list += [bos_token_id] * front_pad input_ids_list += tokenizer(string, add_special_tokens=False)["input_ids"] if eos_token_id is not None: input_ids_list += [eos_token_id] * back_pad embedding_layer = model.get_input_embeddings() target_device = embedding_layer.weight.device input_ids = torch.tensor([input_ids_list], dtype=torch.long).to(target_device) decoded_tokens = [ tokenizer.decode(tok.item(), skip_special_tokens=True, clean_up_tokenization_spaces=False) for tok in input_ids[0] ] attention_mask = torch.ones_like(input_ids) # assert input_ids.max() < model.config.vocab_size, "Token IDs exceed vocab size" # assert input_ids.min() >= 0, "Token IDs must be non-negative" if input_type == "comma_delimited": grad_idx = list(range(front_pad, len(decoded_tokens), 2)) # Skip delimiter tokens else: grad_idx = list(range(front_pad, len(decoded_tokens))) d_model = embedding_layer.embedding_dim residual = nn.Parameter(torch.zeros(len(grad_idx), d_model, device=target_device)) presence = torch.ones(len(decoded_tokens), 1, device=target_device) forward_pass = JCBScope_utils.customize_forward_pass( model, residual, presence, input_ids, grad_idx, attention_mask ) loss_position = len(decoded_tokens) - 1 if mode == "Temperature": scores, logits = JacobianScopes.temperature_scope_scores( forward_pass, residual, loss_position ) elif mode == "Semantic": scores, logits = JacobianScopes.semantic_scope_scores( forward_pass, residual, loss_position, target_id = target_id ) elif mode == "Fisher": lm_head = JCBScope_utils.get_lm_head(model) scores, logits = JacobianScopes.fisher_scope_scores( forward_pass, residual, loss_position, lm_head, method="low_rank", ) out = { "decoded_tokens": decoded_tokens, "grad_idx": grad_idx, "scores": scores, "grads": None, "loss_position": loss_position, "hidden_norm_as_loss": mode == "Temperature", "loss": None, "logits": logits, "input_type": input_type, } if mode == "Semantic" and target_str: out["target_str"] = target_str # For visualization: append target in red if input_type == "comma_delimited": raw = [int(x.strip()) for x in string.split(",") if x.strip()] out["int_list"] = raw[: len(grad_idx)] # align with grad_idx length return out def rgba_to_css(rgba): """Convert matplotlib RGBA to CSS rgba string.""" return f"rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, {rgba[3]:.2f})" def get_text_color(bg_rgba): """Return white or black text based on background luminance.""" luminance = 0.299 * bg_rgba[0] + 0.587 * bg_rgba[1] + 0.114 * bg_rgba[2] return "white" if luminance < 0.5 else "black" def render_attribution_html(result, log_color: bool = False, cmap_name: str = "Blues"): """ Render attribution as HTML with colored token boxes (from notebook routine). Semantic Scope: appends the target token in red. Temperature Scope: appends '' in red. """ decoded_tokens = result["decoded_tokens"] grad_idx = result["grad_idx"] if result.get("scores") is not None: grad_magnitude = torch.tensor(result["scores"], dtype=torch.float32) else: grads = result["grads"] grad_magnitude = grads.norm(dim=-1).squeeze().detach().clone() loss_position = result["loss_position"] target_str = result.get("target_str") # Semantic: append target in red; Temperature: append hardset_target_grad = True exclude_target = False # Semantic: red box with target token. Temperature: red box with "" suffix_red = target_str if target_str is not None else "" cmap = plt.get_cmap(cmap_name) if exclude_target: optimized_tokens = [decoded_tokens[idx] for idx in grad_idx][:-1] else: optimized_tokens = [decoded_tokens[idx] for idx in grad_idx] tick_label_text = optimized_tokens.copy() append_suffix_in_red = True # Semantic: target token; Temperature: "" if grad_magnitude.dim() > 1: grad_magnitude = grad_magnitude.squeeze() bar_idx = None if not exclude_target and hardset_target_grad and (loss_position + 1) in grad_idx: target_idx_in_grad = grad_idx.index(loss_position + 1) if target_idx_in_grad > 0: prev_max = grad_magnitude[:target_idx_in_grad].max().item() grad_magnitude[target_idx_in_grad] = max(prev_max, 1e-8) else: grad_magnitude[target_idx_in_grad] = 1e-8 bar_idx = target_idx_in_grad grad_np = grad_magnitude.float().cpu().numpy() log_norm = Log_Norm(vmin=grad_np.min(), vmax=grad_np.max()) norm = Norm(vmin=grad_np.min(), vmax=grad_np.max()) if log_color: colors = cmap(log_norm(grad_np)) else: colors = cmap(norm(grad_np)) html_parts = [] for i, (token, color) in enumerate(zip(tick_label_text, colors)): bg_color = rgba_to_css(color) text_color = get_text_color(color) if bar_idx is not None and i == bar_idx and hardset_target_grad: bg_color = "red" text_color = "white" display_token = token html_parts.append( f'{display_token}' ) if append_suffix_in_red: # Escape HTML so e.g. "" displays correctly (browsers parse < > as tags) suffix_safe = html_escape(suffix_red) html_parts.append( f'{suffix_safe}' ) html_str = f'''
{"".join(html_parts)}
''' # Color bar (from notebook): horizontal, matching the color mapping fig_bar, ax_bar = plt.subplots(figsize=(8, 0.2), dpi=100) # fig_bar.subplots_adjust(left=0.3, right=0.7, bottom=0.1, top=0.9) cbar = mpl.colorbar.ColorbarBase( ax_bar, cmap=cmap, norm=log_norm if log_color else norm, orientation="horizontal", ) cbar.set_label("Influence") return html_str, fig_bar def render_attribution_barplot(result, log_color: bool = False, cmap_name: str = "Blues"): """ Bar plot with double axes for comma-delimited input: Influence (left) and Token value (right). """ grad_idx = result["grad_idx"] if result.get("scores") is not None: grad_magnitude = np.array(result["scores"], dtype=np.float32).copy() else: grads = result["grads"] if len(grads.shape) == 2: grad_magnitude = grads.norm(dim=-1).squeeze().detach().clone().float().cpu().numpy() else: grad_magnitude = grads.detach().clone().float().cpu().numpy() loss_position = result["loss_position"] int_list = result["int_list"] front_pad = 2 # assumed hardset_target_grad = True target_bar_index = None if hardset_target_grad and (loss_position + 1) in grad_idx: target_bar_index = grad_idx.index(loss_position + 1) grad_magnitude[target_bar_index] = max(grad_magnitude) ax1_color = np.array([10, 110, 230]) / 256 ax2_color = np.array([230, 20, 20]) / 256 x_labels = [x - front_pad for x in grad_idx] fig, ax = plt.subplots(figsize=(10, 2.5), dpi=120) bars = ax.bar( range(grad_magnitude.shape[0]), grad_magnitude, tick_label=x_labels, color=ax1_color, linewidth=0.5, edgecolor="black", width=1.0, alpha=0.9, ) if target_bar_index is not None: bars[target_bar_index].set_color("red") bars[target_bar_index].set_width(1.1) ax2 = ax.twinx() ax2.scatter(range(len(int_list)), int_list, color=ax2_color, marker="o", s=13, alpha=0.9) ax2.plot(range(len(int_list)), int_list, color=ax2_color, linewidth=1.5, alpha=0.5) ax2.tick_params(axis="y", colors=ax2_color, labelsize=10) ax.tick_params(axis="y", colors=ax1_color, labelsize=10) # At most 5 x-axis labels n_bars = grad_magnitude.shape[0] n_labels = min(5, n_bars) if n_labels > 0: tick_indices = np.linspace(0, n_bars - 1, n_labels, dtype=int) ax.set_xticks(tick_indices) ax.set_xticklabels([x_labels[i] for i in tick_indices], fontsize=10) ax.set_xlabel("Token position index", fontsize=10, fontweight="bold") ax.set_ylabel("Influence", labelpad=2, color=ax1_color, fontsize=10, fontweight="bold") ax2.set_ylabel("Token value", labelpad=2, color=ax2_color, fontsize=10, fontweight="bold") ax.set_axisbelow(True) ax.xaxis.grid(True, which="both", linestyle="--", linewidth=0.3, alpha=0.7) ax.yaxis.grid(True, which="both", linestyle="--", linewidth=0.3, alpha=0.7) if log_color: ax.set_yscale("log") ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True, prune='lower')) plt.tight_layout() return fig def main(): st.set_page_config(page_title="Jacobian Scopes Demo", page_icon="🔍", layout="centered") st.title("🔍 Jacobian Scopes Demo") st.markdown( 'Interactive demonstrations for Jacobian Scopes: token-level causal attributions in LLMs.
' 'Github Repo: https://github.com/AntonioLiu97/JacobianScopes', unsafe_allow_html=True, ) # Keep scope columns on one line (Streamlit stacks them below ~640px by default) st.markdown( '', unsafe_allow_html=True, ) scope_col1, scope_div1, scope_col2, scope_div2, scope_col3 = st.columns([1, 0.02, 1, 0.02, 1]) semantic_svg = _load_svg(os.path.join(DESIGN_DIR, "semantic_scope_button.svg")) temp_svg = _load_svg(os.path.join(DESIGN_DIR, "temperature_scope_button.svg")) fisher_svg = _load_svg(os.path.join(DESIGN_DIR, "fisher_scope_button.svg")) with scope_col1: if semantic_svg: st.markdown(_render_svg_html(semantic_svg), unsafe_allow_html=True) st.markdown( "**Semantic Scope** — explains the predicted logit for a specific target token. " "Enter your input passage along with a target token." ) with scope_div1: st.markdown( '
', # '
', unsafe_allow_html=True, ) with scope_col2: if temp_svg: st.markdown(_render_svg_html(temp_svg), unsafe_allow_html=True) st.markdown( "**Temperature Scope** — explains the confidence (effective inverse temperature) of the predictive distribution. " "Particularly effective for attributing time-series predictions. " "Target token not required." ) with scope_div2: st.markdown( '
', unsafe_allow_html=True, ) with scope_col3: if fisher_svg: st.markdown(_render_svg_html(fisher_svg), unsafe_allow_html=True) st.markdown( "**Fisher Scope** — explains the overall predictive distribution using low-rank appxroximation of the Fisher information matrix. " "Best suited for textual data. " "Target token not required." ) model_choice = st.selectbox( "Model", options=["LLaMA 3.2 1B", "LLaMA 3.2 3B", "SmolLM3-3B-Base"], index=0, key="model_choice", help="Choose model.", ) MODEL_MAP = { "LLaMA 3.2 1B": "meta-llama/Llama-3.2-1B", "LLaMA 3.2 3B": "meta-llama/Llama-3.2-3B", "SmolLM3-3B-Base": "HuggingFaceTB/SmolLM3-3B-Base", } model_name = MODEL_MAP[model_choice] attribution_type = st.radio( "Scope type", options=["Semantic Scope", "Temperature Scope", "Fisher Scope"], index=0, horizontal=True, key="attribution_type", # help="Semantic Scope: attribute toward a target token. Temperature Scope: use hidden-state norm.", ) mode = "Semantic" if attribution_type == "Semantic Scope" else "Temperature" if attribution_type == "Temperature Scope" else "Fisher" if mode == "Semantic": input_type = "text" is_comma_delimited = False else: if mode == "Temperature": input_type_default = "comma_delimited" else: input_type_default ="text" input_type = st.radio( "Input type", options=["text", "comma-delimited numbers"], index=0 if input_type_default == "text" else 1, horizontal=True, key=f"input_type_{mode}", help="Text: natural language. Comma-delimited numbers: time-series style. Delimiters are skipped for attribution.", ) is_comma_delimited = input_type == "comma-delimited numbers" if is_comma_delimited: default_text = ( "80,68,57,52,50,49,48,46,42,35,23,14,24,40,49,54,57,60,66,74,79,74,64,58,55,55,57,61,68,77,80,71,60,54,52,51,52,53,55,61,70,83,83,66,53,47,44,41,36,28,22,23,32,40,44,44,43,40,33,24,19,26,37,44,47,47,47,45,40,32,21,16,28,42,49,52,55,58,63,71,80,79,67,58,53,51,51,51,52,55,59,69,82,84,69,54,47,43,40,35,28,22,24,32,39,43,43,41,37,30,22,22,31,39,44,45,44,41,36,27,19,22,34,43,47,49,49,48,47,45,40,31,18,15,31,46,53,57,60,65,72,77,75,67,60,57,57,59,64,71,78,77,68,60,56,55,56,60,66,75,81,75,63,56,53,52,52,54,57,62,73," ) elif mode == "Semantic": default_text = ( "As a state-of-the-art AI assistant, you never argue or deceive, because you are" ) else: default_text = ( # "Italiano: Ma quando tu sarai nel dolce mondo, priegoti ch'a la mente altrui mi rechi: English: But when you have returned to the sweet world, I pray you" "French: Cet article porte sur l'attribution causale, que nous appelons lentille jacobienne. English: This is a paper on causal attribution, and we call it Jacobian" ) text_placeholder = "Input text" if mode == "Semantic" else "Input text or comma-delimited numbers" text_help = "Natural language input." if mode == "Semantic" else "Text or comma-separated numbers. Delimiters are skipped for comma-delimited." text_input = st.text_area( "Input text", value=default_text, height=120, key=f"text_input_{mode}_{input_type}", placeholder=text_placeholder, help=text_help, ) st.caption(f"Characters: {len(text_input)}") target_str = None if mode == "Semantic": target_str = st.text_input( "Target token (tip: most tokenized words start with a space character)", value=" truthful", placeholder='e.g., " truthful" or " nice"', help="Must be representable as a single token. Most tokenized words lead with a space character (e.g. ' truthful' for Llama).", ) st.caption(f"Characters: {len(target_str or '')}") compute_clicked = st.button("Compute Attribution!", type="primary", use_container_width=True) input_type_param = "comma_delimited" if is_comma_delimited else "text" if compute_clicked: if not text_input.strip(): st.error("Please enter some text.") elif mode == "Semantic" and (not target_str or not target_str.strip()): st.error("Please enter a target token for Semantic Scope.") elif is_comma_delimited and not _is_comma_delimited_numbers(text_input.strip()): st.error("Input is not valid comma-delimited numbers.") else: # Progress bar for model loading and attribution progress_text = st.empty() progress_bar = st.progress(0) try: progress_text.write("Step 1/3: Preparing environment...") torch.cuda.empty_cache() torch.cuda.ipc_collect() if torch.cuda.is_available() else None gc.collect() progress_bar.progress(25) progress_text.write("Step 2/3: Loading model...") tokenizer, model = load_model(model_name=model_name) progress_bar.progress(60) progress_text.write(f"Step 3/3: Computing {mode} Scope...") result = compute_attribution( text_input, mode, tokenizer, model, target_str=target_str, input_type=input_type_param, ) progress_bar.progress(100) st.session_state["attribution_result"] = result st.session_state["tokenizer"] = tokenizer st.success("Attribution successful!") except ValueError as e: if "Target not in token dictionary" in str(e): st.error("Target not in token dictionary.") else: st.error(str(e)) except Exception as e: st.error(f"Error: {e}") raise # Visualization (uses cached result; log_color and cmap are post-compute only) if "attribution_result" in st.session_state: result = st.session_state["attribution_result"] tokenizer = st.session_state["tokenizer"] st.subheader("Attribution Visualization") # Adjustable after compute — does not trigger recompute viz_col1, viz_col2 = st.columns([1, 1]) with viz_col1: log_color = st.checkbox( "Log-scale", value=False, key="log_color", help="Use log scale for influence values.", ) with viz_col2: cmap_choice = st.selectbox( "Color map", options=["Blues", "Greens", "viridis"], index=0, key="cmap_choice", help="Colormap for attribution visualization.", ) if result.get("input_type") == "comma_delimited": fig_barplot = render_attribution_barplot( result, log_color=log_color, cmap_name=cmap_choice ) st.pyplot(fig_barplot) plt.close(fig_barplot) else: html_output, fig_colorbar = render_attribution_html( result, log_color=log_color, cmap_name=cmap_choice ) st.markdown(html_output, unsafe_allow_html=True) st.pyplot(fig_colorbar) plt.close(fig_colorbar) st.subheader("Top-15 predicted next tokens") k = 15 logit_vector = result["logits"][result["loss_position"]].detach() probs = torch.softmax(logit_vector, dim=-1) top_probs, top_indices = torch.topk(probs, k) top_tokens = [tokenizer.decode([idx]) for idx in top_indices] if result.get("input_type") == "comma_delimited": # Temperature Scope comma-delimited: order by string value (numbers increasing, else lex) paired = list(zip(top_tokens, top_indices.tolist(), top_probs.tolist())) paired.sort(key=lambda x: _sort_key_for_token(x[0])) top_tokens = [p[0] for p in paired] top_probs = torch.tensor([p[2] for p in paired], dtype=top_probs.dtype) prob_np = top_probs.float().cpu().numpy() fig_pred, ax_pred = plt.subplots(figsize=(8, 3), dpi=100) x_pos = range(k) bars = ax_pred.bar(x_pos, prob_np, color="red", edgecolor="darkred", linewidth=0.5) ax_pred.set_xticks(x_pos) ax_pred.set_xticklabels([repr(t) for t in top_tokens], rotation=45, ha="right") ax_pred.set_ylabel("Probability") ax_pred.set_ylim(0, max(prob_np) * 1.1 if prob_np.max() > 0 else 1) plt.tight_layout() st.pyplot(fig_pred) plt.close(fig_pred) st.divider() with st.expander("Citation Information", expanded=True): st.markdown("**Jacobian Scopes Demo © 2026 Toni Jianbang Liu.**") st.markdown("If you use this demo in your work, please cite:") st.markdown( "Liu, T. J., Zadeoğlu, B., Boullé, N., Sarfati, R., & Earls, C. J. (2026). " "*Jacobian Scopes: token-level causal attributions in LLMs.* arXiv preprint arXiv:2601.16407." ) st.markdown("**BibTeX:**") st.code( """@misc{liu2026jacobianscopestokenlevelcausal, title={Jacobian Scopes: token-level causal attributions in LLMs}, author={Toni J. B. Liu and Baran Zadeoğlu and Nicolas Boullé and Raphaël Sarfati and Christopher J. Earls}, year={2026}, eprint={2601.16407}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/2601.16407}, }""", language=None, ) if __name__ == "__main__": main()