IlPakoZ commited on
Commit
4780493
·
verified ·
1 Parent(s): a81fdbb

Added multiple SMILES compatibility, edited gradio API and updated README.md

Browse files
Files changed (2) hide show
  1. README.md +14 -2
  2. app.py +82 -76
README.md CHANGED
@@ -12,6 +12,8 @@ pinned: false
12
 
13
  This model predicts drug-target interactions using a novel cross-attention architecture that combines RNA sequence understanding with molecular representation learning. The model processes RNA target sequences and drug SMILES representations to predict binding affinity scores (pKd values).
14
 
 
 
15
  ## Architecture
16
 
17
  The model consists of several key components:
@@ -53,6 +55,16 @@ from updated_app import demo
53
  demo.launch()
54
  ```
55
 
 
 
 
 
 
 
 
 
 
 
56
  ### Programmatic usage
57
 
58
  ```python
@@ -101,7 +113,7 @@ with torch.no_grad():
101
  ## Model inputs
102
 
103
  - **Target sequence**: RNA sequence using nucleotides A, U, G, C (string)
104
- - **Drug SMILES**: Simplified Molecular Input Line Entry System notation (string)
105
 
106
  ## Model outputs
107
 
@@ -196,4 +208,4 @@ This model is released under the MIT License.
196
  year={2024},
197
  publisher={Oxford University Press}
198
  }
199
- ```
 
12
 
13
  This model predicts drug-target interactions using a novel cross-attention architecture that combines RNA sequence understanding with molecular representation learning. The model processes RNA target sequences and drug SMILES representations to predict binding affinity scores (pKd values).
14
 
15
+ **Full model repository**: [https://github.com/IlPakoZ/dlrnaberta-dti-prediction](https://github.com/IlPakoZ/dlrnaberta-dti-prediction)
16
+
17
  ## Architecture
18
 
19
  The model consists of several key components:
 
55
  demo.launch()
56
  ```
57
 
58
+ The Gradio interface supports **batch predictions** by allowing multiple SMILES entries. Simply enter each SMILES string on a new line in the drug SMILES input field:
59
+
60
+ ```
61
+ CC(C)CC1=CC=C(C=C1)C(C)C(=O)O
62
+ C1CCCCC1O
63
+ C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2
64
+ ```
65
+
66
+ The model will predict binding affinity for each drug-target pair sequentially. For visualizations, the results will display only the last SMILES entry.
67
+
68
  ### Programmatic usage
69
 
70
  ```python
 
113
  ## Model inputs
114
 
115
  - **Target sequence**: RNA sequence using nucleotides A, U, G, C (string)
116
+ - **Drug SMILES**: Simplified Molecular Input Line Entry System notation (string or multiple strings, one per line)
117
 
118
  ## Model outputs
119
 
 
208
  year={2024},
209
  publisher={Oxford University Press}
210
  }
211
+ ```
app.py CHANGED
@@ -118,42 +118,52 @@ class DrugTargetInteractionApp:
118
  logger.error(f"Error loading model: {str(e)}")
119
  return False
120
 
121
- def predict_interaction(self, target_sequence, drug_smiles, max_length=512):
122
- """Predict drug-target interaction"""
123
- if self.model is None:
124
- return "Error: Model not loaded. Please load a model first."
125
-
126
- try:
127
- # Tokenize inputs
128
- target_inputs = self.target_tokenizer(
129
- target_sequence,
130
- padding="max_length",
131
- truncation=True,
132
- max_length=512,
133
- return_tensors="pt"
134
- ).to(self.device)
135
-
136
  drug_inputs = self.drug_tokenizer(
137
- drug_smiles,
138
  padding="max_length",
139
  truncation=True,
140
  max_length=512,
141
  return_tensors="pt"
142
  ).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  # Make prediction
145
  self.model.INTERPR_DISABLE_MODE()
