File size: 3,899 Bytes
f203922 dae9496 f203922 dae9496 f203922 dae9496 f203922 dae9496 f203922 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import IntegratedGradients
import torch
import torch.nn as nn
import numpy as np
from rdkit import Chem
# Initialize model and tokenizer as global variables
tokenizer = AutoTokenizer.from_pretrained("fabikru/molencoder-D3R-simple")
model = AutoModelForSequenceClassification.from_pretrained("fabikru/molencoder-D3R-simple")
class ModelWrapper(nn.Module):
def __init__(self, model):
super(ModelWrapper, self).__init__()
self.model = model
self.token_dropout = False
def forward(self, inputs_embeds):
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
attention_mask = torch.ones(input_shape, device=inputs_embeds.device)
outputs = self.model.forward(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
return outputs.logits
# Initialize the wrapped model as a global variable
wrapper = ModelWrapper(model)
def analyse_inhibitory_constant_for_delta_opioid_receptor(smiles: str) -> str:
"""
Analyses the inhibitory constant for a given molecule to delta opioid receptor and gives the contribution of each smiles symbol.
Args:
smiles (str): The smiles string to analyze
Returns:
str: The inhibitory constant and a markdown table with the smiles symbols and their attribution values for the inhibitory constant prediction
"""
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return "Invalid SMILES string"
except Exception as e:
return f"Error parsing SMILES string: {e}"
# get input embedding tensors
embedding_layer = model.model.embeddings.tok_embeddings
ids = tokenizer.batch_encode_plus([smiles], add_special_tokens=True, is_split_into_words=False)
input_ids = torch.tensor([ids['input_ids'][0]])
inputs_embeds = embedding_layer(input_ids)
inputs_embeds.requires_grad_(True)
baseline_ids = [tokenizer.cls_token_id] + [tokenizer.pad_token_id] * len(smiles) + [tokenizer.sep_token_id]
baseline_ids = torch.tensor([baseline_ids])
baseline_embeds = embedding_layer(baseline_ids)
# Get model prediction
with torch.no_grad():
logits = model(input_ids).logits
prediction_log = logits.squeeze().item()
# Convert from -log to nM (reverse transformation)
binding_affinity_nM = 10 ** (-prediction_log)
method = IntegratedGradients(wrapper)
attributions = method.attribute(inputs=inputs_embeds, baselines=baseline_embeds)
# mean over embedding size to get one attribution value per input token (excluding special tokens)
attributions_np = attributions.squeeze().cpu().detach().numpy()
attributions_aggregated = np.mean(attributions_np, axis=1)
attribution_values = attributions_aggregated[1:-1]
# Format output with prediction and attribution table
output = f"Inhibitory Constant: {binding_affinity_nM:.2f} nM\n\n"
output += "Attribution Values:\n\n"
output += "| Smiles Symbol | Attribution Value |\n|----------------|------------------|\n"
for i, value in enumerate(attribution_values):
output += f"| {smiles[i]} | {value:.4f} |\n"
return output
# Create the Gradio interface
demo = gr.Interface(
fn=analyse_inhibitory_constant_for_delta_opioid_receptor,
inputs=gr.Textbox(placeholder="Enter smiles to analyze..."),
outputs=gr.Textbox(lines=10), # Changed from gr.JSON() to gr.Textbox()
title="Explainable Inhibitory Constant Prediction for Delta Opioid Receptor",
description="Predicts the inhibitory constant for a given molecule and gives the contribution of each smiles symbol."
)
# Launch the interface and MCP server
if __name__ == "__main__":
demo.launch(mcp_server=True) |