File size: 26,427 Bytes
9080536
cb30405
9080536
 
cb30405
9080536
c6e40db
9080536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb30405
 
9080536
83a9d5c
9080536
cb30405
 
 
 
9080536
 
 
 
cb30405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9080536
 
 
e863bb2
 
 
9080536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f3602
9080536
 
 
 
 
 
 
c6e40db
 
 
 
 
 
 
 
9080536
 
 
 
 
 
 
 
 
 
83a9d5c
9080536
 
 
83a9d5c
 
9080536
 
 
83a9d5c
 
 
 
9080536
c6e40db
9080536
 
 
 
 
 
 
 
 
 
c6e40db
9080536
 
 
 
 
 
 
c6e40db
 
 
83a9d5c
 
 
 
 
 
 
9080536
 
 
 
 
 
83a9d5c
9080536
 
 
 
 
 
 
 
83a9d5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9080536
 
 
 
83a9d5c
 
9080536
83a9d5c
 
9080536
 
 
 
 
 
c6e40db
9080536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6e40db
9080536
 
 
83a9d5c
 
 
 
 
9080536
c6e40db
9080536
 
c6e40db
 
9080536
 
 
 
 
 
 
 
 
c6e40db
9080536
83a9d5c
 
9080536
 
 
 
 
 
 
 
 
 
 
0d820e3
9080536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6e40db
 
 
9080536
 
 
 
 
 
 
 
 
 
 
c6e40db
9080536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6e40db
 
9080536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83a9d5c
 
 
 
 
 
 
 
9080536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83a9d5c
 
 
9080536
 
 
 
 
cb30405
 
1d19a3b
fe6a822
 
 
1d19a3b
 
cb30405
9080536
cb30405
 
 
9080536
cb30405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83a9d5c
cb30405
 
 
 
 
 
 
 
 
 
 
5de6c8a
83a9d5c
 
cb30405
9080536
 
 
5661c7c
9080536
 
 
 
 
 
 
 
 
 
 
 
12f3602
83a9d5c
9080536
 
cb30405
12f3602
9080536
83a9d5c
9080536
12f3602
 
 
 
83a9d5c
 
 
 
12f3602
 
 
 
 
 
 
 
 
9080536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f3602
 
9080536
 
 
 
 
12f3602
 
9080536
c6e40db
9080536
 
 
 
c6e40db
9080536
 
069f1df
9080536
c6e40db
9080536
 
 
 
 
 
 
 
 
 
 
 
 
83a9d5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9080536
 
 
 
 
 
 
 
 
 
 
 
83a9d5c
9080536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6e40db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9080536
dc9de62
fe6a822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc9de62
9080536
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
"""
Streamlit app for interactive Semantic and Temperature Scope visualizations.
"""

import base64
import gc
from html import escape as html_escape
import os
import sys

import numpy as np
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend for Streamlit
import matplotlib as mpl
import matplotlib.pyplot as plt
import streamlit as st
import torch
from matplotlib.colors import LogNorm as Log_Norm
from matplotlib.colors import Normalize as Norm
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

# Add current directory to path for JCBScope_utils
_APP_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, _APP_DIR)
import JCBScope_utils
import JacobianScopes

DESIGN_DIR = os.path.join(_APP_DIR, "design")
if not os.path.exists(DESIGN_DIR):
    DESIGN_DIR = os.path.join(os.path.dirname(_APP_DIR), "design")

# Device configuration: use CPU to match notebook and avoid device_map complexity
device = torch.device("cpu")


@st.cache_data
def _load_svg(path: str) -> str | None:
    """Load SVG file content; returns None if not found."""
    if not os.path.exists(path):
        return None
    with open(path, encoding="utf-8") as f:
        return f.read()


def _render_svg_html(svg_content: str, max_width: int = 140) -> str:
    """Return HTML to render SVG via base64 (reliable in Streamlit)."""
    b64 = base64.b64encode(svg_content.encode("utf-8")).decode("utf-8")
    return f'<img src="data:image/svg+xml;base64,{b64}" style="max-width:{max_width}px;height:auto;"/>'