146
- with torch.no_grad():
147
- prediction = self.model(target_inputs, drug_inputs)
148
-
149
- # Unscale if scaler exists
150
- if self.model.scaler is not None:
151
- prediction = self.model.unscale(prediction)
152
-
153
- prediction_value = prediction.cpu().numpy()[0][0]
154
-
155
- return f"Predicted Binding Affinity: {prediction_value:.4f}"
156
-
 
157
  except Exception as e:
158
  logger.error(f"Prediction error: {str(e)}")
159
  return f"Error during prediction: {str(e)}"
@@ -173,48 +183,34 @@ class DrugTargetInteractionApp:
173
  return None, None, None, "Error: Model not loaded. Please load a model first."
174
 
175
  try:
176
- # Tokenize inputs
177
- target_inputs = self.target_tokenizer(
178
- target_sequence,
179
- padding="max_length",
180
- truncation=True,
181
- max_length=512,
182
- return_tensors="pt"
183
- ).to(self.device)
184
-
185
- drug_inputs = self.drug_tokenizer(
186
- drug_smiles,
187
- padding="max_length",
188
- truncation=True,
189
- max_length=512,
190
- return_tensors="pt"
191
- ).to(self.device)
192
-
193
- # Enable interpretation mode
194
  self.model.INTERPR_ENABLE_MODE()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- # Make prediction and extract visualization data
197
- with torch.no_grad():
198
- prediction = self.model(target_inputs, drug_inputs)
199
-
200
- # Unscale if scaler exists
201
- if self.model.scaler is not None:
202
- prediction = self.model.unscale(prediction)
203
-
204
- prediction_value = prediction.cpu().numpy()[0][0]
205
-
206
- # Extract data needed for visualizations
207
- presum_values = self.model.model.presum_layer # Shape: (1, seq_len)
208
- cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len)
209
-
210
- # Get model parameters for scaling
211
- w = self.model.model.w.squeeze(1)
212
- b = self.model.model.b
213
- scaler = self.model.model.scaler
214
-
215
- logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}")
216
- logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}")
217
-
218
  # Generate visualizations
219
  try:
220
  # 1. Cross-attention heatmap
@@ -318,7 +314,6 @@ class DrugTargetInteractionApp:
318
  text="Raw Contribution\nSkipped (pKd ≤ 0)"
319
  )
320
 
321
- status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}"
322
  if prediction_value <= 0:
323
  status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
324
  if cross_attention_weights is None:
@@ -344,14 +339,14 @@ def predict_wrapper(target_seq, drug_smiles):
344
  if not target_seq.strip() or not drug_smiles.strip():
345
  return "Please provide both target sequence and drug SMILES."
346
 
347
- return app.predict_interaction(target_seq, drug_smiles)
348
 
349
  def visualize_wrapper(target_seq, drug_smiles):
350
  """Wrapper function for visualization"""
351
  if not target_seq.strip() or not drug_smiles.strip():
352
  return None, None, None, "Please provide both target sequence and drug SMILES."
353
 
354
- return app.visualize_interaction(target_seq, drug_smiles)
355
 
356
  def load_model_wrapper(model_path):
357
  """Wrapper function to load model"""
@@ -390,8 +385,11 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
390
 
391
  drug_input = gr.Textbox(
392
  label="Drug SMILES",
393
- placeholder="Enter SMILES notation (e.g., CC(C)CC1=CC=C(C=C1)C(C)C(=O)O)",
394
- lines=2
 
 
 
395
  )
396
 
397
  with gr.Row():
@@ -437,12 +435,16 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
437
  img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles)
438
  # Combine prediction result with visualization status
439
  combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images."
 
 
 
440
  return img1, img2, img3, combined_status
441
 
442
  visualize_btn.click(
443
  fn=visualize_and_update,
444
  inputs=[target_input, drug_input],
445
- outputs=[viz_state1, viz_state2, viz_state3, prediction_output]
 
446
  )
447
 
448
  with gr.Tab("📊 Visualizations"):
