import gradio as gr import torch import numpy as np from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler from configuration_dlmberta import InteractionModelATTNConfig from chemberta import ChembertaTokenizer import json import os from pathlib import Path import logging # Import visualization functions from analysis import plot_crossattention_weights, plot_presum from PIL import Image, ImageDraw, ImageFont # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def create_placeholder_image(width=600, height=400, text="No visualization available", bg_color=(0, 0, 0, 0)): """ Create a transparent placeholder image with text Args: width (int): Image width height (int): Image height text (str): Text to display bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent Returns: PIL.Image: Transparent placeholder image """ # Create image with transparent background img = Image.new('RGBA', (width, height), bg_color) draw = ImageDraw.Draw(img) # Try to use a default font, fallback to default if not available try: font = ImageFont.truetype("arial.ttf", 16) except: try: font = ImageFont.load_default() except: font = None # Get text size and position for centering if font: bbox = draw.textbbox((0, 0), text, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] else: # Rough estimation if no font available text_width = len(text) * 8 text_height = 16 x = (width - text_width) // 2 y = (height - text_height) // 2 # Draw text in gray draw.text((x, y), text, fill=(128, 128, 128, 255), font=font) return img class DrugTargetInteractionApp: def __init__(self): self.model = None self.target_tokenizer = None self.drug_tokenizer = None self.scaler = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(self, model_path="./"): """Load the pre-trained model and tokenizers""" try: # Load configuration config = InteractionModelATTNConfig.from_pretrained(model_path) # Load drug encoder (ChemBERTa) drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR") drug_encoder_config.pooler = None drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False) # Load target encoder target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700") # Load scaler if exists scaler_path = os.path.join(model_path, "scaler.config") scaler = None if os.path.exists(scaler_path): scaler = StdScaler() scaler.load(model_path) self.model = InteractionModelATTNForRegression.from_pretrained( model_path, config=config, target_encoder=target_encoder, drug_encoder=drug_encoder, scaler=scaler ) self.model.to(self.device) self.model.eval() # Load tokenizers self.target_tokenizer = AutoTokenizer.from_pretrained( os.path.join(model_path, "target_tokenizer") ) # Load drug tokenizer (ChemBERTa) vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json") self.drug_tokenizer = ChembertaTokenizer(vocab_file) logger.info("Model and tokenizers loaded successfully!") return True except Exception as e: logger.error(f"Error loading model: {str(e)}") return False def get_target_and_smiles(self, target_sequence, drug_smiles): # Tokenize inputs target_inputs = self.target_tokenizer( target_sequence, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) all_smiles = [] for smiles in drug_smiles: drug_inputs = self.drug_tokenizer( smiles.strip(), padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) all_smiles.append(drug_inputs) return target_inputs, all_smiles def predict_interaction(self, target_sequence, drug_smiles): """Predict drug-target interaction""" if self.model is None: return "Error: Model not loaded. Please load a model first." try: target_inputs, all_drug_inputs = self.get_target_and_smiles(target_sequence, drug_smiles) to_return =[] # Make prediction self.model.INTERPR_DISABLE_MODE() for smile_name, drug_inputs in zip(drug_smiles, all_drug_inputs): with torch.no_grad(): prediction = self.model(target_inputs, drug_inputs) # Unscale if scaler exists if self.model.scaler is not None: prediction = self.model.unscale(prediction) prediction_value = prediction.cpu().numpy()[0][0] to_return.append(f"{smile_name} predicted pKd: {prediction_value:.4f}") return "\n".join(to_return) except Exception as e: logger.error(f"Prediction error: {str(e)}") return f"Error during prediction: {str(e)}" def visualize_interaction(self, target_sequence, drug_smiles): """ Generate visualization images for drug-target interaction Args: target_sequence (str): RNA sequence drug_smiles (str): Drug SMILES notation Returns: tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message) """ if self.model is None: return None, None, None, "Error: Model not loaded. Please load a model first." try: target_inputs, all_drug_inputs = self.get_target_and_smiles(target_sequence, drug_smiles) to_return = [] # Make prediction self.model.INTERPR_ENABLE_MODE() for smile_name, drug_inputs in zip(drug_smiles, all_drug_inputs): # Make prediction and extract visualization data with torch.no_grad(): prediction = self.model(target_inputs, drug_inputs) # Unscale if scaler exists if self.model.scaler is not None: prediction = self.model.unscale(prediction) prediction_value = prediction.cpu().numpy()[0][0] # Extract data needed for visualizations presum_values = self.model.model.presum_layer # Shape: (1, seq_len) cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len) # Get model parameters for scaling w = self.model.model.w.squeeze(1) b = self.model.model.b scaler = self.model.model.scaler to_return.append(f"{smile_name} predicted pKd: {prediction_value:.4f}") status_msg = "\n".join(to_return) # Generate visualizations try: # 1. Cross-attention heatmap cross_attention_img = None logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}") if cross_attention_weights is not None: logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}") try: cross_attn_matrix = cross_attention_weights[0, 0] if cross_attn_matrix is not None: logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}") logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}") logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}") cross_attention_img = plot_crossattention_weights( target_inputs["attention_mask"][0], drug_inputs["attention_mask"][0], target_inputs, drug_inputs, cross_attn_matrix, self.target_tokenizer, self.drug_tokenizer ) else: logger.warning("Could not extract valid cross-attention matrix") except (IndexError, TypeError, AttributeError) as e: logger.warning(f"Error extracting cross-attention matrix: {str(e)}") cross_attn_matrix = None else: logger.warning("Cross-attention weights are None") except Exception as e: logger.error(f"Cross-attention visualization error: {str(e)}") cross_attention_img = None try: # 2. Normalized contribution visualization (only if pKd > 0) normalized_img = None if presum_values is not None: normalized_img = plot_presum( target_inputs, presum_values.detach(), # Detach the tensor scaler, w.detach(), # Detach the tensor b.detach(), # Detach the tensor self.target_tokenizer, raw_affinities=False ) else: if prediction_value <= 0: logger.info("Skipping normalized affinities visualization as pKd <= 0") if presum_values is None: logger.warning("Cannot generate raw visualization: presum values are None") except Exception as e: logger.error(f"Normalized contribution visualization error: {str(e)}") normalized_img = None try: # 3. Raw contribution visualization (always generate) raw_img = None if prediction_value > 0 and presum_values is not None: raw_img = plot_presum( target_inputs, presum_values.detach(), # Detach the tensor scaler, w.detach(), # Detach the tensor b.detach(), # Detach the tensor self.target_tokenizer, raw_affinities=True ) else: logger.warning("Presum values are None") except Exception as e: logger.error(f"Raw contribution visualization error: {str(e)}") raw_img = None # Disable interpretation mode after use self.model.INTERPR_DISABLE_MODE() # Create placeholder images if generation failed if cross_attention_img is None: cross_attention_img = create_placeholder_image( text="Cross-Attention Heatmap\nFailed to generate" ) if normalized_img is None: normalized_img = create_placeholder_image( text="Normalized Contribution\nFailed to generate" ) if raw_img is None and prediction_value > 0: raw_img = create_placeholder_image( text="Raw Contribution\nFailed to generate" ) elif raw_img is None: raw_img = create_placeholder_image( text="Raw Contribution\nSkipped (pKd ≤ 0)" ) if prediction_value <= 0: status_msg += " (Raw contribution visualization skipped due to non-positive pKd)" if cross_attention_weights is None: status_msg += " (Cross-attention visualization failed: weights not available)" return cross_attention_img, raw_img, normalized_img, status_msg except Exception as e: logger.error(f"Visualization error: {str(e)}") # Make sure to disable interpretation mode even if there's an error try: self.model.INTERPR_DISABLE_MODE() except: pass return None, None, None, f"Error during visualization: {str(e)}" # Initialize the app app = DrugTargetInteractionApp() def smiles_preprocessing(drug_smiles, remove_dupl): drugs = drug_smiles.strip().split("\n") # Remove molecule duplicates in O(n) while preserving the order if remove_dupl: seen = set() sorted_drugs = [] kept = 0 for x in drugs: if x not in seen: seen.add(x) sorted_drugs.append(x) kept += 1 logger.info(f"{kept-len(drugs)} duplicate smiles removed!") drugs = sorted_drugs return drugs[:2000] def predict_wrapper(target_seq, drug_smiles, remove_dups): """Wrapper function for Gradio interface""" if not target_seq.strip() or not drug_smiles.strip(): return "Please provide both target sequence and drug SMILES." target_seq = target_seq.strip() drug_smiles = smiles_preprocessing(drug_smiles, remove_dups) return app.predict_interaction(target_seq, drug_smiles) def visualize_wrapper(target_seq, drug_smiles, remove_dups): """Wrapper function for visualization""" if not target_seq.strip() or not drug_smiles.strip(): return None, None, None, "Please provide both target sequence and drug SMILES." target_seq = target_seq.strip() drug_smiles = smiles_preprocessing(drug_smiles, remove_dups) return app.visualize_interaction(target_seq, drug_smiles) def load_model_wrapper(model_path): """Wrapper function to load model""" if app.load_model(model_path): return "Model loaded successfully!" else: return "Failed to load model. Check the path and files." # Create Gradio interface with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo: gr.HTML("""