@st.cache_resource
def load_model(model_name: str = "meta-llama/Llama-3.2-1B"):
    """Load and cache the tokenizer and model."""
    token = os.environ.get("HF_TOKEN")
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
    model = AutoModelForCausalLM.from_pretrained(model_name, token=token)
    model = model.to(device)
    return tokenizer, model


def check_target_single_token(tokenizer, target_str: str) -> tuple[bool, list[int] | None]:
    """
    Check that target is exactly one token. Returns (ok, ids) or (False, None).
    Uses target_str as-is (no strip) so e.g. " truthful" stays one token.
    """
    ids = tokenizer(target_str, add_special_tokens=False)["input_ids"]
    if len(ids) != 1:
        return False, None
    return True, ids


def _is_comma_delimited_numbers(s: str) -> bool:
    """Check if string is comma-delimited, two-digit integers."""
    try:
        parts = [x.strip() for x in s.split(",") if x.strip()]
        return len(parts) > 0 and all(p.lstrip("-").isdigit() for p in parts)
    except Exception:
        return False


def _sort_key_for_token(s: str):
    """Numeric tokens by value; others by lexicographic order. Total order."""
    try:
        return (0, float(s))
    except ValueError:
        return (1, s)


def compute_attribution(
    string: str,
    mode: str,
    tokenizer,
    model,
    target_str: str | None = None,
    front_pad: int = 2,
    input_type: str = "text",
):
    """
    Compute attribution using Temperature, Semantic, or Fisher Scope.

    input_type: "text" or "comma_delimited". For comma_delimited, attribution skips delimiter tokens.
    """
    if mode not in ["Temperature", "Semantic", "Fisher"]:
        raise ValueError(f"Invalid mode '{mode}'. Must be 'Temperature', 'Semantic', or 'Fisher'.")

    if mode == "Semantic" and (not target_str or not target_str.strip()):
        raise ValueError("Semantic Scope requires a target token.")
    if mode == "Semantic":
        ok, target_id = check_target_single_token(tokenizer, target_str)
        if not ok:
            raise ValueError("Target must be a single token.")

    if input_type == "comma_delimited" and not _is_comma_delimited_numbers(string):
        raise ValueError("Input is not valid comma-delimited numbers.")

    back_pad = 0

    bos_token_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id
    eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id

    input_ids_list = []
    if bos_token_id is not None:
        input_ids_list += [bos_token_id] * front_pad
    input_ids_list += tokenizer(string, add_special_tokens=False)["input_ids"]
    if eos_token_id is not None:
        input_ids_list += [eos_token_id] * back_pad

    embedding_layer = model.get_input_embeddings()
    target_device = embedding_layer.weight.device

    input_ids = torch.tensor([input_ids_list], dtype=torch.long).to(target_device)
    decoded_tokens = [
        tokenizer.decode(tok.item(), skip_special_tokens=True, clean_up_tokenization_spaces=False)
        for tok in input_ids[0]
    ]    

    attention_mask = torch.ones_like(input_ids)
    # assert input_ids.max() < model.config.vocab_size, "Token IDs exceed vocab size"
    # assert input_ids.min() >= 0, "Token IDs must be non-negative"



    if input_type == "comma_delimited":
        grad_idx = list(range(front_pad, len(decoded_tokens), 2))  # Skip delimiter tokens
    else:
        grad_idx = list(range(front_pad, len(decoded_tokens)))

    

    d_model = embedding_layer.embedding_dim
    residual = nn.Parameter(torch.zeros(len(grad_idx), d_model, device=target_device))
    presence = torch.ones(len(decoded_tokens), 1, device=target_device)

    forward_pass = JCBScope_utils.customize_forward_pass(
        model, residual, presence, input_ids, grad_idx, attention_mask
    )
    loss_position = len(decoded_tokens) - 1
    if mode == "Temperature":
        scores, logits = JacobianScopes.temperature_scope_scores(
            forward_pass, residual, loss_position
        )
    elif mode == "Semantic":
        scores, logits = JacobianScopes.semantic_scope_scores(
            forward_pass, residual, loss_position, target_id = target_id
        )
    elif mode == "Fisher":
        lm_head = JCBScope_utils.get_lm_head(model)
        scores, logits = JacobianScopes.fisher_scope_scores(
            forward_pass,
            residual,
            loss_position,
            lm_head,
            method="low_rank",
        )

    out = {
        "decoded_tokens": decoded_tokens,
        "grad_idx": grad_idx,
        "scores": scores,
        "grads": None,
        "loss_position": loss_position,
        "hidden_norm_as_loss": mode == "Temperature",
        "loss": None,
        "logits": logits,
        "input_type": input_type,
    }
    if mode == "Semantic" and target_str:
        out["target_str"] = target_str  # For visualization: append target in red
    if input_type == "comma_delimited":
        raw = [int(x.strip()) for x in string.split(",") if x.strip()]
        out["int_list"] = raw[: len(grad_idx)]  # align with grad_idx length
    return out


