Fabian commited on
Commit
f203922
·
1 Parent(s): a4e550a

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +99 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ from captum.attr import IntegratedGradients
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from rdkit import Chem
8
+
9
+ # Initialize model and tokenizer as global variables
10
+ tokenizer = AutoTokenizer.from_pretrained("fabikru/molencoder-D3R-simple")
11
+ model = AutoModelForSequenceClassification.from_pretrained("fabikru/molencoder-D3R-simple")
12
+
13
+ class ModelWrapper(nn.Module):
14
+
15
+ def __init__(self, model):
16
+ super(ModelWrapper, self).__init__()
17
+ self.model = model
18
+ self.token_dropout = False
19
+
20
+ def forward(self, inputs_embeds):
21
+ input_shape = inputs_embeds.size()[:-1]
22
+ batch_size, seq_length = input_shape
23
+ attention_mask = torch.ones(input_shape, device=inputs_embeds.device)
24
+ outputs = self.model.forward(
25
+ inputs_embeds=inputs_embeds,
26
+ attention_mask=attention_mask,
27
+ )
28
+ return outputs.logits
29
+
30
+
31
+ # Initialize the wrapped model as a global variable
32
+ wrapper = ModelWrapper(model)
33
+
34
+ def get_attribution_values(smiles: str) -> str:
35
+ """
36
+ Predicts the inhibitory constant for a given molecule and gives the contribution of each smiles symbol.
37
+
38
+ Args:
39
+ smiles (str): The smiles string to analyze
40
+
41
+ Returns:
42
+ str: A markdown table with the smiles symbols and their attribution values
43
+ """
44
+ try:
45
+ mol = Chem.MolFromSmiles(smiles)
46
+ if mol is None:
47
+ return "Invalid SMILES string"
48
+ except Exception as e:
49
+ return f"Error parsing SMILES string: {e}"
50
+
51
+ # get input embedding tensors
52
+ embedding_layer = model.model.embeddings.tok_embeddings
53
+ ids = tokenizer.batch_encode_plus([smiles], add_special_tokens=True, is_split_into_words=False)
54
+ input_ids = torch.tensor([ids['input_ids'][0]])
55
+ inputs_embeds = embedding_layer(input_ids)
56
+ inputs_embeds.requires_grad_(True)
57
+
58
+ baseline_ids = [tokenizer.cls_token_id] + [tokenizer.pad_token_id] * len(smiles) + [tokenizer.sep_token_id]
59
+ baseline_ids = torch.tensor([baseline_ids])
60
+ baseline_embeds = embedding_layer(baseline_ids)
61
+
62
+ # Get model prediction
63
+ with torch.no_grad():
64
+ logits = model(input_ids).logits
65
+ prediction_log = logits.squeeze().item()
66
+
67
+ # Convert from -log to nM (reverse transformation)
68
+ binding_affinity_nM = 10 ** (-prediction_log)
69
+
70
+ method = IntegratedGradients(wrapper)
71
+ attributions = method.attribute(inputs=inputs_embeds, baselines=baseline_embeds)
72
+
73
+ # mean over embedding size to get one attribution value per input token (excluding special tokens)
74
+ attributions_np = attributions.squeeze().cpu().detach().numpy()
75
+ attributions_aggregated = np.mean(attributions_np, axis=1)
76
+ attribution_values = attributions_aggregated[1:-1]
77
+
78
+ # Format output with prediction and attribution table
79
+ output = f"Inhibitory Constant: {binding_affinity_nM:.2f} nM\n\n"
80
+ output += "Attribution Values:\n\n"
81
+ output += "| Smiles Symbol | Attribution Value |\n|----------------|------------------|\n"
82
+ for i, value in enumerate(attribution_values):
83
+ output += f"| {smiles[i]} | {value:.4f} |\n"
84
+ return output
85
+
86
+
87
+
88
+ # Create the Gradio interface
89
+ demo = gr.Interface(
90
+ fn=get_attribution_values,
91
+ inputs=gr.Textbox(placeholder="Enter smiles to analyze..."),
92
+ outputs=gr.Textbox(lines=10), # Changed from gr.JSON() to gr.Textbox()
93
+ title="Explainable Inhibitory Constant Prediction for Delta Opioid Receptor",
94
+ description="Predicts the inhibitory constant for a given molecule and gives the contribution of each smiles symbol."
95
+ )
96
+
97
+ # Launch the interface and MCP server
98
+ if __name__ == "__main__":
99
+ demo.launch(mcp_server=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio[mcp]
2
+ rdkit
3
+ captum
4
+ transformers
5
+ torch
6
+ numpy