import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification from captum.attr import IntegratedGradients import torch import torch.nn as nn import numpy as np from rdkit import Chem # Initialize model and tokenizer as global variables tokenizer = AutoTokenizer.from_pretrained("fabikru/molencoder-D3R-simple") model = AutoModelForSequenceClassification.from_pretrained("fabikru/molencoder-D3R-simple") class ModelWrapper(nn.Module): def __init__(self, model): super(ModelWrapper, self).__init__() self.model = model self.token_dropout = False def forward(self, inputs_embeds): input_shape = inputs_embeds.size()[:-1] batch_size, seq_length = input_shape attention_mask = torch.ones(input_shape, device=inputs_embeds.device) outputs = self.model.forward( inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) return outputs.logits # Initialize the wrapped model as a global variable wrapper = ModelWrapper(model) def analyse_inhibitory_constant_for_delta_opioid_receptor(smiles: str) -> str: """ Analyses the inhibitory constant for a given molecule to delta opioid receptor and gives the contribution of each smiles symbol. Args: smiles (str): The smiles string to analyze Returns: str: The inhibitory constant and a markdown table with the smiles symbols and their attribution values for the inhibitory constant prediction """ try: mol = Chem.MolFromSmiles(smiles) if mol is None: return "Invalid SMILES string" except Exception as e: return f"Error parsing SMILES string: {e}" # get input embedding tensors embedding_layer = model.model.embeddings.tok_embeddings ids = tokenizer.batch_encode_plus([smiles], add_special_tokens=True, is_split_into_words=False) input_ids = torch.tensor([ids['input_ids'][0]]) inputs_embeds = embedding_layer(input_ids) inputs_embeds.requires_grad_(True) baseline_ids = [tokenizer.cls_token_id] + [tokenizer.pad_token_id] * len(smiles) + [tokenizer.sep_token_id] baseline_ids = torch.tensor([baseline_ids]) baseline_embeds = embedding_layer(baseline_ids) # Get model prediction with torch.no_grad(): logits = model(input_ids).logits prediction_log = logits.squeeze().item() # Convert from -log to nM (reverse transformation) binding_affinity_nM = 10 ** (-prediction_log) method = IntegratedGradients(wrapper) attributions = method.attribute(inputs=inputs_embeds, baselines=baseline_embeds) # mean over embedding size to get one attribution value per input token (excluding special tokens) attributions_np = attributions.squeeze().cpu().detach().numpy() attributions_aggregated = np.mean(attributions_np, axis=1) attribution_values = attributions_aggregated[1:-1] # Format output with prediction and attribution table output = f"Inhibitory Constant: {binding_affinity_nM:.2f} nM\n\n" output += "Attribution Values:\n\n" output += "| Smiles Symbol | Attribution Value |\n|----------------|------------------|\n" for i, value in enumerate(attribution_values): output += f"| {smiles[i]} | {value:.4f} |\n" return output # Create the Gradio interface demo = gr.Interface( fn=analyse_inhibitory_constant_for_delta_opioid_receptor, inputs=gr.Textbox(placeholder="Enter smiles to analyze..."), outputs=gr.Textbox(lines=10), # Changed from gr.JSON() to gr.Textbox() title="Explainable Inhibitory Constant Prediction for Delta Opioid Receptor", description="Predicts the inhibitory constant for a given molecule and gives the contribution of each smiles symbol." ) # Launch the interface and MCP server if __name__ == "__main__": demo.launch(mcp_server=True)