|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
from captum.attr import IntegratedGradients |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from rdkit import Chem |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("fabikru/molencoder-D3R-simple") |
|
|
model = AutoModelForSequenceClassification.from_pretrained("fabikru/molencoder-D3R-simple") |
|
|
|
|
|
class ModelWrapper(nn.Module): |
|
|
|
|
|
def __init__(self, model): |
|
|
super(ModelWrapper, self).__init__() |
|
|
self.model = model |
|
|
self.token_dropout = False |
|
|
|
|
|
def forward(self, inputs_embeds): |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
batch_size, seq_length = input_shape |
|
|
attention_mask = torch.ones(input_shape, device=inputs_embeds.device) |
|
|
outputs = self.model.forward( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
return outputs.logits |
|
|
|
|
|
|
|
|
|
|
|
wrapper = ModelWrapper(model) |
|
|
|
|
|
def analyse_inhibitory_constant_for_delta_opioid_receptor(smiles: str) -> str: |
|
|
""" |
|
|
Analyses the inhibitory constant for a given molecule to delta opioid receptor and gives the contribution of each smiles symbol. |
|
|
|
|
|
Args: |
|
|
smiles (str): The smiles string to analyze |
|
|
|
|
|
Returns: |
|
|
str: The inhibitory constant and a markdown table with the smiles symbols and their attribution values for the inhibitory constant prediction |
|
|
""" |
|
|
try: |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
return "Invalid SMILES string" |
|
|
except Exception as e: |
|
|
return f"Error parsing SMILES string: {e}" |
|
|
|
|
|
|
|
|
embedding_layer = model.model.embeddings.tok_embeddings |
|
|
ids = tokenizer.batch_encode_plus([smiles], add_special_tokens=True, is_split_into_words=False) |
|
|
input_ids = torch.tensor([ids['input_ids'][0]]) |
|
|
inputs_embeds = embedding_layer(input_ids) |
|
|
inputs_embeds.requires_grad_(True) |
|
|
|
|
|
baseline_ids = [tokenizer.cls_token_id] + [tokenizer.pad_token_id] * len(smiles) + [tokenizer.sep_token_id] |
|
|
baseline_ids = torch.tensor([baseline_ids]) |
|
|
baseline_embeds = embedding_layer(baseline_ids) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(input_ids).logits |
|
|
prediction_log = logits.squeeze().item() |
|
|
|
|
|
|
|
|
binding_affinity_nM = 10 ** (-prediction_log) |
|
|
|
|
|
method = IntegratedGradients(wrapper) |
|
|
attributions = method.attribute(inputs=inputs_embeds, baselines=baseline_embeds) |
|
|
|
|
|
|
|
|
attributions_np = attributions.squeeze().cpu().detach().numpy() |
|
|
attributions_aggregated = np.mean(attributions_np, axis=1) |
|
|
attribution_values = attributions_aggregated[1:-1] |
|
|
|
|
|
|
|
|
output = f"Inhibitory Constant: {binding_affinity_nM:.2f} nM\n\n" |
|
|
output += "Attribution Values:\n\n" |
|
|
output += "| Smiles Symbol | Attribution Value |\n|----------------|------------------|\n" |
|
|
for i, value in enumerate(attribution_values): |
|
|
output += f"| {smiles[i]} | {value:.4f} |\n" |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=analyse_inhibitory_constant_for_delta_opioid_receptor, |
|
|
inputs=gr.Textbox(placeholder="Enter smiles to analyze..."), |
|
|
outputs=gr.Textbox(lines=10), |
|
|
title="Explainable Inhibitory Constant Prediction for Delta Opioid Receptor", |
|
|
description="Predicts the inhibitory constant for a given molecule and gives the contribution of each smiles symbol." |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(mcp_server=True) |