Fabian
improved description
dae9496
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import IntegratedGradients
import torch
import torch.nn as nn
import numpy as np
from rdkit import Chem
# Initialize model and tokenizer as global variables
tokenizer = AutoTokenizer.from_pretrained("fabikru/molencoder-D3R-simple")
model = AutoModelForSequenceClassification.from_pretrained("fabikru/molencoder-D3R-simple")
class ModelWrapper(nn.Module):
def __init__(self, model):
super(ModelWrapper, self).__init__()
self.model = model
self.token_dropout = False
def forward(self, inputs_embeds):
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
attention_mask = torch.ones(input_shape, device=inputs_embeds.device)
outputs = self.model.forward(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
return outputs.logits
# Initialize the wrapped model as a global variable
wrapper = ModelWrapper(model)
def analyse_inhibitory_constant_for_delta_opioid_receptor(smiles: str) -> str:
"""
Analyses the inhibitory constant for a given molecule to delta opioid receptor and gives the contribution of each smiles symbol.
Args:
smiles (str): The smiles string to analyze
Returns:
str: The inhibitory constant and a markdown table with the smiles symbols and their attribution values for the inhibitory constant prediction
"""
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return "Invalid SMILES string"
except Exception as e:
return f"Error parsing SMILES string: {e}"
# get input embedding tensors
embedding_layer = model.model.embeddings.tok_embeddings
ids = tokenizer.batch_encode_plus([smiles], add_special_tokens=True, is_split_into_words=False)
input_ids = torch.tensor([ids['input_ids'][0]])
inputs_embeds = embedding_layer(input_ids)
inputs_embeds.requires_grad_(True)
baseline_ids = [tokenizer.cls_token_id] + [tokenizer.pad_token_id] * len(smiles) + [tokenizer.sep_token_id]
baseline_ids = torch.tensor([baseline_ids])
baseline_embeds = embedding_layer(baseline_ids)
# Get model prediction
with torch.no_grad():
logits = model(input_ids).logits
prediction_log = logits.squeeze().item()
# Convert from -log to nM (reverse transformation)
binding_affinity_nM = 10 ** (-prediction_log)
method = IntegratedGradients(wrapper)
attributions = method.attribute(inputs=inputs_embeds, baselines=baseline_embeds)
# mean over embedding size to get one attribution value per input token (excluding special tokens)
attributions_np = attributions.squeeze().cpu().detach().numpy()
attributions_aggregated = np.mean(attributions_np, axis=1)
attribution_values = attributions_aggregated[1:-1]
# Format output with prediction and attribution table
output = f"Inhibitory Constant: {binding_affinity_nM:.2f} nM\n\n"
output += "Attribution Values:\n\n"
output += "| Smiles Symbol | Attribution Value |\n|----------------|------------------|\n"
for i, value in enumerate(attribution_values):
output += f"| {smiles[i]} | {value:.4f} |\n"
return output
# Create the Gradio interface
demo = gr.Interface(
fn=analyse_inhibitory_constant_for_delta_opioid_receptor,
inputs=gr.Textbox(placeholder="Enter smiles to analyze..."),
outputs=gr.Textbox(lines=10), # Changed from gr.JSON() to gr.Textbox()
title="Explainable Inhibitory Constant Prediction for Delta Opioid Receptor",
description="Predicts the inhibitory constant for a given molecule and gives the contribution of each smiles symbol."
)
# Launch the interface and MCP server
if __name__ == "__main__":
demo.launch(mcp_server=True)