import streamlit as st import numpy as np import pandas as pd import plotly.graph_objects as go import torch from transformers import AutoModel from captum.attr import LayerIntegratedGradients from smirk import SmirkTokenizerFast from rdkit import Chem from rdkit.Chem import Draw, AllChem from rdkit.Chem.Draw import rdMolDraw2D from matplotlib import cm from matplotlib.colors import Normalize from io import BytesIO from PIL import Image st.set_page_config(page_title="Token Attribution", layout="wide") st.markdown( """""", unsafe_allow_html=True, ) @st.cache_resource def load_model(model_name: str): tokenizer = SmirkTokenizerFast() model = AutoModel.from_pretrained(model_name, trust_remote_code=True, use_auth_token=True) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return model.to(device), tokenizer, device def get_channels(model): if hasattr(model.config, "channels") and model.config.channels: return model.config.channels return None def forward_fn(input_ids, attention_mask, model): output = model(input_ids=input_ids, attention_mask=attention_mask) if hasattr(output, "logits"): return output.logits if isinstance(output, tuple): return output[0] return output @torch.no_grad() def get_token_embeddings(model, input_ids): if hasattr(model, "encoder") and hasattr(model.encoder, "embeddings"): return model.encoder.embeddings.word_embeddings(input_ids) return model.get_input_embeddings()(input_ids) def get_embedding_layer(model): if hasattr(model, "encoder") and hasattr(model.encoder, "embeddings"): return model.encoder.embeddings.word_embeddings return model.get_input_embeddings() def compute_attributions( model, input_ids, attention_mask, n_steps=50, tokenizer=None, target_idx=None ): model.eval() device = next(model.parameters()).device input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) pad_id = getattr(model.config, "pad_token_id", None) if pad_id is None and tokenizer is not None: pad_id = tokenizer.pad_token_id if pad_id is None: pad_id = 0 baseline_ids = torch.full_like(input_ids, pad_id) lig = LayerIntegratedGradients( lambda ids, am: forward_fn(ids, am, model), get_embedding_layer(model), ) attr_kwargs = { "inputs": input_ids, "baselines": baseline_ids, "additional_forward_args": (attention_mask,), "return_convergence_delta": True, "n_steps": n_steps, } if target_idx is not None: attr_kwargs["target"] = target_idx attributions, delta = lig.attribute(**attr_kwargs) token_scores = attributions.sum(dim=-1) * attention_mask return token_scores, delta def get_color_mapper(scores): scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else scores vmin, vmax = scores_np.min(), scores_np.max() norm = Normalize(vmin=vmin, vmax=vmax) cmap = cm.RdYlGn return norm, cmap def plot_attributions(tokens, scores, target_name=None): scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else scores norm, cmap = get_color_mapper(scores) colors = [] for s in scores_np: rgba = cmap(norm(s)) colors.append( f"rgba({int(rgba[0] * 255)},{int(rgba[1] * 255)},{int(rgba[2] * 255)},{rgba[3]})" ) fig = go.Figure( go.Bar( x=list(range(len(tokens))), y=scores_np, text=tokens, textposition="outside", marker_color=colors, hovertemplate="%{text}
%{y:.4f}", ) ) title = ( f"Token Attributions - {target_name}" if target_name else "Token Attributions" ) fig.update_layout( title=title, xaxis_title="Position", yaxis_title="Attribution", height=500, showlegend=False, margin=dict(t=100, b=50, l=50, r=50), ) return fig def kekulize_smiles(smiles): mol = Chem.MolFromSmiles(smiles) if mol: Chem.Kekulize(mol) return Chem.MolToSmiles(mol, kekuleSmiles=True) return smiles def draw_molecule(smiles): mol = Chem.MolFromSmiles(smiles) if mol: AllChem.Compute2DCoords(mol) return Draw.MolToImage(mol, size=(400, 400)) return None def map_tokens_to_structure(mol, tokens): """Map both atom and bond indices to token indices by parsing SMILES.""" ALIPHATIC_ORGANIC = ["B", "C", "N", "O", "S", "P", "F", "Cl", "Br", "I"] AROMATIC_ORGANIC = ["b", "c", "n", "o", "s", "p"] ELEMENT_SYMBOLS = [ "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og" ] BOND_SYMBOLS = {"-": 1, "=": 2, "#": 3, ":": 1.5, "/": 1, "\\": 1, ".": 0} SPECIAL_TOKENS = ["[CLS]", "[SEP]", "[PAD]", "", "", "", ""] atom_symbols = set(ALIPHATIC_ORGANIC + AROMATIC_ORGANIC + ELEMENT_SYMBOLS) atom_map = {} bond_map = {} atom_count = 0 branch_stack = [] prev_atom = None pending_bond_token = None # Track bond token waiting for next atom ring_closures = {} # Track ring closure numbers: {ring_num: (atom_idx, token_idx)} in_bracket = False bracket_token_span = [] # Track all token indices in current bracket in_extended_ring = False # Track '%' + digits rings extended_ring_tokens = [] # Track tokens for extended ring closure like ['%', '1', '0'] for i, token in enumerate(tokens): if token in SPECIAL_TOKENS: continue # Handle bracketed atoms (e.g., [NH+] tokenized as ['[', 'N', 'H', '+', ']']) if token == "[": in_bracket = True bracket_token_span = [i] # Start tracking bracket span continue elif token == "]" and in_bracket: in_bracket = False bracket_token_span.append(i) # Include closing bracket # Complete the bracketed atom - map to all tokens in the bracket if atom_count < mol.GetNumAtoms(): atom_map[atom_count] = bracket_token_span.copy() # Check for bond to previous atom if prev_atom is not None: bond = mol.GetBondBetweenAtoms(prev_atom, atom_count) if bond is not None: # If there's an explicit bond token, use it; otherwise use bracket tokens for implicit bond if pending_bond_token is not None: bond_map[bond.GetIdx()] = [pending_bond_token] else: # Implicit bond - map to the bracket token span bond_map[bond.GetIdx()] = bracket_token_span.copy() # Always clear pending_bond_token after processing an atom pending_bond_token = None prev_atom = atom_count atom_count += 1 bracket_token_span = [] continue elif in_bracket: # Track tokens inside brackets bracket_token_span.append(i) continue # Handle extended ring closures: %10 tokenized as ['%', '1', '0'] if token == "%": in_extended_ring = True extended_ring_tokens = [i] # Start with '%' token continue elif in_extended_ring and token.isdigit(): extended_ring_tokens.append(i) continue elif in_extended_ring and not token.isdigit(): # Process the ring closure with accumulated tokens ring_num = "%" + "".join( tokens[idx] for idx in extended_ring_tokens[1:] ) is_ring_closure = True ring_token_span = extended_ring_tokens in_extended_ring = False extended_ring_tokens = [] else: is_ring_closure = token.isdigit() if is_ring_closure: ring_num = token ring_token_span = [i] is_atom = token in atom_symbols is_bond = token in BOND_SYMBOLS if is_atom and atom_count < mol.GetNumAtoms(): atom_map[atom_count] = [i] # Use list for consistency with bracketed atoms # Check for bond to previous atom if prev_atom is not None: bond = mol.GetBondBetweenAtoms(prev_atom, atom_count) if bond is not None: # If there's an explicit bond token, use it; otherwise use current atom token for implicit bond if pending_bond_token is not None: bond_map[bond.GetIdx()] = [pending_bond_token] else: # Implicit bond - map to the current atom token bond_map[bond.GetIdx()] = [i] # Always clear pending_bond_token after processing an atom pending_bond_token = None prev_atom = atom_count atom_count += 1 elif is_bond: # Store the bond token to map when we see the next atom pending_bond_token = i elif is_ring_closure and prev_atom is not None: # Handle ring closures (e.g., '1', '2', '%10') # Check if there's a bond symbol before this ring closure (e.g., =1 or C=1) has_explicit_bond = pending_bond_token is not None # Use the explicit bond token if present, otherwise use the ring token span bond_token_indices = ( [pending_bond_token] if has_explicit_bond else ring_token_span ) pending_bond_token = None # Clear after using if ring_num in ring_closures: # Second occurrence: close the ring first_atom, first_bond_token_indices, first_has_explicit = ( ring_closures[ring_num] ) bond = mol.GetBondBetweenAtoms(first_atom, prev_atom) if bond is not None: # Prefer explicit bond symbols over digit tokens # Use whichever occurrence has an explicit bond symbol if has_explicit_bond or first_has_explicit: # Use the one with explicit bond bond_map[bond.GetIdx()] = ( bond_token_indices if has_explicit_bond else first_bond_token_indices ) else: # Neither has explicit bond, use first occurrence digit(s) bond_map[bond.GetIdx()] = first_bond_token_indices del ring_closures[ring_num] else: # First occurrence: store it with its bond token indices and whether it's explicit ring_closures[ring_num] = ( prev_atom, bond_token_indices, has_explicit_bond, ) elif token == "(": # Push current atom onto stack for branch if prev_atom is not None: branch_stack.append(prev_atom) elif token == ")": # Pop from stack to return to main chain if branch_stack: prev_atom = branch_stack.pop() pending_bond_token = None # Handle case where extended ring closure is at the end if in_extended_ring and extended_ring_tokens and prev_atom is not None: ring_num = "%" + "".join(tokens[idx] for idx in extended_ring_tokens[1:]) ring_token_span = extended_ring_tokens has_explicit_bond = ( False # Can't have explicit bond if we're still collecting digits ) bond_token_indices = ring_token_span if ring_num in ring_closures: first_atom, first_bond_token_indices, first_has_explicit = ring_closures[ ring_num ] bond = mol.GetBondBetweenAtoms(first_atom, prev_atom) if bond is not None: bond_map[bond.GetIdx()] = ( first_bond_token_indices if first_has_explicit else bond_token_indices ) else: ring_closures[ring_num] = (prev_atom, bond_token_indices, has_explicit_bond) return atom_map, bond_map def draw_molecule_with_attributions(smiles, tokens, attribution_scores): mol = Chem.MolFromSmiles(smiles, sanitize=False) if not mol: return None AllChem.Compute2DCoords(mol) scores_np = ( attribution_scores.cpu().numpy() if torch.is_tensor(attribution_scores) else attribution_scores ) norm, cmap = get_color_mapper(attribution_scores) # Map atoms and bonds to their corresponding token indices atom_to_token, bond_to_token = map_tokens_to_structure(mol, tokens) atom_colors = {} for atom_idx, token_indices in atom_to_token.items(): # Aggregate scores across all tokens for this atom (sum) valid_indices = [idx for idx in token_indices if idx < len(scores_np)] if valid_indices: aggregated_score = sum(scores_np[idx] for idx in valid_indices) color_val = cmap(norm(aggregated_score)) atom_colors[atom_idx] = color_val[:3] bond_colors = {} for bond_idx, token_indices in bond_to_token.items(): # Aggregate scores across all tokens for this bond (sum) valid_indices = [idx for idx in token_indices if idx < len(scores_np)] if valid_indices: aggregated_score = sum(scores_np[idx] for idx in valid_indices) color_val = cmap(norm(aggregated_score)) bond_colors[bond_idx] = color_val[:3] drawer = rdMolDraw2D.MolDraw2DCairo(600, 600) drawer.DrawMolecule( mol, highlightAtoms=list(atom_colors.keys()), highlightBonds=list(bond_colors.keys()), highlightAtomColors=atom_colors, highlightBondColors=bond_colors, ) drawer.FinishDrawing() img_bytes = drawer.GetDrawingText() return Image.open(BytesIO(img_bytes)) def main(): st.markdown("# Prediction and Attribution with MIST") st.sidebar.header("Configuration") models_info = { "QM8": "mist-models/mist-28M-gzwqzpcr-qm8", "QM9": "mist-models/mist-26.9M-kkgx0omx-qm9", } selected_property = st.sidebar.selectbox("Property", list(models_info.keys())) model_name = models_info[selected_property] st.sidebar.markdown("---") examples = { "Benzene": "c1ccccc1", "Ethanol": "CCO", "Aspirin": "CC(=O)Oc1ccccc1C(=O)O", "Caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", "Propylene Carbonate": "CC1COC(=O)O1", "Custom": "", } selected = st.sidebar.selectbox("Example", list(examples.keys())) smiles = st.sidebar.text_input( "SMILES", value=examples[selected], placeholder="Enter SMILES" ) st.sidebar.markdown("---") n_steps = st.sidebar.slider("Steps", 10, 200, 50, 10) if not smiles: st.info("Enter a SMILES string") return with st.spinner("Loading model..."): model, tokenizer, device = load_model(model_name) channels = get_channels(model) target_idx = None selected_channel = None if channels: st.sidebar.markdown("---") st.sidebar.header("Target") channel_labels = [ f"{ch['name']} ({ch.get('description', '')})" for ch in channels ] selected_idx = st.sidebar.selectbox( "Channel", range(len(channels)), format_func=lambda i: channel_labels[i] ) target_idx = selected_idx selected_channel = channels[selected_idx] kekule_smiles = kekulize_smiles(smiles) with st.spinner("Tokenizing..."): encoded = tokenizer( [ kekule_smiles, ] ) tokens = tokenizer.tokenize(kekule_smiles) input_ids = torch.tensor(encoded["input_ids"]) attention_mask = torch.tensor(encoded["attention_mask"]) st.markdown("### Molecule") st.code(smiles) with st.expander("View Tokens"): token_df = pd.DataFrame({"Position": range(len(tokens)), "Token": tokens}) st.dataframe(token_df, use_container_width=True) st.markdown("### Property Prediction") with torch.no_grad(): predictions = model.predict([kekule_smiles]) st.write("Predicted Value", predictions) st.markdown("### Attributions") st.markdown( """ Token attributions quantify how much each token in the SMILES string contributes to the model's prediction as compared to a baseline. Positive scores (green) indicate tokens that increase the predicted value, while negative scores (red) indicate tokens that decrease it. Attributions are computed using the integrated gradients described in [Axiomatic Attribution for Deep Networks](https://arxiv.org/abs/1703.01365) as implemented by ``captum``'s ``LayerIntegratedGradients`` class. A padding token ``[PAD]`` is used as the baseline. If the convergence Δ is > 0.3, increase the number of integration steps. """ ) if selected_channel: st.info( f"Computing attributions for: **{selected_channel['name']}** ({selected_channel.get('description', '')}) - {selected_channel.get('unit', '')}" ) with st.spinner("Computing attributions..."): scores, delta = compute_attributions( model, input_ids, attention_mask, n_steps, tokenizer, target_idx ) attribution_scores = scores.flatten() col1, col2 = st.columns(2) with col1: st.metric("Convergence Δ", f"{delta.item():.6f}") with col2: quality = ( "Good" if abs(delta.item()) < 0.05 else "Fair" if abs(delta.item()) < 0.1 else "Poor" ) st.metric("Quality", quality) col1, col2 = st.columns(2) with col1: target_name = selected_channel["name"] if selected_channel else None st.plotly_chart( plot_attributions(tokens, attribution_scores, target_name), use_container_width=True, ) with col2: attributed_img = draw_molecule_with_attributions( kekule_smiles, tokens, attribution_scores ) if attributed_img: st.image(attributed_img, width="content") else: st.warning("Unable to generate structure visualization") st.markdown("Statistics") s = attribution_scores.cpu().numpy() col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Max", f"{s.max():.4f}") with col2: st.metric("Mean", f"{s.mean():.4f}") with col3: st.metric("Min", f"{s.min():.4f}") with col4: st.metric("Std", f"{s.std():.4f}") top_idx = np.argsort(np.abs(s))[::-1][:10] df = pd.DataFrame( [{"Pos": int(i), "Token": tokens[i], "Score": f"{s[i]:.6f}"} for i in top_idx] ) st.dataframe(df, use_container_width=True) if __name__ == "__main__": main()