🧬 Drug-Target Interaction Predictor

Predict binding affinity between drugs and target RNA sequences using deep learning

""") # Create state variables to share images between tabs viz_state1 = gr.State() viz_state2 = gr.State() viz_state3 = gr.State() with gr.Tab("🔮 Prediction & Analysis"): with gr.Row(): with gr.Column(scale=1): target_input = gr.Textbox( label="Target RNA Sequence", placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)", lines=5, max_lines=5 ) drug_input = gr.Textbox( label="Drug SMILES", placeholder="Enter SMILES notation for one or more drugs.\n" "For multiple SMILES, enter each on a new line (max 2000):\n" "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O\n" "C1CCCCC1O", lines=5, max_lines=5, ) remove_dups_checkbox = gr.Checkbox( label="Remove duplicate SMILES", value=False ) with gr.Row(): predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg") visualize_btn = gr.Button("📊 Generate Visualizations", variant="secondary", size="lg") with gr.Column(scale=1): prediction_output = gr.Textbox( label="Prediction Result", interactive=False, lines=4 ) # Example inputs gr.HTML("

📚 Example Inputs:

") examples = gr.Examples( examples=[ [ "AUGCUAGCUAGUACGUAUAUCUGCACUGC", "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" ], [ "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU", "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2" ] ], inputs=[target_input, drug_input, remove_dups_checkbox], outputs=prediction_output, fn=predict_wrapper, cache_examples=False ) # Button click events predict_btn.click( fn=predict_wrapper, inputs=[target_input, drug_input, remove_dups_checkbox], outputs=prediction_output ) def visualize_and_update(target_seq, drug_smiles, remove_dups): """Generate visualizations and update both status and state""" img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles, remove_dups) # Combine prediction result with visualization status combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images." if len(drug_smiles) > 1: combined_status +="\nVisualizations are shown only for the last SMILES entry." return img1, img2, img3, combined_status visualize_btn.click( fn=visualize_and_update, inputs=[target_input, drug_input, remove_dups_checkbox], outputs=[viz_state1, viz_state2, viz_state3, prediction_output], api_name="visualize_and_update" ) with gr.Tab("📊 Visualizations"): gr.HTML("""

