IlPakoZ commited on
Commit
c0031a7
·
verified ·
1 Parent(s): 0774611

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -643
app.py DELETED
@@ -1,643 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel
5
- from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler
6
- from configuration_dlmberta import InteractionModelATTNConfig
7
- from chemberta import ChembertaTokenizer
8
- import json
9
- import os
10
- from pathlib import Path
11
- import logging
12
-
13
- # Import visualization functions
14
- from analysis import plot_crossattention_weights, plot_presum
15
- from PIL import Image, ImageDraw, ImageFont
16
-
17
- # Configure logging
18
- logging.basicConfig(level=logging.INFO)
19
- logger = logging.getLogger(__name__)
20
-
21
- def create_placeholder_image(width=600, height=400, text="No visualization available", bg_color=(0, 0, 0, 0)):
22
- """
23
- Create a transparent placeholder image with text
24
-
25
- Args:
26
- width (int): Image width
27
- height (int): Image height
28
- text (str): Text to display
29
- bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent
30
-
31
- Returns:
32
- PIL.Image: Transparent placeholder image
33
- """
34
- # Create image with transparent background
35
- img = Image.new('RGBA', (width, height), bg_color)
36
- draw = ImageDraw.Draw(img)
37
-
38
- # Try to use a default font, fallback to default if not available
39
- try:
40
- font = ImageFont.truetype("arial.ttf", 16)
41
- except:
42
- try:
43
- font = ImageFont.load_default()
44
- except:
45
- font = None
46
-
47
- # Get text size and position for centering
48
- if font:
49
- bbox = draw.textbbox((0, 0), text, font=font)
50
- text_width = bbox[2] - bbox[0]
51
- text_height = bbox[3] - bbox[1]
52
- else:
53
- # Rough estimation if no font available
54
- text_width = len(text) * 8
55
- text_height = 16
56
-
57
- x = (width - text_width) // 2
58
- y = (height - text_height) // 2
59
-
60
- # Draw text in gray
61
- draw.text((x, y), text, fill=(128, 128, 128, 255), font=font)
62
-
63
- return img
64
-
65
- class DrugTargetInteractionApp:
66
- def __init__(self):
67
- self.model = None
68
- self.target_tokenizer = None
69
- self.drug_tokenizer = None
70
- self.scaler = None
71
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
-
73
- def load_model(self, model_path="./"):
74
- """Load the pre-trained model and tokenizers"""
75
- try:
76
- # Load configuration
77
- config = InteractionModelATTNConfig.from_pretrained(model_path)
78
-
79
- # Load drug encoder (ChemBERTa)
80
- drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
81
- drug_encoder_config.pooler = None
82
- drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False)
83
-
84
- # Load target encoder
85
- target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700")
86
-
87
- # Load scaler if exists
88
- scaler_path = os.path.join(model_path, "scaler.config")
89
- scaler = None
90
- if os.path.exists(scaler_path):
91
- scaler = StdScaler()
92
- scaler.load(model_path)
93
-
94
- self.model = InteractionModelATTNForRegression.from_pretrained(
95
- model_path,
96
- config=config,
97
- target_encoder=target_encoder,
98
- drug_encoder=drug_encoder,
99
- scaler=scaler
100
- )
101
-
102
- self.model.to(self.device)
103
- self.model.eval()
104
-
105
- # Load tokenizers
106
- self.target_tokenizer = AutoTokenizer.from_pretrained(
107
- os.path.join(model_path, "target_tokenizer")
108
- )
109
-
110
- # Load drug tokenizer (ChemBERTa)
111
- vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json")
112
- self.drug_tokenizer = ChembertaTokenizer(vocab_file)
113
-
114
- logger.info("Model and tokenizers loaded successfully!")
115
- return True
116
-
117
- except Exception as e:
118
- logger.error(f"Error loading model: {str(e)}")
119
- return False
120
-
121
- def predict_interaction(self, target_sequence, drug_smiles, max_length=512):
122
- """Predict drug-target interaction"""
123
- if self.model is None:
124
- return "Error: Model not loaded. Please load a model first."
125
-
126
- try:
127
- # Tokenize inputs
128
- target_inputs = self.target_tokenizer(
129
- target_sequence,
130
- padding="max_length",
131
- truncation=True,
132
- max_length=512,
133
- return_tensors="pt"
134
- ).to(self.device)
135
-
136
- drug_inputs = self.drug_tokenizer(
137
- drug_smiles,
138
- padding="max_length",
139
- truncation=True,
140
- max_length=512,
141
- return_tensors="pt"
142
- ).to(self.device)
143
-
144
- # Make prediction
145
- self.model.INTERPR_DISABLE_MODE()
146
- with torch.no_grad():
147
- prediction = self.model(target_inputs, drug_inputs)
148
-
149
- # Unscale if scaler exists
150
- if self.model.scaler is not None:
151
- prediction = self.model.unscale(prediction)
152
-
153
- prediction_value = prediction.cpu().numpy()[0][0]
154
-
155
- return f"Predicted Binding Affinity: {prediction_value:.4f}"
156
-
157
- except Exception as e:
158
- logger.error(f"Prediction error: {str(e)}")
159
- return f"Error during prediction: {str(e)}"
160
-
161
- def visualize_interaction(self, target_sequence, drug_smiles):
162
- """
163
- Generate visualization images for drug-target interaction
164
-
165
- Args:
166
- target_sequence (str): RNA sequence
167
- drug_smiles (str): Drug SMILES notation
168
-
169
- Returns:
170
- tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message)
171
- """
172
- if self.model is None:
173
- return None, None, None, "Error: Model not loaded. Please load a model first."
174
-
175
- try:
176
- # Tokenize inputs
177
- target_inputs = self.target_tokenizer(
178
- target_sequence,
179
- padding="max_length",
180
- truncation=True,
181
- max_length=512,
182
- return_tensors="pt"
183
- ).to(self.device)
184
-
185
- drug_inputs = self.drug_tokenizer(
186
- drug_smiles,
187
- padding="max_length",
188
- truncation=True,
189
- max_length=512,
190
- return_tensors="pt"
191
- ).to(self.device)
192
-
193
- # Enable interpretation mode
194
- self.model.INTERPR_ENABLE_MODE()
195
-
196
- # Make prediction and extract visualization data
197
- with torch.no_grad():
198
- prediction = self.model(target_inputs, drug_inputs)
199
-
200
- # Unscale if scaler exists
201
- if self.model.scaler is not None:
202
- prediction = self.model.unscale(prediction)
203
-
204
- prediction_value = prediction.cpu().numpy()[0][0]
205
-
206
- # Extract data needed for visualizations
207
- presum_values = self.model.model.presum_layer # Shape: (1, seq_len)
208
- cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len)
209
-
210
- # Get model parameters for scaling
211
- w = self.model.model.w.squeeze(1)
212
- b = self.model.model.b
213
- scaler = self.model.model.scaler
214
-
215
- logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}")
216
- logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}")
217
-
218
- # Generate visualizations
219
- try:
220
- # 1. Cross-attention heatmap
221
- cross_attention_img = None
222
- logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}")
223
- if cross_attention_weights is not None:
224
- logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}")
225
-
226
- try:
227
- cross_attn_matrix = cross_attention_weights[0, 0]
228
-
229
- if cross_attn_matrix is not None:
230
- logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}")
231
- logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}")
232
- logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}")
233
-
234
- cross_attention_img = plot_crossattention_weights(
235
- target_inputs["attention_mask"][0],
236
- drug_inputs["attention_mask"][0],
237
- target_inputs,
238
- drug_inputs,
239
- cross_attn_matrix,
240
- self.target_tokenizer,
241
- self.drug_tokenizer
242
- )
243
- else:
244
- logger.warning("Could not extract valid cross-attention matrix")
245
-
246
- except (IndexError, TypeError, AttributeError) as e:
247
- logger.warning(f"Error extracting cross-attention matrix: {str(e)}")
248
- cross_attn_matrix = None
249
- else:
250
- logger.warning("Cross-attention weights are None")
251
-
252
- except Exception as e:
253
- logger.error(f"Cross-attention visualization error: {str(e)}")
254
- cross_attention_img = None
255
-
256
- try:
257
- # 2. Normalized contribution visualization (only if pKd > 0)
258
- normalized_img = None
259
- if presum_values is not None:
260
- normalized_img = plot_presum(
261
- target_inputs,
262
- presum_values.detach(), # Detach the tensor
263
- scaler,
264
- w.detach(), # Detach the tensor
265
- b.detach(), # Detach the tensor
266
- self.target_tokenizer,
267
- raw_affinities=False
268
- )
269
- else:
270
- if prediction_value <= 0:
271
- logger.info("Skipping normalized affinities visualization as pKd <= 0")
272
- if presum_values is None:
273
- logger.warning("Cannot generate raw visualization: presum values are None")
274
-
275
-
276
- except Exception as e:
277
- logger.error(f"Normalized contribution visualization error: {str(e)}")
278
- normalized_img = None
279
-
280
- try:
281
- # 3. Raw contribution visualization (always generate)
282
- raw_img = None
283
- if prediction_value > 0 and presum_values is not None:
284
- raw_img = plot_presum(
285
- target_inputs,
286
- presum_values.detach(), # Detach the tensor
287
- scaler,
288
- w.detach(), # Detach the tensor
289
- b.detach(), # Detach the tensor
290
- self.target_tokenizer,
291
- raw_affinities=True
292
- )
293
- else:
294
- logger.warning("Presum values are None")
295
-
296
- except Exception as e:
297
- logger.error(f"Raw contribution visualization error: {str(e)}")
298
- raw_img = None
299
-
300
- # Disable interpretation mode after use
301
- self.model.INTERPR_DISABLE_MODE()
302
-
303
- # Create placeholder images if generation failed
304
- if cross_attention_img is None:
305
- cross_attention_img = create_placeholder_image(
306
- text="Cross-Attention Heatmap\nFailed to generate"
307
- )
308
- if normalized_img is None:
309
- normalized_img = create_placeholder_image(
310
- text="Normalized Contribution\nFailed to generate"
311
- )
312
- if raw_img is None and prediction_value > 0:
313
- raw_img = create_placeholder_image(
314
- text="Raw Contribution\nFailed to generate"
315
- )
316
- elif raw_img is None:
317
- raw_img = create_placeholder_image(
318
- text="Raw Contribution\nSkipped (pKd ≤ 0)"
319
- )
320
-
321
- status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}"
322
- if prediction_value <= 0:
323
- status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
324
- if cross_attention_weights is None:
325
- status_msg += " (Cross-attention visualization failed: weights not available)"
326
-
327
- return cross_attention_img, raw_img, normalized_img, status_msg
328
-
329
- except Exception as e:
330
- logger.error(f"Visualization error: {str(e)}")
331
- # Make sure to disable interpretation mode even if there's an error
332
- try:
333
- self.model.INTERPR_DISABLE_MODE()
334
- except:
335
- pass
336
- return None, None, None, f"Error during visualization: {str(e)}"
337
-
338
-
339
- # Initialize the app
340
- app = DrugTargetInteractionApp()
341
-
342
- def predict_wrapper(target_seq, drug_smiles):
343
- """Wrapper function for Gradio interface"""
344
- if not target_seq.strip() or not drug_smiles.strip():
345
- return "Please provide both target sequence and drug SMILES."
346
-
347
- return app.predict_interaction(target_seq, drug_smiles)
348
-
349
- def visualize_wrapper(target_seq, drug_smiles):
350
- """Wrapper function for visualization"""
351
- if not target_seq.strip() or not drug_smiles.strip():
352
- return None, None, None, "Please provide both target sequence and drug SMILES."
353
-
354
- return app.visualize_interaction(target_seq, drug_smiles)
355
-
356
- def load_model_wrapper(model_path):
357
- """Wrapper function to load model"""
358
- if app.load_model(model_path):
359
- return "Model loaded successfully!"
360
- else:
361
- return "Failed to load model. Check the path and files."
362
-
363
- # Create Gradio interface
364
- with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo:
365
- gr.HTML("""
366
- <div style="text-align: center; margin-bottom: 30px;">
367
- <h1 style="color: #2E86AB; font-size: 2.5em; margin-bottom: 10px;">
368
- 🧬 Drug-Target Interaction Predictor
369
- </h1>
370
- <p style="font-size: 1.2em; color: #666;">
371
- Predict binding affinity between drugs and target RNA sequences using deep learning
372
- </p>
373
- </div>
374
- """)
375
-
376
- # Create state variables to share images between tabs
377
- viz_state1 = gr.State()
378
- viz_state2 = gr.State()
379
- viz_state3 = gr.State()
380
-
381
- with gr.Tab("🔮 Prediction & Analysis"):
382
- with gr.Row():
383
- with gr.Column(scale=1):
384
- target_input = gr.Textbox(
385
- label="Target RNA Sequence",
386
- placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)",
387
- lines=4,
388
- max_lines=6
389
- )
390
-
391
- drug_input = gr.Textbox(
392
- label="Drug SMILES",
393
- placeholder="Enter SMILES notation (e.g., CC(C)CC1=CC=C(C=C1)C(C)C(=O)O)",
394
- lines=2
395
- )
396
-
397
- with gr.Row():
398
- predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg")
399
- visualize_btn = gr.Button("📊 Generate Visualizations", variant="secondary", size="lg")
400
-
401
- with gr.Column(scale=1):
402
- prediction_output = gr.Textbox(
403
- label="Prediction Result",
404
- interactive=False,
405
- lines=4
406
- )
407
-
408
- # Example inputs
409
- gr.HTML("<h3 style='margin-top: 20px; color: #2E86AB;'>📚 Example Inputs:</h3>")
410
-
411
- examples = gr.Examples(
412
- examples=[
413
- [
414
- "AUGCUAGCUAGUACGUAUAUCUGCACUGC",
415
- "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
416
- ],
417
- [
418
- "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU",
419
- "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"
420
- ]
421
- ],
422
- inputs=[target_input, drug_input],
423
- outputs=prediction_output,
424
- fn=predict_wrapper,
425
- cache_examples=False
426
- )
427
-
428
- # Button click events
429
- predict_btn.click(
430
- fn=predict_wrapper,
431
- inputs=[target_input, drug_input],
432
- outputs=prediction_output
433
- )
434
-
435
- def visualize_and_update(target_seq, drug_smiles):
436
- """Generate visualizations and update both status and state"""
437
- img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles)
438
- # Combine prediction result with visualization status
439
- combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images."
440
- return img1, img2, img3, combined_status
441
-
442
- visualize_btn.click(
443
- fn=visualize_and_update,
444
- inputs=[target_input, drug_input],
445
- outputs=[viz_state1, viz_state2, viz_state3, prediction_output]
446
- )
447
-
448
- with gr.Tab("📊 Visualizations"):
449
- gr.HTML("""
450
- <div style="text-align: center; margin-bottom: 20px;">
451
- <h2 style="color: #2E86AB;">🔬 Interaction Analysis & Visualizations</h2>
452
- <p style="font-size: 1.1em; color: #666;">
453
- Generated visualizations will appear here after clicking "Generate Visualizations" in the Prediction tab
454
- </p>
455
- </div>
456
- """)
457
-
458
- # Visualization outputs - Large and vertically aligned
459
- viz_image1 = gr.Image(
460
- label="Cross-Attention Heatmap",
461
- type="pil",
462
- interactive=False,
463
- container=True,
464
- height=500,
465
- value=create_placeholder_image(text="Cross-Attention Heatmap\n(Generate visualizations in the Prediction tab)")
466
- )
467
-
468
- viz_image2 = gr.Image(
469
- label="Raw pKd Contribution Visualization",
470
- type="pil",
471
- interactive=False,
472
- container=True,
473
- height=500,
474
- value=create_placeholder_image(text="Raw pKd Contribution\n(Generate visualizations in the Prediction tab)")
475
- )
476
-
477
- viz_image3 = gr.Image(
478
- label="Normalized pKd Contribution Visualization",
479
- type="pil",
480
- interactive=False,
481
- container=True,
482
- height=500,
483
- value=create_placeholder_image(text="Normalized pKd Contribution\n(Generate visualizations in the Prediction tab)")
484
- )
485
-
486
- # Update visualization images when state changes
487
- viz_state1.change(
488
- fn=lambda x: x,
489
- inputs=viz_state1,
490
- outputs=viz_image1
491
- )
492
-
493
- viz_state2.change(
494
- fn=lambda x: x,
495
- inputs=viz_state2,
496
- outputs=viz_image2
497
- )
498
-
499
- viz_state3.change(
500
- fn=lambda x: x,
501
- inputs=viz_state3,
502
- outputs=viz_image3
503
- )
504
-
505
- with gr.Tab("⚙️ Model Settings"):
506
- gr.HTML("<h3 style='color: #2E86AB;'>Model Configuration</h3>")
507
-
508
- model_path_input = gr.Textbox(
509
- label="Model Path",
510
- value="./",
511
- placeholder="Path to model directory"
512
- )
513
-
514
- load_model_btn = gr.Button("📥 Load Model", variant="secondary")
515
- model_status = gr.Textbox(
516
- label="Status",
517
- interactive=False,
518
- value="No model loaded"
519
- )
520
-
521
- load_model_btn.click(
522
- fn=load_model_wrapper,
523
- inputs=model_path_input,
524
- outputs=model_status
525
- )
526
-
527
- with gr.Tab("📊 Dataset"):
528
- gr.Markdown("""
529
- ## Training and Test Datasets
530
-
531
- ### Fine-tuning Dataset (Training)
532
-
533
- The model was trained on a dataset comprising **1,439 RNA–drug interaction pairs**, including:
534
- - **759 unique compounds** (SMILES representations)
535
- - **294 unique RNA sequences**
536
- - Dissociation constants (pKd values) for binding affinity prediction
537
-
538
- **RNA Sequence Distribution by Type:**
539
-
540
- | RNA Sequence Type | Number of Interactions |
541
- |-------------------|------------------------|
542
- | Aptamers | 520 |
543
- | Ribosomal | 295 |
544
- | Viral RNAs | 281 |
545
- | miRNAs | 146 |
546
- | Riboswitches | 100 |
547
- | Repeats | 97 |
548
- | **Total** | **1,439** |
549
-
550
- ### External Evaluation Dataset (Test)
551
-
552
- Model validation was performed using external ROBIN classification datasets containing **5,534 RNA–drug pairs**:
553
- - **2,991 positive interactions**
554
- - **2,538 negative interactions**
555
-
556
- **Test Dataset Composition:**
557
- - **1,617 aptamer pairs** (5 unique RNA sequences)
558
- - **1,828 viral RNA pairs** (6 unique RNA sequences)
559
- - **1,459 riboswitch pairs** (5 unique RNA sequences)
560
- - **630 miRNA pairs** (3 unique RNA sequences)
561
-
562
- ### Dataset Downloads
563
-
564
- - [Training Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/training_data.csv?download=true)
565
- - [Test Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/test_data.csv?download=true)
566
-
567
- ### Citation
568
-
569
- Original datasets published by:
570
- **Krishnan et al.** - Available on the RSAPred website in PDF format.
571
-
572
- *Reference:*
573
- ```bibtex
574
- @article{krishnan2024reliable,
575
- title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning},
576
- author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael},
577
- journal={Briefings in Bioinformatics},
578
- volume={25},
579
- number={2},
580
- pages={bbae002},
581
- year={2024},
582
- publisher={Oxford University Press}
583
- }
584
- ```
585
- """)
586
- with gr.Tab("ℹ️ About"):
587
- gr.Markdown("""
588
- ## About this application
589
-
590
- This application implements DLRNA-BERTa, a Dual Langauge RoBERTa Transformer model for predicting drug to RNA target interactions. The model architecture includes:
591
-
592
- - **Target encoder**: Processes RNA sequences using RNA-BERTa
593
- - **Drug encoder**: Processes molecular SMILES notation using ChemBERTa
594
- - **Cross-attention mechanism**: Captures interactions between drugs and targets
595
- - **Regression head**: Predicts binding affinity scores (pKd values)
596
-
597
- ### Input requirements:
598
- - **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C)
599
- - **Drug SMILES**: Simplified Molecular Input Line Entry System notation
600
-
601
- ### Model features:
602
- - Cross-attention for drug-target interaction modeling
603
- - Dropout for regularization
604
- - Layer normalization for stable training
605
- - Interpretability mode for contribution and attention visualization
606
-
607
- ### Usage tips:
608
- 1. Load a trained model using the Model Settings tab (optional)
609
- 2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab
610
- 3. Click "Predict Interaction" for binding affinity prediction only
611
- 4. Click "Generate Visualizations" to create detailed interaction analysis - results will appear in the Visualizations tab
612
-
613
- For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens).
614
-
615
- ### Visualization features:
616
- - **Cross-attention heatmap**: Shows cross-attention weights between drug and target tokens
617
- - **Unnormalized pKd contribution**: Shows unnormalized signed contributions from each target token (only when pKd > 0)
618
- - **Normalized pKd contribution**: Shows normalized non-negative contributions from each target token
619
-
620
- ### Performance metrics:
621
- - Training on diverse drug-target interaction datasets
622
- - Evaluated using RMSE, Pearson correlation, and Concordance Index
623
- - Optimized for both predictive accuracy and interpretability
624
-
625
- ### GitHub repository:
626
- - The full model GitHub repository can be found here: https://github.com/IlPakoZ/dlrnaberta-dti-prediction
627
-
628
- ### Contribution:
629
- - Special thanks to Umut Onur Özcan for help in developing this space:)
630
- """)
631
-
632
- # Launch the app
633
- if __name__ == "__main__":
634
- # Try to load model on startup
635
- if os.path.exists("./config.json"):
636
- app.load_model("./")
637
-
638
- demo.launch(
639
- server_name="0.0.0.0",
640
- server_port=7860,
641
- share=False,
642
- show_error=True
643
- )