File size: 3,899 Bytes
f203922
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dae9496
f203922
dae9496
f203922
 
 
 
 
dae9496
f203922
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dae9496
f203922
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import IntegratedGradients
import torch
import torch.nn as nn
import numpy as np
from rdkit import Chem

# Initialize model and tokenizer as global variables
tokenizer = AutoTokenizer.from_pretrained("fabikru/molencoder-D3R-simple")
model = AutoModelForSequenceClassification.from_pretrained("fabikru/molencoder-D3R-simple")

class ModelWrapper(nn.Module):

    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model
        self.token_dropout = False

    def forward(self, inputs_embeds):
        input_shape = inputs_embeds.size()[:-1]
        batch_size, seq_length = input_shape
        attention_mask = torch.ones(input_shape, device=inputs_embeds.device)
        outputs = self.model.forward(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
        )
        return outputs.logits


# Initialize the wrapped model as a global variable
wrapper = ModelWrapper(model)

def analyse_inhibitory_constant_for_delta_opioid_receptor(smiles: str) -> str:
    """
    Analyses the inhibitory constant for a given molecule to delta opioid receptor and gives the contribution of each smiles symbol.

    Args:
        smiles (str): The smiles string to analyze

    Returns:
        str: The inhibitory constant and a markdown table with the smiles symbols and their attribution values for the inhibitory constant prediction
    """
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return "Invalid SMILES string"
    except Exception as e:
        return f"Error parsing SMILES string: {e}"
    
    # get input embedding tensors
    embedding_layer = model.model.embeddings.tok_embeddings
    ids = tokenizer.batch_encode_plus([smiles], add_special_tokens=True, is_split_into_words=False)
    input_ids = torch.tensor([ids['input_ids'][0]])
    inputs_embeds = embedding_layer(input_ids)
    inputs_embeds.requires_grad_(True)

    baseline_ids = [tokenizer.cls_token_id] + [tokenizer.pad_token_id] * len(smiles) + [tokenizer.sep_token_id]
    baseline_ids = torch.tensor([baseline_ids])
    baseline_embeds = embedding_layer(baseline_ids)
    
    # Get model prediction
    with torch.no_grad():
        logits = model(input_ids).logits
        prediction_log = logits.squeeze().item()
    
    # Convert from -log to nM (reverse transformation)
    binding_affinity_nM = 10 ** (-prediction_log)
    
    method = IntegratedGradients(wrapper)
    attributions = method.attribute(inputs=inputs_embeds, baselines=baseline_embeds)
        
    # mean over embedding size to get one attribution value per input token (excluding special tokens)
    attributions_np = attributions.squeeze().cpu().detach().numpy()
    attributions_aggregated = np.mean(attributions_np, axis=1)
    attribution_values = attributions_aggregated[1:-1]

    # Format output with prediction and attribution table
    output = f"Inhibitory Constant: {binding_affinity_nM:.2f} nM\n\n"
    output += "Attribution Values:\n\n"
    output += "| Smiles Symbol | Attribution Value |\n|----------------|------------------|\n"
    for i, value in enumerate(attribution_values):
        output += f"| {smiles[i]} | {value:.4f} |\n"
    return output



# Create the Gradio interface
demo = gr.Interface(
    fn=analyse_inhibitory_constant_for_delta_opioid_receptor,
    inputs=gr.Textbox(placeholder="Enter smiles to analyze..."),
    outputs=gr.Textbox(lines=10),  # Changed from gr.JSON() to gr.Textbox()
    title="Explainable Inhibitory Constant Prediction for Delta Opioid Receptor",
    description="Predicts the inhibitory constant for a given molecule and gives the contribution of each smiles symbol."
)

# Launch the interface and MCP server
if __name__ == "__main__":
    demo.launch(mcp_server=True)