@@ -587,7 +589,7 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
587
  gr.Markdown("""
588
  ## About this application
589
 
590
- This application implements DLRNA-BERTa, a Dual Langauge RoBERTa Transformer model for predicting drug to RNA target interactions. The model architecture includes:
591
 
592
  - **Target encoder**: Processes RNA sequences using RNA-BERTa
593
  - **Drug encoder**: Processes molecular SMILES notation using ChemBERTa
@@ -597,18 +599,22 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
597
  ### Input requirements:
598
  - **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C)
599
  - **Drug SMILES**: Simplified Molecular Input Line Entry System notation
 
600
 
601
  ### Model features:
602
  - Cross-attention for drug-target interaction modeling
603
  - Dropout for regularization
604
  - Layer normalization for stable training
605
  - Interpretability mode for contribution and attention visualization
 
606
 
607
  ### Usage tips:
608
  1. Load a trained model using the Model Settings tab (optional)
609
  2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab
610
- 3. Click "Predict Interaction" for binding affinity prediction only
611
- 4. Click "Generate Visualizations" to create detailed interaction analysis - results will appear in the Visualizations tab
 
 
612
 
613
  For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens).
614
 
 
118
  logger.error(f"Error loading model: {str(e)}")
119
  return False
120
 
121
+ def get_target_and_smiles(self, target_sequence, drug_smiles):
122
+ # Tokenize inputs
123
+ target_inputs = self.target_tokenizer(
124
+ target_sequence,
125
+ padding="max_length",
126
+ truncation=True,
127
+ max_length=512,
128
+ return_tensors="pt"
129
+ ).to(self.device)
130
+
131
+ all_smiles = []
132
+ for smiles in drug_smiles:
 
 
 
133
  drug_inputs = self.drug_tokenizer(
134
+ smiles.strip(),
135
  padding="max_length",
136
  truncation=True,
137
  max_length=512,
138
  return_tensors="pt"
139
  ).to(self.device)
140
+ all_smiles.append(drug_inputs)
141
+
142
+ return target_inputs, all_smiles
143
+
144
+ def predict_interaction(self, target_sequence, drug_smiles):
145
+ """Predict drug-target interaction"""
146
+ if self.model is None:
147
+ return "Error: Model not loaded. Please load a model first."
148
+
149
+ try:
150
+ target_inputs, all_drug_inputs = self.get_target_and_smiles(target_sequence, drug_smiles)
151
+ to_return =[]
152
 
153
  # Make prediction
154
  self.model.INTERPR_DISABLE_MODE()
155
+ for smile_name, drug_inputs in zip(drug_smiles, all_drug_inputs):
156
+ with torch.no_grad():
157
+ prediction = self.model(target_inputs, drug_inputs)
158
+
159
+ # Unscale if scaler exists
160
+ if self.model.scaler is not None:
161
+ prediction = self.model.unscale(prediction)
162
+
163
+ prediction_value = prediction.cpu().numpy()[0][0]
164
+ to_return.append(f"{smile_name} predicted pKd: {prediction_value:.4f}")
165
+ return "\n".join(to_return)
166
+
167
  except Exception as e:
168
  logger.error(f"Prediction error: {str(e)}")
169
  return f"Error during prediction: {str(e)}"
 
183
  return None, None, None, "Error: Model not loaded. Please load a model first."
184
 
185
  try:
186
+ target_inputs, all_drug_inputs = self.get_target_and_smiles(target_sequence, drug_smiles)
187
+ to_return = []
188
+
189
+ # Make prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  self.model.INTERPR_ENABLE_MODE()
191
+ for smile_name, drug_inputs in zip(drug_smiles, all_drug_inputs):
192
+ # Make prediction and extract visualization data
193
+ with torch.no_grad():
194
+ prediction = self.model(target_inputs, drug_inputs)
195
+
196
+ # Unscale if scaler exists
197
+ if self.model.scaler is not None:
198
+ prediction = self.model.unscale(prediction)
199
+
200
+ prediction_value = prediction.cpu().numpy()[0][0]
201
+
202
+ # Extract data needed for visualizations
203
+ presum_values = self.model.model.presum_layer # Shape: (1, seq_len)
204
+ cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len)
205
+
206
+ # Get model parameters for scaling
207
+ w = self.model.model.w.squeeze(1)
208
+ b = self.model.model.b
209
+ scaler = self.model.model.scaler
210
+ to_return.append(f"{smile_name} predicted pKd: {prediction_value:.4f}")
211
+
212
+ status_msg = "\n".join(to_return)
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  # Generate visualizations
215
  try:
216
  # 1. Cross-attention heatmap
 
314
  text="Raw Contribution\nSkipped (pKd ≤ 0)"
315
  )
316
 
 
317
  if prediction_value <= 0:
318
  status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
319
  if cross_attention_weights is None:
 
339
  if not target_seq.strip() or not drug_smiles.strip():
340
  return "Please provide both target sequence and drug SMILES."
341
 
342
+ return app.predict_interaction(target_seq, drug_smiles.split("\n"))
343
 
344
  def visualize_wrapper(target_seq, drug_smiles):
345
  """Wrapper function for visualization"""
346
  if not target_seq.strip() or not drug_smiles.strip():
347
  return None, None, None, "Please provide both target sequence and drug SMILES."
348
 
349
+ return app.visualize_interaction(target_seq, drug_smiles.split("\n"))
350
 
351
  def load_model_wrapper(model_path):
352
  """Wrapper function to load model"""
 
385
 
386
  drug_input = gr.Textbox(
387
  label="Drug SMILES",
388
+ placeholder="Enter SMILES notation for one or more drugs.\n"
389
+ "For multiple SMILES, enter each on a new line:\n"
390
+ "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O\n"
391
+ "C1CCCCC1O",
392
+ lines=3
393
  )
394
 
395
  with gr.Row():
 
435
  img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles)
436
  # Combine prediction result with visualization status
437
  combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images."
438
+ if len(drug_smiles) > 1:
439
+ combined_status +="\nVisualizations are shown only for the last SMILES entry."
440
+
441
  return img1, img2, img3, combined_status
442
 
443
  visualize_btn.click(
444
  fn=visualize_and_update,
445
  inputs=[target_input, drug_input],
446
+ outputs=[viz_state1, viz_state2, viz_state3, prediction_output],
447
+ api_name="visualize_and_update" # Make this API accessible
448
  )
449
 
450
  with gr.Tab("📊 Visualizations"):
 
589
  gr.Markdown("""
590
  ## About this application
591
 
592
+ This application implements DLRNA-BERTa, a Dual Language RoBERTa Transformer model for predicting drug to RNA target interactions. The model architecture includes:
593
 
594
  - **Target encoder**: Processes RNA sequences using RNA-BERTa
595
  - **Drug encoder**: Processes molecular SMILES notation using ChemBERTa
 
599
  ### Input requirements:
600
  - **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C)
601
  - **Drug SMILES**: Simplified Molecular Input Line Entry System notation
602
+ - **Batch prediction**: Enter multiple SMILES strings, one per line, to predict binding affinity for multiple drugs against the same target
603
 
604
  ### Model features:
605
  - Cross-attention for drug-target interaction modeling
606
  - Dropout for regularization
607
  - Layer normalization for stable training
608
  - Interpretability mode for contribution and attention visualization
609
+ - Support for batch predictions with multiple SMILES entries
610
 
611
  ### Usage tips:
612
  1. Load a trained model using the Model Settings tab (optional)
613
  2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab
614
+ 3. For batch predictions, enter multiple SMILES strings (one per line) in the drug SMILES field
615
+ 4. Click "Predict Interaction" for binding affinity prediction only
616
+ 5. Click "Generate Visualizations" to create detailed interaction analysis - results will appear in the Visualizations tab
617
+ 6. Note: Visualizations are generated only for the last SMILES entry when using batch mode
618
 
619
  For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens).
620