def rgba_to_css(rgba):
    """Convert matplotlib RGBA to CSS rgba string."""
    return f"rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, {rgba[3]:.2f})"


def get_text_color(bg_rgba):
    """Return white or black text based on background luminance."""
    luminance = 0.299 * bg_rgba[0] + 0.587 * bg_rgba[1] + 0.114 * bg_rgba[2]
    return "white" if luminance < 0.5 else "black"


def render_attribution_html(result, log_color: bool = False, cmap_name: str = "Blues"):
    """
    Render attribution as HTML with colored token boxes (from notebook routine).
    Semantic Scope: appends the target token in red. Temperature Scope: appends '<predicted distribution>' in red.
    """
    decoded_tokens = result["decoded_tokens"]
    grad_idx = result["grad_idx"]
    if result.get("scores") is not None:
        grad_magnitude = torch.tensor(result["scores"], dtype=torch.float32)
    else:
        grads = result["grads"]
        grad_magnitude = grads.norm(dim=-1).squeeze().detach().clone()
    loss_position = result["loss_position"]
    target_str = result.get("target_str")  # Semantic: append target in red; Temperature: append <target dist>
    hardset_target_grad = True
    exclude_target = False
    # Semantic: red box with target token. Temperature: red box with "<predicted distribution>"
    suffix_red = target_str if target_str is not None else "<predicted distribution>"

    cmap = plt.get_cmap(cmap_name)

    if exclude_target:
        optimized_tokens = [decoded_tokens[idx] for idx in grad_idx][:-1]
    else:
        optimized_tokens = [decoded_tokens[idx] for idx in grad_idx]

    tick_label_text = optimized_tokens.copy()
    append_suffix_in_red = True  # Semantic: target token; Temperature: "<predicted distribution>"

    if grad_magnitude.dim() > 1:
        grad_magnitude = grad_magnitude.squeeze()

    bar_idx = None
    if not exclude_target and hardset_target_grad and (loss_position + 1) in grad_idx:
        target_idx_in_grad = grad_idx.index(loss_position + 1)
        if target_idx_in_grad > 0:
            prev_max = grad_magnitude[:target_idx_in_grad].max().item()
            grad_magnitude[target_idx_in_grad] = max(prev_max, 1e-8)
        else:
            grad_magnitude[target_idx_in_grad] = 1e-8
        bar_idx = target_idx_in_grad

    grad_np = grad_magnitude.float().cpu().numpy()
    log_norm = Log_Norm(vmin=grad_np.min(), vmax=grad_np.max())
    norm = Norm(vmin=grad_np.min(), vmax=grad_np.max())

    if log_color:
        colors = cmap(log_norm(grad_np))
    else:
        colors = cmap(norm(grad_np))

    html_parts = []
    for i, (token, color) in enumerate(zip(tick_label_text, colors)):
        bg_color = rgba_to_css(color)
        text_color = get_text_color(color)

        if bar_idx is not None and i == bar_idx and hardset_target_grad:
            bg_color = "red"
            text_color = "white"

        display_token = token
        html_parts.append(
            f'<span style="'
            f"background-color: {bg_color}; "
            f"color: {text_color}; "
            f"padding: 0px 0px; "
            f"margin: 0px; "
            f"border-radius: 0px; "
            f"font-family: monospace; "
            f"font-size: 16px; "
            f"display: inline-block; "
            f"font-weight: bold; "
            f'white-space: pre;">{display_token}</span>'
        )

    if append_suffix_in_red:
        # Escape HTML so e.g. "<predicted distribution>" displays correctly (browsers parse < > as tags)
        suffix_safe = html_escape(suffix_red)
        html_parts.append(
            f'<span style="'
            f"background-color: red; "
            f"color: white; "
            f"padding: 0px 0px; "
            f"margin: 0px; "
            f"border-radius: 0px; "
            f"font-family: monospace; "
            f"font-size: 16px; "
            f"display: inline-block; "
            f"font-weight: bold; "
            f'white-space: pre;">{suffix_safe}</span>'
        )

    html_str = f'''
<div style="
    background: white;
    padding: 20px;
    border-radius: 8px;
    line-height: 2.2;
    width: 100%;
    max-width: 700px;
">
    {"".join(html_parts)}
</div>
'''
    # Color bar (from notebook): horizontal, matching the color mapping
    fig_bar, ax_bar = plt.subplots(figsize=(8, 0.2), dpi=100)
    # fig_bar.subplots_adjust(left=0.3, right=0.7, bottom=0.1, top=0.9)
    cbar = mpl.colorbar.ColorbarBase(
        ax_bar,
        cmap=cmap,
        norm=log_norm if log_color else norm,
        orientation="horizontal",
    )
    cbar.set_label("Influence")

    return html_str, fig_bar


def render_attribution_barplot(result, log_color: bool = False, cmap_name: str = "Blues"):
    """
    Bar plot with double axes for comma-delimited input: Influence (left) and Token value (right).
    """
    grad_idx = result["grad_idx"]
    if result.get("scores") is not None:
        grad_magnitude = np.array(result["scores"], dtype=np.float32).copy()
    else:
        grads = result["grads"]
        if len(grads.shape) == 2:
            grad_magnitude = grads.norm(dim=-1).squeeze().detach().clone().float().cpu().numpy()
        else:
            grad_magnitude = grads.detach().clone().float().cpu().numpy()
    loss_position = result["loss_position"]
    int_list = result["int_list"]
    front_pad = 2  # assumed

    hardset_target_grad = True
    target_bar_index = None
    if hardset_target_grad and (loss_position + 1) in grad_idx:
        target_bar_index = grad_idx.index(loss_position + 1)
        grad_magnitude[target_bar_index] = max(grad_magnitude)

    ax1_color = np.array([10, 110, 230]) / 256
    ax2_color = np.array([230, 20, 20]) / 256

    x_labels = [x - front_pad for x in grad_idx]

    fig, ax = plt.subplots(figsize=(10, 2.5), dpi=120)
    bars = ax.bar(
        range(grad_magnitude.shape[0]),
        grad_magnitude,
        tick_label=x_labels,
        color=ax1_color,
        linewidth=0.5,
        edgecolor="black",
        width=1.0,
        alpha=0.9,
    )
    if target_bar_index is not None:
        bars[target_bar_index].set_color("red")
        bars[target_bar_index].set_width(1.1)

    ax2 = ax.twinx()
    ax2.scatter(range(len(int_list)), int_list, color=ax2_color, marker="o", s=13, alpha=0.9)
    ax2.plot(range(len(int_list)), int_list, color=ax2_color, linewidth=1.5, alpha=0.5)

    ax2.tick_params(axis="y", colors=ax2_color, labelsize=10)
    ax.tick_params(axis="y", colors=ax1_color, labelsize=10)

    # At most 5 x-axis labels
    n_bars = grad_magnitude.shape[0]
    n_labels = min(5, n_bars)
    if n_labels > 0:
        tick_indices = np.linspace(0, n_bars - 1, n_labels, dtype=int)
        ax.set_xticks(tick_indices)
        ax.set_xticklabels([x_labels[i] for i in tick_indices], fontsize=10)

    ax.set_xlabel("Token position index", fontsize=10, fontweight="bold")
    ax.set_ylabel("Influence", labelpad=2, color=ax1_color, fontsize=10, fontweight="bold")
    ax2.set_ylabel("Token value", labelpad=2, color=ax2_color, fontsize=10, fontweight="bold")

    ax.set_axisbelow(True)
    ax.xaxis.grid(True, which="both", linestyle="--", linewidth=0.3, alpha=0.7)
    ax.yaxis.grid(True, which="both", linestyle="--", linewidth=0.3, alpha=0.7)
    if log_color:
        ax.set_yscale("log")
        ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True, prune='lower'))
    plt.tight_layout()
    return fig


def main():
    st.set_page_config(page_title="Jacobian Scopes Demo", page_icon="🔍", layout="centered")
    st.title("🔍 Jacobian Scopes Demo")
    st.markdown(
        'Interactive demonstrations for <a href="https://arxiv.org/abs/2601.16407"><b>Jacobian Scopes: token-level causal attributions in LLMs</b></a>. <br>'
        'Github Repo: <a href="https://github.com/AntonioLiu97/JacobianScopes">https://github.com/AntonioLiu97/JacobianScopes</a>',
        
        unsafe_allow_html=True,
    )
    # Keep scope columns on one line (Streamlit stacks them below ~640px by default)
    st.markdown(
        '<style>div[data-testid="stHorizontalBlock"]{flex-wrap:nowrap!important}'
        '[data-testid="column"]{min-width:120px!important}</style>',
        unsafe_allow_html=True,
    )
    scope_col1, scope_div1, scope_col2, scope_div2, scope_col3 = st.columns([1, 0.02, 1, 0.02, 1])
    semantic_svg = _load_svg(os.path.join(DESIGN_DIR, "semantic_scope_button.svg"))
    temp_svg = _load_svg(os.path.join(DESIGN_DIR, "temperature_scope_button.svg"))
    fisher_svg = _load_svg(os.path.join(DESIGN_DIR, "fisher_scope_button.svg"))
    with scope_col1:
        if semantic_svg:
            st.markdown(_render_svg_html(semantic_svg), unsafe_allow_html=True)
        st.markdown(
            "**Semantic Scope** — explains the predicted logit for a specific target token. "
            "Enter your input passage along with a target token."
        )
    with scope_div1:
        st.markdown(
            '<div style="border-left: 5px solid #888; min-height: 200px; margin: 0;"></div>',
            # '<div style="border-left: 5px solid steelblue; min-height: 160px; margin: 0;"></div>',
            unsafe_allow_html=True,
        )
    with scope_col2:
        if temp_svg:
            st.markdown(_render_svg_html(temp_svg), unsafe_allow_html=True)
        st.markdown(
            "**Temperature Scope** — explains the confidence (effective inverse temperature) of the predictive distribution. "
            "Particularly effective for attributing time-series predictions. "
            "Target token not required."
        )
    with scope_div2:
        st.markdown(
            '<div style="border-left: 5px solid #888; min-height: 200px; margin: 0;"></div>',
            unsafe_allow_html=True,
        )
    with scope_col3:
        if fisher_svg:
            st.markdown(_render_svg_html(fisher_svg), unsafe_allow_html=True)
        st.markdown(
            "**Fisher Scope** — explains the overall predictive distribution using low-rank appxroximation of the Fisher information matrix. "
            "Best suited for textual data. "
            "Target token not required."
        )

    model_choice = st.selectbox(
        "Model",
        options=["LLaMA 3.2 1B", "LLaMA 3.2 3B", "SmolLM3-3B-Base"],
        index=0,
        key="model_choice",
        help="Choose model.",
    )
    MODEL_MAP = {
        "LLaMA 3.2 1B": "meta-llama/Llama-3.2-1B",
        "LLaMA 3.2 3B": "meta-llama/Llama-3.2-3B",
        "SmolLM3-3B-Base": "HuggingFaceTB/SmolLM3-3B-Base",
    }
    model_name = MODEL_MAP[model_choice]

    attribution_type = st.radio(
        "Scope type",
        options=["Semantic Scope", "Temperature Scope", "Fisher Scope"],
        index=0,
        horizontal=True,
        key="attribution_type",
        # help="Semantic Scope: attribute toward a target token. Temperature Scope: use hidden-state norm.",
    )
    mode = "Semantic" if attribution_type == "Semantic Scope" else "Temperature" if attribution_type == "Temperature Scope" else "Fisher"

    if mode == "Semantic":
        input_type = "text"
        is_comma_delimited = False
    else:
        if mode == "Temperature":
            input_type_default = "comma_delimited" 
        else:
            input_type_default ="text"
        input_type = st.radio(
            "Input type",
            options=["text", "comma-delimited numbers"],
            index=0 if input_type_default == "text" else 1,
            horizontal=True,
            key=f"input_type_{mode}",
            help="Text: natural language. Comma-delimited numbers: time-series style. Delimiters are skipped for attribution.",
        )
        is_comma_delimited = input_type == "comma-delimited numbers"

    if is_comma_delimited:
        default_text = (
            "80,68,57,52,50,49,48,46,42,35,23,14,24,40,49,54,57,60,66,74,79,74,64,58,55,55,57,61,68,77,80,71,60,54,52,51,52,53,55,61,70,83,83,66,53,47,44,41,36,28,22,23,32,40,44,44,43,40,33,24,19,26,37,44,47,47,47,45,40,32,21,16,28,42,49,52,55,58,63,71,80,79,67,58,53,51,51,51,52,55,59,69,82,84,69,54,47,43,40,35,28,22,24,32,39,43,43,41,37,30,22,22,31,39,44,45,44,41,36,27,19,22,34,43,47,49,49,48,47,45,40,31,18,15,31,46,53,57,60,65,72,77,75,67,60,57,57,59,64,71,78,77,68,60,56,55,56,60,66,75,81,75,63,56,53,52,52,54,57,62,73,"
        )
    elif mode == "Semantic":
        default_text = (
            "As a state-of-the-art AI assistant, you never argue or deceive, because you are"
        )
    else:
        default_text = (
            # "Italiano: Ma quando tu sarai nel dolce mondo, priegoti ch'a la mente altrui mi rechi: English: But when you have returned to the sweet world, I pray you"
            "French: Cet article porte sur l'attribution causale, que nous appelons lentille jacobienne. English: This is a paper on causal attribution, and we call it Jacobian"
        )

    text_placeholder = "Input text" if mode == "Semantic" else "Input text or comma-delimited numbers"
    text_help = "Natural language input." if mode == "Semantic" else "Text or comma-separated numbers. Delimiters are skipped for comma-delimited."
    text_input = st.text_area(
        "Input text",
        value=default_text,
        height=120,
        key=f"text_input_{mode}_{input_type}",
        placeholder=text_placeholder,
        help=text_help,
    )
    st.caption(f"Characters: {len(text_input)}")

    target_str = None
    if mode == "Semantic":
        target_str = st.text_input(
            "Target token (tip: most tokenized words start with a space character)",
            value=" truthful",
            placeholder='e.g., " truthful" or " nice"',
            help="Must be representable as a single token. Most tokenized words lead with a space character (e.g. ' truthful' for Llama).",
        )
        st.caption(f"Characters: {len(target_str or '')}")

    compute_clicked = st.button("Compute Attribution!", type="primary", use_container_width=True)

    input_type_param = "comma_delimited" if is_comma_delimited else "text"

    if compute_clicked:
        if not text_input.strip():
            st.error("Please enter some text.")
        elif mode == "Semantic" and (not target_str or not target_str.strip()):
            st.error("Please enter a target token for Semantic Scope.")
        elif is_comma_delimited and not _is_comma_delimited_numbers(text_input.strip()):
            st.error("Input is not valid comma-delimited numbers.")
        else:
            # Progress bar for model loading and attribution
            progress_text = st.empty()
            progress_bar = st.progress(0)

            try:
                progress_text.write("Step 1/3: Preparing environment...")
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect() if torch.cuda.is_available() else None
                gc.collect()
                progress_bar.progress(25)

                progress_text.write("Step 2/3: Loading model...")
                tokenizer, model = load_model(model_name=model_name)
                progress_bar.progress(60)

                progress_text.write(f"Step 3/3: Computing {mode} Scope...")
                result = compute_attribution(
                    text_input,
                    mode,
                    tokenizer,
                    model,
                    target_str=target_str,
                    input_type=input_type_param,
                )
                progress_bar.progress(100)

                st.session_state["attribution_result"] = result
                st.session_state["tokenizer"] = tokenizer

                st.success("Attribution successful!")
            except ValueError as e:
                if "Target not in token dictionary" in str(e):
                    st.error("Target not in token dictionary.")
                else:
                    st.error(str(e))
            except Exception as e:
                st.error(f"Error: {e}")
                raise

    # Visualization (uses cached result; log_color and cmap are post-compute only)
    if "attribution_result" in st.session_state:
        result = st.session_state["attribution_result"]
        tokenizer = st.session_state["tokenizer"]

        st.subheader("Attribution Visualization")

        # Adjustable after compute — does not trigger recompute
        viz_col1, viz_col2 = st.columns([1, 1])
        with viz_col1:
            log_color = st.checkbox(
                "Log-scale",
                value=False,
                key="log_color",
                help="Use log scale for influence values.",
            )
        with viz_col2:
            cmap_choice = st.selectbox(
                "Color map",
                options=["Blues", "Greens", "viridis"],
                index=0,
                key="cmap_choice",
                help="Colormap for attribution visualization.",
            )

        if result.get("input_type") == "comma_delimited":
            fig_barplot = render_attribution_barplot(
                result, log_color=log_color, cmap_name=cmap_choice
            )
            st.pyplot(fig_barplot)
            plt.close(fig_barplot)
        else:
            html_output, fig_colorbar = render_attribution_html(
                result, log_color=log_color, cmap_name=cmap_choice
            )
            st.markdown(html_output, unsafe_allow_html=True)
            st.pyplot(fig_colorbar)
            plt.close(fig_colorbar)

        st.subheader("Top-15 predicted next tokens")
        k = 15
        logit_vector = result["logits"][result["loss_position"]].detach()
        probs = torch.softmax(logit_vector, dim=-1)
        top_probs, top_indices = torch.topk(probs, k)
        top_tokens = [tokenizer.decode([idx]) for idx in top_indices]
        if result.get("input_type") == "comma_delimited":
            # Temperature Scope comma-delimited: order by string value (numbers increasing, else lex)
            paired = list(zip(top_tokens, top_indices.tolist(), top_probs.tolist()))
            paired.sort(key=lambda x: _sort_key_for_token(x[0]))
            top_tokens = [p[0] for p in paired]
            top_probs = torch.tensor([p[2] for p in paired], dtype=top_probs.dtype)
        prob_np = top_probs.float().cpu().numpy()
        fig_pred, ax_pred = plt.subplots(figsize=(8, 3), dpi=100)
        x_pos = range(k)
        bars = ax_pred.bar(x_pos, prob_np, color="red", edgecolor="darkred", linewidth=0.5)
        ax_pred.set_xticks(x_pos)
        ax_pred.set_xticklabels([repr(t) for t in top_tokens], rotation=45, ha="right")
        ax_pred.set_ylabel("Probability")
        ax_pred.set_ylim(0, max(prob_np) * 1.1 if prob_np.max() > 0 else 1)
        plt.tight_layout()
        st.pyplot(fig_pred)
        plt.close(fig_pred)

    st.divider()
    with st.expander("Citation Information", expanded=True):
        st.markdown("**Jacobian Scopes Demo © 2026 Toni Jianbang Liu.**")
        st.markdown("If you use this demo in your work, please cite:")
        st.markdown(
            "Liu, T. J., Zadeoğlu, B., Boullé, N., Sarfati, R., & Earls, C. J. (2026). "
            "*Jacobian Scopes: token-level causal attributions in LLMs.* arXiv preprint arXiv:2601.16407."
        )
        st.markdown("**BibTeX:**")
        st.code(
            """@misc{liu2026jacobianscopestokenlevelcausal,
      title={Jacobian Scopes: token-level causal attributions in LLMs}, 
      author={Toni J. B. Liu and Baran Zadeoğlu and Nicolas Boullé and Raphaël Sarfati and Christopher J. Earls},
      year={2026},
      eprint={2601.16407},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2601.16407}, }""",
            language=None,
        )


if __name__ == "__main__":
    main()