🔬 Interaction Analysis & Visualizations

Generated visualizations will appear here after clicking "Generate Visualizations" in the Prediction tab

""") # Visualization outputs - Large and vertically aligned viz_image1 = gr.Image( label="Cross-Attention Heatmap", type="pil", interactive=False, container=True, height=500, value=create_placeholder_image(text="Cross-Attention Heatmap\n(Generate visualizations in the Prediction tab)") ) viz_image2 = gr.Image( label="Raw pKd Contribution Visualization", type="pil", interactive=False, container=True, height=500, value=create_placeholder_image(text="Raw pKd Contribution\n(Generate visualizations in the Prediction tab)") ) viz_image3 = gr.Image( label="Normalized pKd Contribution Visualization", type="pil", interactive=False, container=True, height=500, value=create_placeholder_image(text="Normalized pKd Contribution\n(Generate visualizations in the Prediction tab)") ) # Update visualization images when state changes viz_state1.change( fn=lambda x: x, inputs=viz_state1, outputs=viz_image1 ) viz_state2.change( fn=lambda x: x, inputs=viz_state2, outputs=viz_image2 ) viz_state3.change( fn=lambda x: x, inputs=viz_state3, outputs=viz_image3 ) with gr.Tab("⚙️ Model Settings"): gr.HTML("

Model Configuration

") model_path_input = gr.Textbox( label="Model Path", value="./", placeholder="Path to model directory" ) load_model_btn = gr.Button("📥 Load Model", variant="secondary") model_status = gr.Textbox( label="Status", interactive=False, value="No model loaded" ) load_model_btn.click( fn=load_model_wrapper, inputs=model_path_input, outputs=model_status ) with gr.Tab("📊 Dataset"): gr.Markdown(""" ## Training and Test Datasets ### Fine-tuning Dataset (Training) The model was trained on a dataset comprising **1,439 RNA–drug interaction pairs**, including: - **759 unique compounds** (SMILES representations) - **294 unique RNA sequences** - Dissociation constants (pKd values) for binding affinity prediction **RNA Sequence Distribution by Type:** | RNA Sequence Type | Number of Interactions | |-------------------|------------------------| | Aptamers | 520 | | Ribosomal | 295 | | Viral RNAs | 281 | | miRNAs | 146 | | Riboswitches | 100 | | Repeats | 97 | | **Total** | **1,439** | ### External Evaluation Dataset (Test) Model validation was performed using external ROBIN classification datasets containing **5,534 RNA–drug pairs**: - **2,991 positive interactions** - **2,538 negative interactions** **Test Dataset Composition:** - **1,617 aptamer pairs** (5 unique RNA sequences) - **1,828 viral RNA pairs** (6 unique RNA sequences) - **1,459 riboswitch pairs** (5 unique RNA sequences) - **630 miRNA pairs** (3 unique RNA sequences) ### Dataset Downloads - [Training Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/training_data.csv?download=true) - [Test Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/test_data.csv?download=true) ### Citation Original datasets published by: **Krishnan et al.** - Available on the RSAPred website in PDF format. *Reference:* ```bibtex @article{krishnan2024reliable, title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning}, author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael}, journal={Briefings in Bioinformatics}, volume={25}, number={2}, pages={bbae002}, year={2024}, publisher={Oxford University Press} } ``` """) with gr.Tab("ℹ️ About"): gr.Markdown(""" ## About this application This application implements DLRNA-BERTa, a Dual Language RoBERTa Transformer model for predicting drug-to-RNA target interactions. The architecture combines: - **Target encoder**: RNA-BERTa for processing RNA sequences - **Drug encoder**: ChemBERTa for SMILES representation - **Cross-attention mechanism**: Captures interactions between drug and target - **Regression head**: Predicts binding affinity (pKd) ### Input requirements - **Target sequence**: RNA sequence (A, U, G, C) - **Drug SMILES**: One or more SMILES strings - For batch mode, enter each SMILES on a new line (up to 2000 entries) - A checkbox option allows automatic removal of duplicate SMILES before prediction ### Model features - Cross-attention for drug-target interaction modeling - Regularization via dropout - Layer normalization for stable training - Dedicated interpretability mode for visualization - Batch prediction with optional de-duplication ### Usage tips 1. Load a model (optional) in the Model Settings tab 2. Enter an RNA sequence and one or more SMILES strings 3. Use the **“Remove duplicate SMILES”** checkbox if you want duplicates filtered automatically 4. Click *Predict Interaction* for affinity scores 5. Click *Generate Visualizations* for interpretability plots 6. Visualizations are produced only for the final SMILES entry in batch mode For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens). ### Visualization features: - **Cross-attention heatmap**: Shows cross-attention weights between drug and target tokens - **Unnormalized pKd contribution**: Shows unnormalized signed contributions from each target token (only when pKd > 0) - **Normalized pKd contribution**: Shows normalized non-negative contributions from each target token ### Performance metrics: - Training on diverse drug-target interaction datasets - Evaluated using RMSE, Pearson correlation, and Concordance Index - Optimized for both predictive accuracy and interpretability ### GitHub repository: - The full model GitHub repository can be found here: https://github.com/IlPakoZ/dlrnaberta-dti-prediction ### Contribution: - Special thanks to Umut Onur Özcan for help in developing this space:) ### Contact: - Ziaurrehman Tanoli (ziaurrehman.tanoli@helsinki.fi) Principal investigator at Institute for Molecular Medicine Finland HiLIFE, University of Helsinki, Finland. """) # Launch the app if __name__ == "__main__": # Try to load model on startup if os.path.exists("./config.json"): app.load_model("./") demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )