File size: 20,273 Bytes
6117135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a490df1
6117135
 
 
 
 
 
 
 
 
 
 
6c36c0b
6117135
 
 
6c36c0b
6117135
6c36c0b
 
6117135
6c36c0b
6117135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c36c0b
6117135
 
 
6c36c0b
 
 
 
 
 
 
 
 
 
 
6117135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f58e807
6117135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8740496
6117135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
import streamlit as st
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch
from transformers import AutoModel
from captum.attr import LayerIntegratedGradients
from smirk import SmirkTokenizerFast

from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import rdMolDraw2D
from matplotlib import cm
from matplotlib.colors import Normalize
from io import BytesIO
from PIL import Image


st.set_page_config(page_title="Token Attribution", layout="wide")

st.markdown(
    """<style>
.main-header {font-size: 2.5rem; font-weight: bold; color: #1f77b4; text-align: center; margin-bottom: 2rem;}
.section-header {font-size: 1.5rem; font-weight: bold; color: #2c3e50; margin-top: 1.5rem;}
</style>""",
    unsafe_allow_html=True,
)


@st.cache_resource
def load_model(model_name: str):
    tokenizer = SmirkTokenizerFast()
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True, use_auth_token=True)
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return model.to(device), tokenizer, device


def get_channels(model):
    if hasattr(model.config, "channels") and model.config.channels:
        return model.config.channels
    return None


def forward_fn(input_ids, attention_mask, model):
    output = model(input_ids=input_ids, attention_mask=attention_mask)

    if hasattr(output, "logits"):
        return output.logits

    if isinstance(output, tuple):
        return output[0]

    return output


@torch.no_grad()
def get_token_embeddings(model, input_ids):
    if hasattr(model, "encoder") and hasattr(model.encoder, "embeddings"):
        return model.encoder.embeddings.word_embeddings(input_ids)
    return model.get_input_embeddings()(input_ids)


def get_embedding_layer(model):
    if hasattr(model, "encoder") and hasattr(model.encoder, "embeddings"):
        return model.encoder.embeddings.word_embeddings
    return model.get_input_embeddings()


def compute_attributions(
    model, input_ids, attention_mask, n_steps=50, tokenizer=None, target_idx=None
):
    model.eval()
    device = next(model.parameters()).device
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    pad_id = getattr(model.config, "pad_token_id", None)
    if pad_id is None and tokenizer is not None:
        pad_id = tokenizer.pad_token_id
    if pad_id is None:
        pad_id = 0

    baseline_ids = torch.full_like(input_ids, pad_id)
    lig = LayerIntegratedGradients(
        lambda ids, am: forward_fn(ids, am, model),
        get_embedding_layer(model),
    )

    attr_kwargs = {
        "inputs": input_ids,
        "baselines": baseline_ids,
        "additional_forward_args": (attention_mask,),
        "return_convergence_delta": True,
        "n_steps": n_steps,
    }
    if target_idx is not None:
        attr_kwargs["target"] = target_idx

    attributions, delta = lig.attribute(**attr_kwargs)

    token_scores = attributions.sum(dim=-1) * attention_mask
    return token_scores, delta


def get_color_mapper(scores):
    scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else scores
    vmin, vmax = scores_np.min(), scores_np.max()
    norm = Normalize(vmin=vmin, vmax=vmax)
    cmap = cm.RdYlGn
    return norm, cmap


def plot_attributions(tokens, scores, target_name=None):
    scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else scores
    norm, cmap = get_color_mapper(scores)

    colors = []
    for s in scores_np:
        rgba = cmap(norm(s))
        colors.append(
            f"rgba({int(rgba[0] * 255)},{int(rgba[1] * 255)},{int(rgba[2] * 255)},{rgba[3]})"
        )

    fig = go.Figure(
        go.Bar(
            x=list(range(len(tokens))),
            y=scores_np,
            text=tokens,
            textposition="outside",
            marker_color=colors,
            hovertemplate="<b>%{text}</b><br>%{y:.4f}<extra></extra>",
        )
    )

    title = (
        f"Token Attributions - {target_name}" if target_name else "Token Attributions"
    )

    fig.update_layout(
        title=title,
        xaxis_title="Position",
        yaxis_title="Attribution",
        height=500,
        showlegend=False,
        margin=dict(t=100, b=50, l=50, r=50),
    )
    return fig


def kekulize_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        Chem.Kekulize(mol)
        return Chem.MolToSmiles(mol, kekuleSmiles=True)
    return smiles


def draw_molecule(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        AllChem.Compute2DCoords(mol)
        return Draw.MolToImage(mol, size=(400, 400))
    return None


def map_tokens_to_structure(mol, tokens):
    """Map both atom and bond indices to token indices by parsing SMILES."""
    ALIPHATIC_ORGANIC = ["B", "C", "N", "O", "S", "P", "F", "Cl", "Br", "I"]
    AROMATIC_ORGANIC = ["b", "c", "n", "o", "s", "p"]
    ELEMENT_SYMBOLS = [
        "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al",
        "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", 
        "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", 
        "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", 
        "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", 
        "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", 
        "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", 
        "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", 
        "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", 
        "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
    ]
    BOND_SYMBOLS = {"-": 1, "=": 2, "#": 3, ":": 1.5, "/": 1, "\\": 1, ".": 0}
    SPECIAL_TOKENS = ["[CLS]", "[SEP]", "[PAD]", "<s>", "</s>", "<pad>", "<unk>"]

    atom_symbols = set(ALIPHATIC_ORGANIC + AROMATIC_ORGANIC + ELEMENT_SYMBOLS)

    atom_map = {}
    bond_map = {}
    atom_count = 0
    branch_stack = []
    prev_atom = None
    pending_bond_token = None  # Track bond token waiting for next atom
    ring_closures = {}  # Track ring closure numbers: {ring_num: (atom_idx, token_idx)}
    in_bracket = False
    bracket_token_span = []  # Track all token indices in current bracket
    in_extended_ring = False  # Track '%' + digits rings
    extended_ring_tokens = []  # Track tokens for extended ring closure like ['%', '1', '0']

    for i, token in enumerate(tokens):
        if token in SPECIAL_TOKENS:
            continue

        # Handle bracketed atoms (e.g., [NH+] tokenized as ['[', 'N', 'H', '+', ']'])
        if token == "[":
            in_bracket = True
            bracket_token_span = [i]  # Start tracking bracket span
            continue
        elif token == "]" and in_bracket:
            in_bracket = False
            bracket_token_span.append(i)  # Include closing bracket
            # Complete the bracketed atom - map to all tokens in the bracket
            if atom_count < mol.GetNumAtoms():
                atom_map[atom_count] = bracket_token_span.copy()

                # Check for bond to previous atom
                if prev_atom is not None:
                    bond = mol.GetBondBetweenAtoms(prev_atom, atom_count)
                    if bond is not None:
                        # If there's an explicit bond token, use it; otherwise use bracket tokens for implicit bond
                        if pending_bond_token is not None:
                            bond_map[bond.GetIdx()] = [pending_bond_token]
                        else:
                            # Implicit bond - map to the bracket token span
                            bond_map[bond.GetIdx()] = bracket_token_span.copy()

                # Always clear pending_bond_token after processing an atom
                pending_bond_token = None
                prev_atom = atom_count
                atom_count += 1
            bracket_token_span = []
            continue
        elif in_bracket:
            # Track tokens inside brackets
            bracket_token_span.append(i)
            continue

        # Handle extended ring closures: %10 tokenized as ['%', '1', '0']
        if token == "%":
            in_extended_ring = True
            extended_ring_tokens = [i]  # Start with '%' token
            continue
        elif in_extended_ring and token.isdigit():
            extended_ring_tokens.append(i)
            continue
        elif in_extended_ring and not token.isdigit():
            # Process the ring closure with accumulated tokens
            ring_num = "%" + "".join(
                tokens[idx] for idx in extended_ring_tokens[1:]
            )
            is_ring_closure = True
            ring_token_span = extended_ring_tokens
            in_extended_ring = False
            extended_ring_tokens = []
        else:
            is_ring_closure = token.isdigit()
            if is_ring_closure:
                ring_num = token
                ring_token_span = [i]

        is_atom = token in atom_symbols
        is_bond = token in BOND_SYMBOLS

        if is_atom and atom_count < mol.GetNumAtoms():
            atom_map[atom_count] = [i]  # Use list for consistency with bracketed atoms

            # Check for bond to previous atom
            if prev_atom is not None:
                bond = mol.GetBondBetweenAtoms(prev_atom, atom_count)
                if bond is not None:
                    # If there's an explicit bond token, use it; otherwise use current atom token for implicit bond
                    if pending_bond_token is not None:
                        bond_map[bond.GetIdx()] = [pending_bond_token]
                    else:
                        # Implicit bond - map to the current atom token
                        bond_map[bond.GetIdx()] = [i]

            # Always clear pending_bond_token after processing an atom
            pending_bond_token = None
            prev_atom = atom_count
            atom_count += 1
        elif is_bond:
            # Store the bond token to map when we see the next atom
            pending_bond_token = i
        elif is_ring_closure and prev_atom is not None:
            # Handle ring closures (e.g., '1', '2', '%10')
            # Check if there's a bond symbol before this ring closure (e.g., =1 or C=1)
            has_explicit_bond = pending_bond_token is not None
            # Use the explicit bond token if present, otherwise use the ring token span
            bond_token_indices = (
                [pending_bond_token] if has_explicit_bond else ring_token_span
            )
            pending_bond_token = None  # Clear after using

            if ring_num in ring_closures:
                # Second occurrence: close the ring
                first_atom, first_bond_token_indices, first_has_explicit = (
                    ring_closures[ring_num]
                )
                bond = mol.GetBondBetweenAtoms(first_atom, prev_atom)
                if bond is not None:
                    # Prefer explicit bond symbols over digit tokens
                    # Use whichever occurrence has an explicit bond symbol
                    if has_explicit_bond or first_has_explicit:
                        # Use the one with explicit bond
                        bond_map[bond.GetIdx()] = (
                            bond_token_indices
                            if has_explicit_bond
                            else first_bond_token_indices
                        )
                    else:
                        # Neither has explicit bond, use first occurrence digit(s)
                        bond_map[bond.GetIdx()] = first_bond_token_indices
                del ring_closures[ring_num]
            else:
                # First occurrence: store it with its bond token indices and whether it's explicit
                ring_closures[ring_num] = (
                    prev_atom,
                    bond_token_indices,
                    has_explicit_bond,
                )
        elif token == "(":
            # Push current atom onto stack for branch
            if prev_atom is not None:
                branch_stack.append(prev_atom)
        elif token == ")":
            # Pop from stack to return to main chain
            if branch_stack:
                prev_atom = branch_stack.pop()
                pending_bond_token = None

    # Handle case where extended ring closure is at the end
    if in_extended_ring and extended_ring_tokens and prev_atom is not None:
        ring_num = "%" + "".join(tokens[idx] for idx in extended_ring_tokens[1:])
        ring_token_span = extended_ring_tokens
        has_explicit_bond = (
            False  # Can't have explicit bond if we're still collecting digits
        )
        bond_token_indices = ring_token_span

        if ring_num in ring_closures:
            first_atom, first_bond_token_indices, first_has_explicit = ring_closures[
                ring_num
            ]
            bond = mol.GetBondBetweenAtoms(first_atom, prev_atom)
            if bond is not None:
                bond_map[bond.GetIdx()] = (
                    first_bond_token_indices
                    if first_has_explicit
                    else bond_token_indices
                )
        else:
            ring_closures[ring_num] = (prev_atom, bond_token_indices, has_explicit_bond)

    return atom_map, bond_map


def draw_molecule_with_attributions(smiles, tokens, attribution_scores):
    mol = Chem.MolFromSmiles(smiles, sanitize=False)
    if not mol:
        return None

    AllChem.Compute2DCoords(mol)
    scores_np = (
        attribution_scores.cpu().numpy()
        if torch.is_tensor(attribution_scores)
        else attribution_scores
    )
    norm, cmap = get_color_mapper(attribution_scores)

    # Map atoms and bonds to their corresponding token indices
    atom_to_token, bond_to_token = map_tokens_to_structure(mol, tokens)

    atom_colors = {}
    for atom_idx, token_indices in atom_to_token.items():
        # Aggregate scores across all tokens for this atom (sum)
        valid_indices = [idx for idx in token_indices if idx < len(scores_np)]
        if valid_indices:
            aggregated_score = sum(scores_np[idx] for idx in valid_indices)
            color_val = cmap(norm(aggregated_score))
            atom_colors[atom_idx] = color_val[:3]

    bond_colors = {}
    for bond_idx, token_indices in bond_to_token.items():
        # Aggregate scores across all tokens for this bond (sum)
        valid_indices = [idx for idx in token_indices if idx < len(scores_np)]
        if valid_indices:
            aggregated_score = sum(scores_np[idx] for idx in valid_indices)
            color_val = cmap(norm(aggregated_score))
            bond_colors[bond_idx] = color_val[:3]

    drawer = rdMolDraw2D.MolDraw2DCairo(600, 600)
    drawer.DrawMolecule(
        mol,
        highlightAtoms=list(atom_colors.keys()),
        highlightBonds=list(bond_colors.keys()),
        highlightAtomColors=atom_colors,
        highlightBondColors=bond_colors,
    )
    drawer.FinishDrawing()
    img_bytes = drawer.GetDrawingText()
    return Image.open(BytesIO(img_bytes))


def main():
    st.markdown("# Prediction and Attribution with MIST")
    st.sidebar.header("Configuration")

    models_info = {
        "QM8": "mist-models/mist-28M-gzwqzpcr-qm8",
        "QM9": "mist-models/mist-26.9M-kkgx0omx-qm9",
    }
    selected_property = st.sidebar.selectbox("Property", list(models_info.keys()))
    model_name = models_info[selected_property]

    st.sidebar.markdown("---")
    examples = {
        "Benzene": "c1ccccc1",
        "Ethanol": "CCO",
        "Aspirin": "CC(=O)Oc1ccccc1C(=O)O",
        "Caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
        "Propylene Carbonate": "CC1COC(=O)O1",
        "Custom": "",
    }

    selected = st.sidebar.selectbox("Example", list(examples.keys()))
    smiles = st.sidebar.text_input(
        "SMILES", value=examples[selected], placeholder="Enter SMILES"
    )

    st.sidebar.markdown("---")
    n_steps = st.sidebar.slider("Steps", 10, 200, 50, 10)

    if not smiles:
        st.info("Enter a SMILES string")
        return

    with st.spinner("Loading model..."):
        model, tokenizer, device = load_model(model_name)

    channels = get_channels(model)
    target_idx = None
    selected_channel = None

    if channels:
        st.sidebar.markdown("---")
        st.sidebar.header("Target")
        channel_labels = [
            f"{ch['name']} ({ch.get('description', '')})" for ch in channels
        ]
        selected_idx = st.sidebar.selectbox(
            "Channel", range(len(channels)), format_func=lambda i: channel_labels[i]
        )
        target_idx = selected_idx
        selected_channel = channels[selected_idx]

    kekule_smiles = kekulize_smiles(smiles)

    with st.spinner("Tokenizing..."):
        encoded = tokenizer(
            [
                kekule_smiles,
            ]
        )
        tokens = tokenizer.tokenize(kekule_smiles)
        input_ids = torch.tensor(encoded["input_ids"])
        attention_mask = torch.tensor(encoded["attention_mask"])

    st.markdown("### Molecule")

    st.code(smiles)
    with st.expander("View Tokens"):
        token_df = pd.DataFrame({"Position": range(len(tokens)), "Token": tokens})
        st.dataframe(token_df, use_container_width=True)

    st.markdown("### Property Prediction")

    with torch.no_grad():
        predictions = model.predict([kekule_smiles])
        st.write("Predicted Value", predictions)

    st.markdown("### Attributions")

    st.markdown(
        """
        Token attributions quantify how much each token in the SMILES string contributes to the model's prediction as compared to a baseline.
        Positive scores (green) indicate tokens that increase the predicted value, while negative scores (red) indicate
        tokens that decrease it. 
        Attributions are computed using the integrated gradients described in [Axiomatic Attribution for Deep Networks](https://arxiv.org/abs/1703.01365) 
        as implemented by ``captum``'s ``LayerIntegratedGradients`` class. 
        A padding token ``[PAD]`` is used as the baseline.
        If the convergence Δ is > 0.3, increase the number of integration steps.
        """
        )
    
    if selected_channel:
        st.info(
            f"Computing attributions for: **{selected_channel['name']}** ({selected_channel.get('description', '')}) - {selected_channel.get('unit', '')}"
        )

    with st.spinner("Computing attributions..."):
        scores, delta = compute_attributions(
            model, input_ids, attention_mask, n_steps, tokenizer, target_idx
        )

    attribution_scores = scores.flatten()

    col1, col2 = st.columns(2)
    with col1:
        st.metric("Convergence Δ", f"{delta.item():.6f}")
    with col2:
        quality = (
            "Good"
            if abs(delta.item()) < 0.05
            else "Fair"
            if abs(delta.item()) < 0.1
            else "Poor"
        )
        st.metric("Quality", quality)

    col1, col2 = st.columns(2)

    with col1:
        target_name = selected_channel["name"] if selected_channel else None
        st.plotly_chart(
            plot_attributions(tokens, attribution_scores, target_name),
            use_container_width=True,
        )

    with col2:
        attributed_img = draw_molecule_with_attributions(
            kekule_smiles, tokens, attribution_scores
        )
        if attributed_img:
            st.image(attributed_img, width="content")
        else:
            st.warning("Unable to generate structure visualization")

    st.markdown("Statistics")

    s = attribution_scores.cpu().numpy()
    col1, col2, col3, col4 = st.columns(4)
    with col1:
        st.metric("Max", f"{s.max():.4f}")
    with col2:
        st.metric("Mean", f"{s.mean():.4f}")
    with col3:
        st.metric("Min", f"{s.min():.4f}")
    with col4:
        st.metric("Std", f"{s.std():.4f}")

    top_idx = np.argsort(np.abs(s))[::-1][:10]
    df = pd.DataFrame(
        [{"Pos": int(i), "Token": tokens[i], "Score": f"{s[i]:.6f}"} for i in top_idx]
    )
    st.dataframe(df, use_container_width=True)


if __name__ == "__main__":
    main()