Spaces:
Sleeping
Sleeping
Added multiple SMILES compatibility, edited gradio API and updated README.md
Browse files
README.md
CHANGED
|
@@ -12,6 +12,8 @@ pinned: false
|
|
| 12 |
|
| 13 |
This model predicts drug-target interactions using a novel cross-attention architecture that combines RNA sequence understanding with molecular representation learning. The model processes RNA target sequences and drug SMILES representations to predict binding affinity scores (pKd values).
|
| 14 |
|
|
|
|
|
|
|
| 15 |
## Architecture
|
| 16 |
|
| 17 |
The model consists of several key components:
|
|
@@ -53,6 +55,16 @@ from updated_app import demo
|
|
| 53 |
demo.launch()
|
| 54 |
```
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
### Programmatic usage
|
| 57 |
|
| 58 |
```python
|
|
@@ -101,7 +113,7 @@ with torch.no_grad():
|
|
| 101 |
## Model inputs
|
| 102 |
|
| 103 |
- **Target sequence**: RNA sequence using nucleotides A, U, G, C (string)
|
| 104 |
-
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation (string)
|
| 105 |
|
| 106 |
## Model outputs
|
| 107 |
|
|
@@ -196,4 +208,4 @@ This model is released under the MIT License.
|
|
| 196 |
year={2024},
|
| 197 |
publisher={Oxford University Press}
|
| 198 |
}
|
| 199 |
-
```
|
|
|
|
| 12 |
|
| 13 |
This model predicts drug-target interactions using a novel cross-attention architecture that combines RNA sequence understanding with molecular representation learning. The model processes RNA target sequences and drug SMILES representations to predict binding affinity scores (pKd values).
|
| 14 |
|
| 15 |
+
**Full model repository**: [https://github.com/IlPakoZ/dlrnaberta-dti-prediction](https://github.com/IlPakoZ/dlrnaberta-dti-prediction)
|
| 16 |
+
|
| 17 |
## Architecture
|
| 18 |
|
| 19 |
The model consists of several key components:
|
|
|
|
| 55 |
demo.launch()
|
| 56 |
```
|
| 57 |
|
| 58 |
+
The Gradio interface supports **batch predictions** by allowing multiple SMILES entries. Simply enter each SMILES string on a new line in the drug SMILES input field:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
CC(C)CC1=CC=C(C=C1)C(C)C(=O)O
|
| 62 |
+
C1CCCCC1O
|
| 63 |
+
C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
The model will predict binding affinity for each drug-target pair sequentially. For visualizations, the results will display only the last SMILES entry.
|
| 67 |
+
|
| 68 |
### Programmatic usage
|
| 69 |
|
| 70 |
```python
|
|
|
|
| 113 |
## Model inputs
|
| 114 |
|
| 115 |
- **Target sequence**: RNA sequence using nucleotides A, U, G, C (string)
|
| 116 |
+
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation (string or multiple strings, one per line)
|
| 117 |
|
| 118 |
## Model outputs
|
| 119 |
|
|
|
|
| 208 |
year={2024},
|
| 209 |
publisher={Oxford University Press}
|
| 210 |
}
|
| 211 |
+
```
|
app.py
CHANGED
|
@@ -118,42 +118,52 @@ class DrugTargetInteractionApp:
|
|
| 118 |
logger.error(f"Error loading model: {str(e)}")
|
| 119 |
return False
|
| 120 |
|
| 121 |
-
def
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
return_tensors="pt"
|
| 134 |
-
).to(self.device)
|
| 135 |
-
|
| 136 |
drug_inputs = self.drug_tokenizer(
|
| 137 |
-
|
| 138 |
padding="max_length",
|
| 139 |
truncation=True,
|
| 140 |
max_length=512,
|
| 141 |
return_tensors="pt"
|
| 142 |
).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
# Make prediction
|
| 145 |
self.model.INTERPR_DISABLE_MODE()
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
| 157 |
except Exception as e:
|
| 158 |
logger.error(f"Prediction error: {str(e)}")
|
| 159 |
return f"Error during prediction: {str(e)}"
|
|
@@ -173,48 +183,34 @@ class DrugTargetInteractionApp:
|
|
| 173 |
return None, None, None, "Error: Model not loaded. Please load a model first."
|
| 174 |
|
| 175 |
try:
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
truncation=True,
|
| 181 |
-
max_length=512,
|
| 182 |
-
return_tensors="pt"
|
| 183 |
-
).to(self.device)
|
| 184 |
-
|
| 185 |
-
drug_inputs = self.drug_tokenizer(
|
| 186 |
-
drug_smiles,
|
| 187 |
-
padding="max_length",
|
| 188 |
-
truncation=True,
|
| 189 |
-
max_length=512,
|
| 190 |
-
return_tensors="pt"
|
| 191 |
-
).to(self.device)
|
| 192 |
-
|
| 193 |
-
# Enable interpretation mode
|
| 194 |
self.model.INTERPR_ENABLE_MODE()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
# Make prediction and extract visualization data
|
| 197 |
-
with torch.no_grad():
|
| 198 |
-
prediction = self.model(target_inputs, drug_inputs)
|
| 199 |
-
|
| 200 |
-
# Unscale if scaler exists
|
| 201 |
-
if self.model.scaler is not None:
|
| 202 |
-
prediction = self.model.unscale(prediction)
|
| 203 |
-
|
| 204 |
-
prediction_value = prediction.cpu().numpy()[0][0]
|
| 205 |
-
|
| 206 |
-
# Extract data needed for visualizations
|
| 207 |
-
presum_values = self.model.model.presum_layer # Shape: (1, seq_len)
|
| 208 |
-
cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len)
|
| 209 |
-
|
| 210 |
-
# Get model parameters for scaling
|
| 211 |
-
w = self.model.model.w.squeeze(1)
|
| 212 |
-
b = self.model.model.b
|
| 213 |
-
scaler = self.model.model.scaler
|
| 214 |
-
|
| 215 |
-
logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}")
|
| 216 |
-
logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}")
|
| 217 |
-
|
| 218 |
# Generate visualizations
|
| 219 |
try:
|
| 220 |
# 1. Cross-attention heatmap
|
|
@@ -318,7 +314,6 @@ class DrugTargetInteractionApp:
|
|
| 318 |
text="Raw Contribution\nSkipped (pKd ≤ 0)"
|
| 319 |
)
|
| 320 |
|
| 321 |
-
status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}"
|
| 322 |
if prediction_value <= 0:
|
| 323 |
status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
|
| 324 |
if cross_attention_weights is None:
|
|
@@ -344,14 +339,14 @@ def predict_wrapper(target_seq, drug_smiles):
|
|
| 344 |
if not target_seq.strip() or not drug_smiles.strip():
|
| 345 |
return "Please provide both target sequence and drug SMILES."
|
| 346 |
|
| 347 |
-
return app.predict_interaction(target_seq, drug_smiles)
|
| 348 |
|
| 349 |
def visualize_wrapper(target_seq, drug_smiles):
|
| 350 |
"""Wrapper function for visualization"""
|
| 351 |
if not target_seq.strip() or not drug_smiles.strip():
|
| 352 |
return None, None, None, "Please provide both target sequence and drug SMILES."
|
| 353 |
|
| 354 |
-
return app.visualize_interaction(target_seq, drug_smiles)
|
| 355 |
|
| 356 |
def load_model_wrapper(model_path):
|
| 357 |
"""Wrapper function to load model"""
|
|
@@ -390,8 +385,11 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
|
|
| 390 |
|
| 391 |
drug_input = gr.Textbox(
|
| 392 |
label="Drug SMILES",
|
| 393 |
-
placeholder="Enter SMILES notation
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
| 395 |
)
|
| 396 |
|
| 397 |
with gr.Row():
|
|
@@ -437,12 +435,16 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
|
|
| 437 |
img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles)
|
| 438 |
# Combine prediction result with visualization status
|
| 439 |
combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images."
|
|
|
|
|
|
|
|
|
|
| 440 |
return img1, img2, img3, combined_status
|
| 441 |
|
| 442 |
visualize_btn.click(
|
| 443 |
fn=visualize_and_update,
|
| 444 |
inputs=[target_input, drug_input],
|
| 445 |
-
outputs=[viz_state1, viz_state2, viz_state3, prediction_output]
|
|
|
|
| 446 |
)
|
| 447 |
|
| 448 |
with gr.Tab("📊 Visualizations"):
|
|
@@ -587,7 +589,7 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
|
|
| 587 |
gr.Markdown("""
|
| 588 |
## About this application
|
| 589 |
|
| 590 |
-
This application implements DLRNA-BERTa, a Dual
|
| 591 |
|
| 592 |
- **Target encoder**: Processes RNA sequences using RNA-BERTa
|
| 593 |
- **Drug encoder**: Processes molecular SMILES notation using ChemBERTa
|
|
@@ -597,18 +599,22 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
|
|
| 597 |
### Input requirements:
|
| 598 |
- **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C)
|
| 599 |
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation
|
|
|
|
| 600 |
|
| 601 |
### Model features:
|
| 602 |
- Cross-attention for drug-target interaction modeling
|
| 603 |
- Dropout for regularization
|
| 604 |
- Layer normalization for stable training
|
| 605 |
- Interpretability mode for contribution and attention visualization
|
|
|
|
| 606 |
|
| 607 |
### Usage tips:
|
| 608 |
1. Load a trained model using the Model Settings tab (optional)
|
| 609 |
2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab
|
| 610 |
-
3.
|
| 611 |
-
4. Click "
|
|
|
|
|
|
|
| 612 |
|
| 613 |
For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens).
|
| 614 |
|
|
|
|
| 118 |
logger.error(f"Error loading model: {str(e)}")
|
| 119 |
return False
|
| 120 |
|
| 121 |
+
def get_target_and_smiles(self, target_sequence, drug_smiles):
|
| 122 |
+
# Tokenize inputs
|
| 123 |
+
target_inputs = self.target_tokenizer(
|
| 124 |
+
target_sequence,
|
| 125 |
+
padding="max_length",
|
| 126 |
+
truncation=True,
|
| 127 |
+
max_length=512,
|
| 128 |
+
return_tensors="pt"
|
| 129 |
+
).to(self.device)
|
| 130 |
+
|
| 131 |
+
all_smiles = []
|
| 132 |
+
for smiles in drug_smiles:
|
|
|
|
|
|
|
|
|
|
| 133 |
drug_inputs = self.drug_tokenizer(
|
| 134 |
+
smiles.strip(),
|
| 135 |
padding="max_length",
|
| 136 |
truncation=True,
|
| 137 |
max_length=512,
|
| 138 |
return_tensors="pt"
|
| 139 |
).to(self.device)
|
| 140 |
+
all_smiles.append(drug_inputs)
|
| 141 |
+
|
| 142 |
+
return target_inputs, all_smiles
|
| 143 |
+
|
| 144 |
+
def predict_interaction(self, target_sequence, drug_smiles):
|
| 145 |
+
"""Predict drug-target interaction"""
|
| 146 |
+
if self.model is None:
|
| 147 |
+
return "Error: Model not loaded. Please load a model first."
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
target_inputs, all_drug_inputs = self.get_target_and_smiles(target_sequence, drug_smiles)
|
| 151 |
+
to_return =[]
|
| 152 |
|
| 153 |
# Make prediction
|
| 154 |
self.model.INTERPR_DISABLE_MODE()
|
| 155 |
+
for smile_name, drug_inputs in zip(drug_smiles, all_drug_inputs):
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
prediction = self.model(target_inputs, drug_inputs)
|
| 158 |
+
|
| 159 |
+
# Unscale if scaler exists
|
| 160 |
+
if self.model.scaler is not None:
|
| 161 |
+
prediction = self.model.unscale(prediction)
|
| 162 |
+
|
| 163 |
+
prediction_value = prediction.cpu().numpy()[0][0]
|
| 164 |
+
to_return.append(f"{smile_name} predicted pKd: {prediction_value:.4f}")
|
| 165 |
+
return "\n".join(to_return)
|
| 166 |
+
|
| 167 |
except Exception as e:
|
| 168 |
logger.error(f"Prediction error: {str(e)}")
|
| 169 |
return f"Error during prediction: {str(e)}"
|
|
|
|
| 183 |
return None, None, None, "Error: Model not loaded. Please load a model first."
|
| 184 |
|
| 185 |
try:
|
| 186 |
+
target_inputs, all_drug_inputs = self.get_target_and_smiles(target_sequence, drug_smiles)
|
| 187 |
+
to_return = []
|
| 188 |
+
|
| 189 |
+
# Make prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
self.model.INTERPR_ENABLE_MODE()
|
| 191 |
+
for smile_name, drug_inputs in zip(drug_smiles, all_drug_inputs):
|
| 192 |
+
# Make prediction and extract visualization data
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
prediction = self.model(target_inputs, drug_inputs)
|
| 195 |
+
|
| 196 |
+
# Unscale if scaler exists
|
| 197 |
+
if self.model.scaler is not None:
|
| 198 |
+
prediction = self.model.unscale(prediction)
|
| 199 |
+
|
| 200 |
+
prediction_value = prediction.cpu().numpy()[0][0]
|
| 201 |
+
|
| 202 |
+
# Extract data needed for visualizations
|
| 203 |
+
presum_values = self.model.model.presum_layer # Shape: (1, seq_len)
|
| 204 |
+
cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len)
|
| 205 |
+
|
| 206 |
+
# Get model parameters for scaling
|
| 207 |
+
w = self.model.model.w.squeeze(1)
|
| 208 |
+
b = self.model.model.b
|
| 209 |
+
scaler = self.model.model.scaler
|
| 210 |
+
to_return.append(f"{smile_name} predicted pKd: {prediction_value:.4f}")
|
| 211 |
+
|
| 212 |
+
status_msg = "\n".join(to_return)
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
# Generate visualizations
|
| 215 |
try:
|
| 216 |
# 1. Cross-attention heatmap
|
|
|
|
| 314 |
text="Raw Contribution\nSkipped (pKd ≤ 0)"
|
| 315 |
)
|
| 316 |
|
|
|
|
| 317 |
if prediction_value <= 0:
|
| 318 |
status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
|
| 319 |
if cross_attention_weights is None:
|
|
|
|
| 339 |
if not target_seq.strip() or not drug_smiles.strip():
|
| 340 |
return "Please provide both target sequence and drug SMILES."
|
| 341 |
|
| 342 |
+
return app.predict_interaction(target_seq, drug_smiles.split("\n"))
|
| 343 |
|
| 344 |
def visualize_wrapper(target_seq, drug_smiles):
|
| 345 |
"""Wrapper function for visualization"""
|
| 346 |
if not target_seq.strip() or not drug_smiles.strip():
|
| 347 |
return None, None, None, "Please provide both target sequence and drug SMILES."
|
| 348 |
|
| 349 |
+
return app.visualize_interaction(target_seq, drug_smiles.split("\n"))
|
| 350 |
|
| 351 |
def load_model_wrapper(model_path):
|
| 352 |
"""Wrapper function to load model"""
|
|
|
|
| 385 |
|
| 386 |
drug_input = gr.Textbox(
|
| 387 |
label="Drug SMILES",
|
| 388 |
+
placeholder="Enter SMILES notation for one or more drugs.\n"
|
| 389 |
+
"For multiple SMILES, enter each on a new line:\n"
|
| 390 |
+
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O\n"
|
| 391 |
+
"C1CCCCC1O",
|
| 392 |
+
lines=3
|
| 393 |
)
|
| 394 |
|
| 395 |
with gr.Row():
|
|
|
|
| 435 |
img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles)
|
| 436 |
# Combine prediction result with visualization status
|
| 437 |
combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images."
|
| 438 |
+
if len(drug_smiles) > 1:
|
| 439 |
+
combined_status +="\nVisualizations are shown only for the last SMILES entry."
|
| 440 |
+
|
| 441 |
return img1, img2, img3, combined_status
|
| 442 |
|
| 443 |
visualize_btn.click(
|
| 444 |
fn=visualize_and_update,
|
| 445 |
inputs=[target_input, drug_input],
|
| 446 |
+
outputs=[viz_state1, viz_state2, viz_state3, prediction_output],
|
| 447 |
+
api_name="visualize_and_update" # Make this API accessible
|
| 448 |
)
|
| 449 |
|
| 450 |
with gr.Tab("📊 Visualizations"):
|
|
|
|
| 589 |
gr.Markdown("""
|
| 590 |
## About this application
|
| 591 |
|
| 592 |
+
This application implements DLRNA-BERTa, a Dual Language RoBERTa Transformer model for predicting drug to RNA target interactions. The model architecture includes:
|
| 593 |
|
| 594 |
- **Target encoder**: Processes RNA sequences using RNA-BERTa
|
| 595 |
- **Drug encoder**: Processes molecular SMILES notation using ChemBERTa
|
|
|
|
| 599 |
### Input requirements:
|
| 600 |
- **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C)
|
| 601 |
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation
|
| 602 |
+
- **Batch prediction**: Enter multiple SMILES strings, one per line, to predict binding affinity for multiple drugs against the same target
|
| 603 |
|
| 604 |
### Model features:
|
| 605 |
- Cross-attention for drug-target interaction modeling
|
| 606 |
- Dropout for regularization
|
| 607 |
- Layer normalization for stable training
|
| 608 |
- Interpretability mode for contribution and attention visualization
|
| 609 |
+
- Support for batch predictions with multiple SMILES entries
|
| 610 |
|
| 611 |
### Usage tips:
|
| 612 |
1. Load a trained model using the Model Settings tab (optional)
|
| 613 |
2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab
|
| 614 |
+
3. For batch predictions, enter multiple SMILES strings (one per line) in the drug SMILES field
|
| 615 |
+
4. Click "Predict Interaction" for binding affinity prediction only
|
| 616 |
+
5. Click "Generate Visualizations" to create detailed interaction analysis - results will appear in the Visualizations tab
|
| 617 |
+
6. Note: Visualizations are generated only for the last SMILES entry when using batch mode
|
| 618 |
|
| 619 |
For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens).
|
| 620 |
|