Spaces:
Running
Running
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import torch | |
| from transformers import AutoModel | |
| from captum.attr import LayerIntegratedGradients | |
| from smirk import SmirkTokenizerFast | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw, AllChem | |
| from rdkit.Chem.Draw import rdMolDraw2D | |
| from matplotlib import cm | |
| from matplotlib.colors import Normalize | |
| from io import BytesIO | |
| from PIL import Image | |
| st.set_page_config(page_title="Token Attribution", layout="wide") | |
| st.markdown( | |
| """<style> | |
| .main-header {font-size: 2.5rem; font-weight: bold; color: #1f77b4; text-align: center; margin-bottom: 2rem;} | |
| .section-header {font-size: 1.5rem; font-weight: bold; color: #2c3e50; margin-top: 1.5rem;} | |
| </style>""", | |
| unsafe_allow_html=True, | |
| ) | |
| def load_model(model_name: str): | |
| tokenizer = SmirkTokenizerFast() | |
| model = AutoModel.from_pretrained(model_name, trust_remote_code=True, use_auth_token=True) | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| return model.to(device), tokenizer, device | |
| def get_channels(model): | |
| if hasattr(model.config, "channels") and model.config.channels: | |
| return model.config.channels | |
| return None | |
| def forward_fn(input_ids, attention_mask, model): | |
| output = model(input_ids=input_ids, attention_mask=attention_mask) | |
| if hasattr(output, "logits"): | |
| return output.logits | |
| if isinstance(output, tuple): | |
| return output[0] | |
| return output | |
| def get_token_embeddings(model, input_ids): | |
| if hasattr(model, "encoder") and hasattr(model.encoder, "embeddings"): | |
| return model.encoder.embeddings.word_embeddings(input_ids) | |
| return model.get_input_embeddings()(input_ids) | |
| def get_embedding_layer(model): | |
| if hasattr(model, "encoder") and hasattr(model.encoder, "embeddings"): | |
| return model.encoder.embeddings.word_embeddings | |
| return model.get_input_embeddings() | |
| def compute_attributions( | |
| model, input_ids, attention_mask, n_steps=50, tokenizer=None, target_idx=None | |
| ): | |
| model.eval() | |
| device = next(model.parameters()).device | |
| input_ids = input_ids.to(device) | |
| attention_mask = attention_mask.to(device) | |
| pad_id = getattr(model.config, "pad_token_id", None) | |
| if pad_id is None and tokenizer is not None: | |
| pad_id = tokenizer.pad_token_id | |
| if pad_id is None: | |
| pad_id = 0 | |
| baseline_ids = torch.full_like(input_ids, pad_id) | |
| lig = LayerIntegratedGradients( | |
| lambda ids, am: forward_fn(ids, am, model), | |
| get_embedding_layer(model), | |
| ) | |
| attr_kwargs = { | |
| "inputs": input_ids, | |
| "baselines": baseline_ids, | |
| "additional_forward_args": (attention_mask,), | |
| "return_convergence_delta": True, | |
| "n_steps": n_steps, | |
| } | |
| if target_idx is not None: | |
| attr_kwargs["target"] = target_idx | |
| attributions, delta = lig.attribute(**attr_kwargs) | |
| token_scores = attributions.sum(dim=-1) * attention_mask | |
| return token_scores, delta | |
| def get_color_mapper(scores): | |
| scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else scores | |
| vmin, vmax = scores_np.min(), scores_np.max() | |
| norm = Normalize(vmin=vmin, vmax=vmax) | |
| cmap = cm.RdYlGn | |
| return norm, cmap | |
| def plot_attributions(tokens, scores, target_name=None): | |
| scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else scores | |
| norm, cmap = get_color_mapper(scores) | |
| colors = [] | |
| for s in scores_np: | |
| rgba = cmap(norm(s)) | |
| colors.append( | |
| f"rgba({int(rgba[0] * 255)},{int(rgba[1] * 255)},{int(rgba[2] * 255)},{rgba[3]})" | |
| ) | |
| fig = go.Figure( | |
| go.Bar( | |
| x=list(range(len(tokens))), | |
| y=scores_np, | |
| text=tokens, | |
| textposition="outside", | |
| marker_color=colors, | |
| hovertemplate="<b>%{text}</b><br>%{y:.4f}<extra></extra>", | |
| ) | |
| ) | |
| title = ( | |
| f"Token Attributions - {target_name}" if target_name else "Token Attributions" | |
| ) | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title="Position", | |
| yaxis_title="Attribution", | |
| height=500, | |
| showlegend=False, | |
| margin=dict(t=100, b=50, l=50, r=50), | |
| ) | |
| return fig | |
| def kekulize_smiles(smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol: | |
| Chem.Kekulize(mol) | |
| return Chem.MolToSmiles(mol, kekuleSmiles=True) | |
| return smiles | |
| def draw_molecule(smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol: | |
| AllChem.Compute2DCoords(mol) | |
| return Draw.MolToImage(mol, size=(400, 400)) | |
| return None | |
| def map_tokens_to_structure(mol, tokens): | |
| """Map both atom and bond indices to token indices by parsing SMILES.""" | |
| ALIPHATIC_ORGANIC = ["B", "C", "N", "O", "S", "P", "F", "Cl", "Br", "I"] | |
| AROMATIC_ORGANIC = ["b", "c", "n", "o", "s", "p"] | |
| ELEMENT_SYMBOLS = [ | |
| "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", | |
| "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", | |
| "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", | |
| "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", | |
| "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", | |
| "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", | |
| "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", | |
| "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", | |
| "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", | |
| "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og" | |
| ] | |
| BOND_SYMBOLS = {"-": 1, "=": 2, "#": 3, ":": 1.5, "/": 1, "\\": 1, ".": 0} | |
| SPECIAL_TOKENS = ["[CLS]", "[SEP]", "[PAD]", "<s>", "</s>", "<pad>", "<unk>"] | |
| atom_symbols = set(ALIPHATIC_ORGANIC + AROMATIC_ORGANIC + ELEMENT_SYMBOLS) | |
| atom_map = {} | |
| bond_map = {} | |
| atom_count = 0 | |
| branch_stack = [] | |
| prev_atom = None | |
| pending_bond_token = None # Track bond token waiting for next atom | |
| ring_closures = {} # Track ring closure numbers: {ring_num: (atom_idx, token_idx)} | |
| in_bracket = False | |
| bracket_token_span = [] # Track all token indices in current bracket | |
| in_extended_ring = False # Track '%' + digits rings | |
| extended_ring_tokens = [] # Track tokens for extended ring closure like ['%', '1', '0'] | |
| for i, token in enumerate(tokens): | |
| if token in SPECIAL_TOKENS: | |
| continue | |
| # Handle bracketed atoms (e.g., [NH+] tokenized as ['[', 'N', 'H', '+', ']']) | |
| if token == "[": | |
| in_bracket = True | |
| bracket_token_span = [i] # Start tracking bracket span | |
| continue | |
| elif token == "]" and in_bracket: | |
| in_bracket = False | |
| bracket_token_span.append(i) # Include closing bracket | |
| # Complete the bracketed atom - map to all tokens in the bracket | |
| if atom_count < mol.GetNumAtoms(): | |
| atom_map[atom_count] = bracket_token_span.copy() | |
| # Check for bond to previous atom | |
| if prev_atom is not None: | |
| bond = mol.GetBondBetweenAtoms(prev_atom, atom_count) | |
| if bond is not None: | |
| # If there's an explicit bond token, use it; otherwise use bracket tokens for implicit bond | |
| if pending_bond_token is not None: | |
| bond_map[bond.GetIdx()] = [pending_bond_token] | |
| else: | |
| # Implicit bond - map to the bracket token span | |
| bond_map[bond.GetIdx()] = bracket_token_span.copy() | |
| # Always clear pending_bond_token after processing an atom | |
| pending_bond_token = None | |
| prev_atom = atom_count | |
| atom_count += 1 | |
| bracket_token_span = [] | |
| continue | |
| elif in_bracket: | |
| # Track tokens inside brackets | |
| bracket_token_span.append(i) | |
| continue | |
| # Handle extended ring closures: %10 tokenized as ['%', '1', '0'] | |
| if token == "%": | |
| in_extended_ring = True | |
| extended_ring_tokens = [i] # Start with '%' token | |
| continue | |
| elif in_extended_ring and token.isdigit(): | |
| extended_ring_tokens.append(i) | |
| continue | |
| elif in_extended_ring and not token.isdigit(): | |
| # Process the ring closure with accumulated tokens | |
| ring_num = "%" + "".join( | |
| tokens[idx] for idx in extended_ring_tokens[1:] | |
| ) | |
| is_ring_closure = True | |
| ring_token_span = extended_ring_tokens | |
| in_extended_ring = False | |
| extended_ring_tokens = [] | |
| else: | |
| is_ring_closure = token.isdigit() | |
| if is_ring_closure: | |
| ring_num = token | |
| ring_token_span = [i] | |
| is_atom = token in atom_symbols | |
| is_bond = token in BOND_SYMBOLS | |
| if is_atom and atom_count < mol.GetNumAtoms(): | |
| atom_map[atom_count] = [i] # Use list for consistency with bracketed atoms | |
| # Check for bond to previous atom | |
| if prev_atom is not None: | |
| bond = mol.GetBondBetweenAtoms(prev_atom, atom_count) | |
| if bond is not None: | |
| # If there's an explicit bond token, use it; otherwise use current atom token for implicit bond | |
| if pending_bond_token is not None: | |
| bond_map[bond.GetIdx()] = [pending_bond_token] | |
| else: | |
| # Implicit bond - map to the current atom token | |
| bond_map[bond.GetIdx()] = [i] | |
| # Always clear pending_bond_token after processing an atom | |
| pending_bond_token = None | |
| prev_atom = atom_count | |
| atom_count += 1 | |
| elif is_bond: | |
| # Store the bond token to map when we see the next atom | |
| pending_bond_token = i | |
| elif is_ring_closure and prev_atom is not None: | |
| # Handle ring closures (e.g., '1', '2', '%10') | |
| # Check if there's a bond symbol before this ring closure (e.g., =1 or C=1) | |
| has_explicit_bond = pending_bond_token is not None | |
| # Use the explicit bond token if present, otherwise use the ring token span | |
| bond_token_indices = ( | |
| [pending_bond_token] if has_explicit_bond else ring_token_span | |
| ) | |
| pending_bond_token = None # Clear after using | |
| if ring_num in ring_closures: | |
| # Second occurrence: close the ring | |
| first_atom, first_bond_token_indices, first_has_explicit = ( | |
| ring_closures[ring_num] | |
| ) | |
| bond = mol.GetBondBetweenAtoms(first_atom, prev_atom) | |
| if bond is not None: | |
| # Prefer explicit bond symbols over digit tokens | |
| # Use whichever occurrence has an explicit bond symbol | |
| if has_explicit_bond or first_has_explicit: | |
| # Use the one with explicit bond | |
| bond_map[bond.GetIdx()] = ( | |
| bond_token_indices | |
| if has_explicit_bond | |
| else first_bond_token_indices | |
| ) | |
| else: | |
| # Neither has explicit bond, use first occurrence digit(s) | |
| bond_map[bond.GetIdx()] = first_bond_token_indices | |
| del ring_closures[ring_num] | |
| else: | |
| # First occurrence: store it with its bond token indices and whether it's explicit | |
| ring_closures[ring_num] = ( | |
| prev_atom, | |
| bond_token_indices, | |
| has_explicit_bond, | |
| ) | |
| elif token == "(": | |
| # Push current atom onto stack for branch | |
| if prev_atom is not None: | |
| branch_stack.append(prev_atom) | |
| elif token == ")": | |
| # Pop from stack to return to main chain | |
| if branch_stack: | |
| prev_atom = branch_stack.pop() | |
| pending_bond_token = None | |
| # Handle case where extended ring closure is at the end | |
| if in_extended_ring and extended_ring_tokens and prev_atom is not None: | |
| ring_num = "%" + "".join(tokens[idx] for idx in extended_ring_tokens[1:]) | |
| ring_token_span = extended_ring_tokens | |
| has_explicit_bond = ( | |
| False # Can't have explicit bond if we're still collecting digits | |
| ) | |
| bond_token_indices = ring_token_span | |
| if ring_num in ring_closures: | |
| first_atom, first_bond_token_indices, first_has_explicit = ring_closures[ | |
| ring_num | |
| ] | |
| bond = mol.GetBondBetweenAtoms(first_atom, prev_atom) | |
| if bond is not None: | |
| bond_map[bond.GetIdx()] = ( | |
| first_bond_token_indices | |
| if first_has_explicit | |
| else bond_token_indices | |
| ) | |
| else: | |
| ring_closures[ring_num] = (prev_atom, bond_token_indices, has_explicit_bond) | |
| return atom_map, bond_map | |
| def draw_molecule_with_attributions(smiles, tokens, attribution_scores): | |
| mol = Chem.MolFromSmiles(smiles, sanitize=False) | |
| if not mol: | |
| return None | |
| AllChem.Compute2DCoords(mol) | |
| scores_np = ( | |
| attribution_scores.cpu().numpy() | |
| if torch.is_tensor(attribution_scores) | |
| else attribution_scores | |
| ) | |
| norm, cmap = get_color_mapper(attribution_scores) | |
| # Map atoms and bonds to their corresponding token indices | |
| atom_to_token, bond_to_token = map_tokens_to_structure(mol, tokens) | |
| atom_colors = {} | |
| for atom_idx, token_indices in atom_to_token.items(): | |
| # Aggregate scores across all tokens for this atom (sum) | |
| valid_indices = [idx for idx in token_indices if idx < len(scores_np)] | |
| if valid_indices: | |
| aggregated_score = sum(scores_np[idx] for idx in valid_indices) | |
| color_val = cmap(norm(aggregated_score)) | |
| atom_colors[atom_idx] = color_val[:3] | |
| bond_colors = {} | |
| for bond_idx, token_indices in bond_to_token.items(): | |
| # Aggregate scores across all tokens for this bond (sum) | |
| valid_indices = [idx for idx in token_indices if idx < len(scores_np)] | |
| if valid_indices: | |
| aggregated_score = sum(scores_np[idx] for idx in valid_indices) | |
| color_val = cmap(norm(aggregated_score)) | |
| bond_colors[bond_idx] = color_val[:3] | |
| drawer = rdMolDraw2D.MolDraw2DCairo(600, 600) | |
| drawer.DrawMolecule( | |
| mol, | |
| highlightAtoms=list(atom_colors.keys()), | |
| highlightBonds=list(bond_colors.keys()), | |
| highlightAtomColors=atom_colors, | |
| highlightBondColors=bond_colors, | |
| ) | |
| drawer.FinishDrawing() | |
| img_bytes = drawer.GetDrawingText() | |
| return Image.open(BytesIO(img_bytes)) | |
| def main(): | |
| st.markdown("# Prediction and Attribution with MIST") | |
| st.sidebar.header("Configuration") | |
| models_info = { | |
| "QM8": "mist-models/mist-28M-gzwqzpcr-qm8", | |
| "QM9": "mist-models/mist-26.9M-kkgx0omx-qm9", | |
| } | |
| selected_property = st.sidebar.selectbox("Property", list(models_info.keys())) | |
| model_name = models_info[selected_property] | |
| st.sidebar.markdown("---") | |
| examples = { | |
| "Benzene": "c1ccccc1", | |
| "Ethanol": "CCO", | |
| "Aspirin": "CC(=O)Oc1ccccc1C(=O)O", | |
| "Caffeine": "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", | |
| "Propylene Carbonate": "CC1COC(=O)O1", | |
| "Custom": "", | |
| } | |
| selected = st.sidebar.selectbox("Example", list(examples.keys())) | |
| smiles = st.sidebar.text_input( | |
| "SMILES", value=examples[selected], placeholder="Enter SMILES" | |
| ) | |
| st.sidebar.markdown("---") | |
| n_steps = st.sidebar.slider("Steps", 10, 200, 50, 10) | |
| if not smiles: | |
| st.info("Enter a SMILES string") | |
| return | |
| with st.spinner("Loading model..."): | |
| model, tokenizer, device = load_model(model_name) | |
| channels = get_channels(model) | |
| target_idx = None | |
| selected_channel = None | |
| if channels: | |
| st.sidebar.markdown("---") | |
| st.sidebar.header("Target") | |
| channel_labels = [ | |
| f"{ch['name']} ({ch.get('description', '')})" for ch in channels | |
| ] | |
| selected_idx = st.sidebar.selectbox( | |
| "Channel", range(len(channels)), format_func=lambda i: channel_labels[i] | |
| ) | |
| target_idx = selected_idx | |
| selected_channel = channels[selected_idx] | |
| kekule_smiles = kekulize_smiles(smiles) | |
| with st.spinner("Tokenizing..."): | |
| encoded = tokenizer( | |
| [ | |
| kekule_smiles, | |
| ] | |
| ) | |
| tokens = tokenizer.tokenize(kekule_smiles) | |
| input_ids = torch.tensor(encoded["input_ids"]) | |
| attention_mask = torch.tensor(encoded["attention_mask"]) | |
| st.markdown("### Molecule") | |
| st.code(smiles) | |
| with st.expander("View Tokens"): | |
| token_df = pd.DataFrame({"Position": range(len(tokens)), "Token": tokens}) | |
| st.dataframe(token_df, use_container_width=True) | |
| st.markdown("### Property Prediction") | |
| with torch.no_grad(): | |
| predictions = model.predict([kekule_smiles]) | |
| st.write("Predicted Value", predictions) | |
| st.markdown("### Attributions") | |
| st.markdown( | |
| """ | |
| Token attributions quantify how much each token in the SMILES string contributes to the model's prediction as compared to a baseline. | |
| Positive scores (green) indicate tokens that increase the predicted value, while negative scores (red) indicate | |
| tokens that decrease it. | |
| Attributions are computed using the integrated gradients described in [Axiomatic Attribution for Deep Networks](https://arxiv.org/abs/1703.01365) | |
| as implemented by ``captum``'s ``LayerIntegratedGradients`` class. | |
| A padding token ``[PAD]`` is used as the baseline. | |
| If the convergence Δ is > 0.3, increase the number of integration steps. | |
| """ | |
| ) | |
| if selected_channel: | |
| st.info( | |
| f"Computing attributions for: **{selected_channel['name']}** ({selected_channel.get('description', '')}) - {selected_channel.get('unit', '')}" | |
| ) | |
| with st.spinner("Computing attributions..."): | |
| scores, delta = compute_attributions( | |
| model, input_ids, attention_mask, n_steps, tokenizer, target_idx | |
| ) | |
| attribution_scores = scores.flatten() | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.metric("Convergence Δ", f"{delta.item():.6f}") | |
| with col2: | |
| quality = ( | |
| "Good" | |
| if abs(delta.item()) < 0.05 | |
| else "Fair" | |
| if abs(delta.item()) < 0.1 | |
| else "Poor" | |
| ) | |
| st.metric("Quality", quality) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| target_name = selected_channel["name"] if selected_channel else None | |
| st.plotly_chart( | |
| plot_attributions(tokens, attribution_scores, target_name), | |
| use_container_width=True, | |
| ) | |
| with col2: | |
| attributed_img = draw_molecule_with_attributions( | |
| kekule_smiles, tokens, attribution_scores | |
| ) | |
| if attributed_img: | |
| st.image(attributed_img, width="content") | |
| else: | |
| st.warning("Unable to generate structure visualization") | |
| st.markdown("Statistics") | |
| s = attribution_scores.cpu().numpy() | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Max", f"{s.max():.4f}") | |
| with col2: | |
| st.metric("Mean", f"{s.mean():.4f}") | |
| with col3: | |
| st.metric("Min", f"{s.min():.4f}") | |
| with col4: | |
| st.metric("Std", f"{s.std():.4f}") | |
| top_idx = np.argsort(np.abs(s))[::-1][:10] | |
| df = pd.DataFrame( | |
| [{"Pos": int(i), "Token": tokens[i], "Score": f"{s[i]:.6f}"} for i in top_idx] | |
| ) | |
| st.dataframe(df, use_container_width=True) | |
| if __name__ == "__main__": | |
| main() |