liuliu2333 commited on
Commit
fc481db
Β·
1 Parent(s): 9201562

Deploy DeepMiRT Gradio demo with model code

Browse files
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: Deepmirt
3
- emoji: πŸš€
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: DeepMiRT
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.23.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: miRNA target prediction with RNA foundation models
12
  ---
 
 
app.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DeepMiRT Web Demo β€” Gradio interface for miRNA-target interaction prediction.
4
+
5
+ Run locally:
6
+ python app.py
7
+
8
+ Deploy on Hugging Face Spaces:
9
+ Set sdk: gradio in the Space README.md metadata.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import re
16
+ import tempfile
17
+ from pathlib import Path
18
+
19
+ import gradio as gr
20
+ import numpy as np
21
+ import pandas as pd
22
+ import torch
23
+
24
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Global model (loaded once at startup)
29
+ # ---------------------------------------------------------------------------
30
+ _model = None
31
+ _alphabet = None
32
+ _config = None
33
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+
36
+ def _load_model():
37
+ """Load model from Hugging Face Hub (cached after first download)."""
38
+ global _model, _alphabet, _config
39
+
40
+ if _model is not None:
41
+ return
42
+
43
+ import fm
44
+ import torch
45
+ from huggingface_hub import hf_hub_download
46
+
47
+ from deepmirt.evaluation.predict import load_model_from_checkpoint
48
+
49
+ repo_id = "liuliu2333/deepmirt"
50
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename="epoch=27-val_auroc=0.9612.ckpt")
51
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
52
+
53
+ logger.info("Loading model...")
54
+ _model, _config = load_model_from_checkpoint(ckpt_path, config_path, device=_device)
55
+ _, _alphabet = fm.pretrained.rna_fm_t12()
56
+ logger.info("Model loaded successfully.")
57
+
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Validation helpers
61
+ # ---------------------------------------------------------------------------
62
+ _VALID_BASES = set("AUGC")
63
+
64
+
65
+ def _validate_seq(seq: str, name: str, min_len: int = 1, max_len: int = 200) -> str:
66
+ """Validate and clean an RNA/DNA sequence."""
67
+ seq = seq.strip().upper().replace("T", "U")
68
+ if not seq:
69
+ raise gr.Error(f"{name} sequence is empty.")
70
+ if len(seq) < min_len or len(seq) > max_len:
71
+ raise gr.Error(f"{name} must be {min_len}-{max_len} nt, got {len(seq)} nt.")
72
+ invalid = set(seq) - _VALID_BASES
73
+ if invalid:
74
+ raise gr.Error(f"{name} contains invalid characters: {invalid}. Only A/U/G/C/T allowed.")
75
+ return seq
76
+
77
+
78
+ # ---------------------------------------------------------------------------
79
+ # Prediction logic
80
+ # ---------------------------------------------------------------------------
81
+ def _predict_pair(mirna_seq: str, target_seq: str) -> np.ndarray:
82
+ """Run model inference on a single pair."""
83
+ import torch
84
+ from torch.nn.utils.rnn import pad_sequence
85
+
86
+ _load_model()
87
+
88
+ batch_converter = _alphabet.get_batch_converter()
89
+ padding_idx = _alphabet.padding_idx
90
+
91
+ _, _, m_tok = batch_converter([("m", mirna_seq)])
92
+ _, _, t_tok = batch_converter([("t", target_seq)])
93
+
94
+ mirna_padded = pad_sequence([m_tok[0]], batch_first=True, padding_value=padding_idx)
95
+ target_stacked = torch.stack([t_tok[0]])
96
+
97
+ attn_mask_mirna = (mirna_padded != padding_idx).long().to(_device)
98
+ attn_mask_target = torch.ones_like(target_stacked, dtype=torch.long).to(_device)
99
+ mirna_padded = mirna_padded.to(_device)
100
+ target_stacked = target_stacked.to(_device)
101
+
102
+ with torch.no_grad():
103
+ logits = _model.model(mirna_padded, target_stacked, attn_mask_mirna, attn_mask_target)
104
+ prob = torch.sigmoid(logits.squeeze(-1)).cpu().numpy()
105
+ return prob
106
+
107
+
108
+ def predict_single(mirna_seq: str, target_seq: str):
109
+ """Gradio callback for single prediction."""
110
+ mirna_rna = _validate_seq(mirna_seq, "miRNA", min_len=15, max_len=30)
111
+ target_rna = _validate_seq(target_seq, "Target", min_len=20, max_len=50)
112
+
113
+ prob = _predict_pair(mirna_rna, target_rna)
114
+ p = float(prob[0])
115
+ label = "INTERACTION" if p >= 0.5 else "NO INTERACTION"
116
+ color = "#2ecc71" if p >= 0.5 else "#e74c3c"
117
+ details = {
118
+ "probability": round(p, 6),
119
+ "prediction": label,
120
+ "threshold": 0.5,
121
+ "mirna_length": len(mirna_rna),
122
+ "target_length": len(target_rna),
123
+ }
124
+ return (
125
+ f"<div style='text-align:center;padding:20px;'>"
126
+ f"<span style='font-size:48px;font-weight:bold;color:{color};'>{p:.4f}</span><br>"
127
+ f"<span style='font-size:20px;color:{color};'>{label}</span></div>"
128
+ ), details
129
+
130
+
131
+ def predict_batch(file):
132
+ """Gradio callback for batch prediction."""
133
+ if file is None:
134
+ raise gr.Error("Please upload a CSV file.")
135
+
136
+ _load_model()
137
+
138
+ df = pd.read_csv(file.name)
139
+
140
+ mirna_col = None
141
+ target_col = None
142
+ for col in df.columns:
143
+ cl = col.lower().strip()
144
+ if "mirna" in cl:
145
+ mirna_col = col
146
+ elif "target" in cl:
147
+ target_col = col
148
+
149
+ if mirna_col is None or target_col is None:
150
+ raise gr.Error(
151
+ "CSV must contain a column with 'mirna' and a column with 'target' in the name. "
152
+ f"Found columns: {list(df.columns)}"
153
+ )
154
+
155
+ mirna_seqs = df[mirna_col].astype(str).tolist()
156
+ target_seqs = df[target_col].astype(str).tolist()
157
+
158
+ # Validate and convert
159
+ cleaned_mirna = []
160
+ cleaned_target = []
161
+ for i, (m, t) in enumerate(zip(mirna_seqs, target_seqs)):
162
+ m = m.strip().upper().replace("T", "U")
163
+ t = t.strip().upper().replace("T", "U")
164
+ invalid_m = set(m) - _VALID_BASES
165
+ invalid_t = set(t) - _VALID_BASES
166
+ if invalid_m or invalid_t:
167
+ raise gr.Error(f"Row {i}: invalid characters in sequences.")
168
+ cleaned_mirna.append(m)
169
+ cleaned_target.append(t)
170
+
171
+ # Batch inference
172
+ import torch
173
+ from torch.nn.utils.rnn import pad_sequence
174
+
175
+ batch_converter = _alphabet.get_batch_converter()
176
+ padding_idx = _alphabet.padding_idx
177
+ all_probs = []
178
+ batch_size = 128
179
+
180
+ with torch.no_grad():
181
+ for start in range(0, len(cleaned_mirna), batch_size):
182
+ batch_m = cleaned_mirna[start : start + batch_size]
183
+ batch_t = cleaned_target[start : start + batch_size]
184
+
185
+ m_toks = []
186
+ t_toks = []
187
+ for ms, ts in zip(batch_m, batch_t):
188
+ _, _, mt = batch_converter([("m", ms)])
189
+ _, _, tt = batch_converter([("t", ts)])
190
+ m_toks.append(mt[0])
191
+ t_toks.append(tt[0])
192
+
193
+ mirna_padded = pad_sequence(m_toks, batch_first=True, padding_value=padding_idx)
194
+ target_stacked = torch.stack(t_toks)
195
+ attn_mask_mirna = (mirna_padded != padding_idx).long().to(_device)
196
+ attn_mask_target = torch.ones_like(target_stacked, dtype=torch.long).to(_device)
197
+
198
+ logits = _model.model(
199
+ mirna_padded.to(_device),
200
+ target_stacked.to(_device),
201
+ attn_mask_mirna,
202
+ attn_mask_target,
203
+ )
204
+ probs = torch.sigmoid(logits.squeeze(-1)).cpu().numpy()
205
+ all_probs.append(probs)
206
+
207
+ all_probs = np.concatenate(all_probs)
208
+ df["probability"] = all_probs
209
+ df["prediction"] = (all_probs >= 0.5).astype(int)
210
+
211
+ # Save to temp file for download
212
+ out_path = Path(tempfile.mkdtemp()) / "deepmirt_predictions.csv"
213
+ df.to_csv(str(out_path), index=False)
214
+ return str(out_path), df.head(20)
215
+
216
+
217
+ # ---------------------------------------------------------------------------
218
+ # Examples
219
+ # ---------------------------------------------------------------------------
220
+ EXAMPLES = [
221
+ # [miRNA, target_40nt] - real miRNA-target pairs
222
+ ["UGAGGUAGUAGGUUGUAUAGUU", "ACUGCAGCAUAUCUACUAUUUGCUACUGUAACCAUUGAUCU"], # let-7a / lin-41
223
+ ["UAAAGUGCUUAUAGUGCAGGUAG", "GCAGCAUUGUACAGGGCUAUCAGAAACUAUUGACACUAAAA"], # miR-20a / E2F1
224
+ ["UAGCAGCACGUAAAUAUUGGCG", "GCAAUGUUUUCCACAGUGCUUACACAGAAAUAGCAACUUUA"], # miR-16 / BCL2
225
+ ["CAUCAAAGUGGAGGCCCUCUCU", "AAUGCUUCUAAAUUGAAUCCAAACUGCAGUUUAUUAGUGGU"], # miR-198 (negative)
226
+ ["UGGAAUGUAAAGAAGUAUGUAU", "UCGAAUCCAUGCAAAACAGCUUGAUUUGUUAGUACACGAAU"], # miR-1 / HAND2
227
+ ]
228
+
229
+
230
+ # ---------------------------------------------------------------------------
231
+ # Gradio UI
232
+ # ---------------------------------------------------------------------------
233
+ def build_demo():
234
+ with gr.Blocks(
235
+ title="DeepMiRT: miRNA Target Prediction",
236
+ theme=gr.themes.Soft(),
237
+ ) as demo:
238
+ gr.Markdown(
239
+ """
240
+ # DeepMiRT: miRNA Target Prediction with RNA Foundation Models
241
+
242
+ Predict miRNA-target interactions using RNA-FM embeddings and cross-attention.
243
+ Ranked **#1** on eCLIP benchmarks (AUROC 0.75) and achieves **AUROC 0.96** on our comprehensive test set.
244
+
245
+ **Paper:** *coming soon* | **GitHub:** [DeepMiRT](https://github.com/zichengll/DeepMiRT) | **Model:** [Hugging Face](https://huggingface.co/liuliu2333/deepmirt)
246
+ """
247
+ )
248
+
249
+ with gr.Tab("Single Prediction"):
250
+ with gr.Row():
251
+ with gr.Column():
252
+ mirna_input = gr.Textbox(
253
+ label="miRNA Sequence",
254
+ placeholder="e.g., UGAGGUAGUAGGUUGUAUAGUU",
255
+ info="18-25 nt. DNA (T) or RNA (U) format accepted.",
256
+ )
257
+ target_input = gr.Textbox(
258
+ label="Target Sequence",
259
+ placeholder="e.g., ACUGCAGCAUAUCUACUAUUUGCUACUGUAACCAUUGAUCU",
260
+ info="40 nt recommended. DNA (T) or RNA (U) format accepted.",
261
+ )
262
+ predict_btn = gr.Button("Predict", variant="primary")
263
+
264
+ with gr.Column():
265
+ result_html = gr.HTML(label="Prediction Result")
266
+ result_json = gr.JSON(label="Details")
267
+
268
+ predict_btn.click(
269
+ predict_single,
270
+ inputs=[mirna_input, target_input],
271
+ outputs=[result_html, result_json],
272
+ )
273
+
274
+ gr.Examples(
275
+ examples=EXAMPLES,
276
+ inputs=[mirna_input, target_input],
277
+ outputs=[result_html, result_json],
278
+ fn=predict_single,
279
+ cache_examples=False,
280
+ )
281
+
282
+ with gr.Tab("Batch Prediction"):
283
+ gr.Markdown(
284
+ """
285
+ Upload a CSV file with columns containing **mirna** and **target** in the column names.
286
+
287
+ Example format:
288
+ | mirna_seq | target_seq |
289
+ |-----------|------------|
290
+ | UGAGGUAGUAGGUUGUAUAGUU | ACUGCAGCAUAUCUACUAUUUGCUACUGUAACCAUUGAUCU |
291
+ """
292
+ )
293
+ csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
294
+ batch_btn = gr.Button("Run Batch Prediction", variant="primary")
295
+ csv_output = gr.File(label="Download Results")
296
+ preview = gr.Dataframe(label="Preview (first 20 rows)")
297
+
298
+ batch_btn.click(
299
+ predict_batch,
300
+ inputs=[csv_input],
301
+ outputs=[csv_output, preview],
302
+ )
303
+
304
+ with gr.Tab("About"):
305
+ gr.Markdown(
306
+ """
307
+ ## Model Architecture
308
+
309
+ DeepMiRT uses a **shared RNA-FM encoder** (12-layer Transformer, pre-trained on 23M non-coding RNAs)
310
+ to embed both miRNA and target sequences into the same representation space.
311
+ A **cross-attention module** (2 layers, 8 heads) allows the target to attend to the miRNA,
312
+ capturing interaction patterns. The attended representations are pooled and classified
313
+ by an **MLP head** (640 β†’ 256 β†’ 64 β†’ 1).
314
+
315
+ ```
316
+ miRNA β†’ [RNA-FM Encoder] β†’ miRNA embedding ─────────┐
317
+ ↓
318
+ Target β†’ [RNA-FM Encoder] β†’ target embedding β†’ [Cross-Attention] β†’ Pool β†’ [MLP] β†’ probability
319
+ ```
320
+
321
+ ## Training
322
+
323
+ - **Data:** miRNA-target interactions from multiple databases and literature mining
324
+ - **Two-phase training:** Phase 1 (frozen backbone) β†’ Phase 2 (unfreeze top 3 RNA-FM layers)
325
+ - **Hardware:** 2Γ— NVIDIA L20 GPUs, mixed-precision (fp16)
326
+ - **Best checkpoint:** epoch 27, val AUROC = 0.9612
327
+
328
+ ## Performance
329
+
330
+ | Benchmark | AUROC | Rank |
331
+ |-----------|-------|------|
332
+ | miRBench eCLIP (Klimentova 2022) | 0.7511 | #1/12 |
333
+ | miRBench eCLIP (Manakov 2022) | 0.7543 | #1/12 |
334
+ | miRBench CLASH (Hejret 2023) | 0.6952 | #5/12 |
335
+ | Our test set (813K samples, 16 methods) | 0.9606 | #1/16 |
336
+
337
+ ## Citation
338
+
339
+ If you use DeepMiRT in your research, please cite:
340
+ ```
341
+ @software{liu2026deepmirt,
342
+ title={DeepMiRT: miRNA Target Prediction with RNA Foundation Models},
343
+ author={Liu, Zicheng},
344
+ year={2026},
345
+ url={https://github.com/zichengll/DeepMiRT}
346
+ }
347
+ ```
348
+
349
+ ## License
350
+
351
+ MIT License. See [LICENSE](https://github.com/zichengll/DeepMiRT/blob/main/LICENSE).
352
+ """
353
+ )
354
+
355
+ return demo
356
+
357
+
358
+ if __name__ == "__main__":
359
+ demo = build_demo()
360
+ demo.launch()
deepmirt/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """DeepMiRT: miRNA target prediction using RNA foundation models and cross-attention."""
2
+
3
+ __version__ = "1.0.0"
4
+
5
+ from deepmirt.predict import predict as predict
deepmirt/data_module/__init__.py ADDED
File without changes
deepmirt/data_module/datamodule.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ miRNA-Target PyTorch Lightning DataModule
4
+
5
+ [Lightning DataModule Lifecycle]
6
+ Lightning DataModule encapsulates data loading logic into a reusable module.
7
+ Its lifecycle is as follows:
8
+
9
+ 1. prepare_data() β€” download data (runs only on main process; not needed in this project)
10
+ 2. setup(stage) β€” create Dataset instances (runs on every process)
11
+ - stage='fit' β†’ create train_dataset + val_dataset
12
+ - stage='test' β†’ create test_dataset
13
+ - stage='predict' β†’ create predict_dataset
14
+ 3. train_dataloader() β€” return training DataLoader
15
+ 4. val_dataloader() β€” return validation DataLoader
16
+ 5. test_dataloader() β€” return test DataLoader
17
+
18
+ [Why use DataModule instead of manually creating DataLoaders?]
19
+ - Centralizes all data-related logic (paths, batch size, tokenizer, data splits)
20
+ - Lightning Trainer automatically calls the correct methods, reducing boilerplate
21
+ - Makes it easy to reuse the same data configuration across different experiments
22
+
23
+ [collate_fn Explained β€” The Core Difficulty of This Module]
24
+ Since miRNA sequence lengths are variable (15-30nt β†’ 17-32 tokens),
25
+ samples in the same batch may have mirna_tokens of different lengths.
26
+ PyTorch's default collate cannot stack variable-length tensors,
27
+ so we need a custom collate_fn to:
28
+ 1. Find the longest miRNA sequence in the batch
29
+ 2. Pad all miRNA sequences to the same length
30
+ 3. Generate an attention mask indicating which positions are real tokens vs. padding
31
+
32
+ Target sequences are fixed at 40nt (β†’ 42 tokens) and do not require additional padding.
33
+ """
34
+
35
+ from __future__ import annotations
36
+
37
+ import os
38
+
39
+ import fm
40
+ import pytorch_lightning as pl
41
+ import torch
42
+ from torch.nn.utils.rnn import pad_sequence
43
+ from torch.utils.data import DataLoader
44
+
45
+ from deepmirt.data_module.dataset import MiRNATargetDataset
46
+
47
+
48
+ class MiRNATargetDataModule(pl.LightningDataModule):
49
+ """
50
+ Lightning DataModule for miRNA-target pairs.
51
+
52
+ [Responsibilities]
53
+ - Manage creation and DataLoader configuration for train / val / test datasets
54
+ - Provide a custom collate_fn to handle variable-length miRNA sequence padding
55
+ - Encapsulate RNA-FM alphabet loading to avoid redundant initialization in multiple places
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ data_dir: str,
61
+ batch_size: int = 128,
62
+ num_workers: int = 8,
63
+ pin_memory: bool = True,
64
+ ):
65
+ """
66
+ Initialize the DataModule.
67
+
68
+ Args:
69
+ data_dir (str): path to the directory containing train.csv / val.csv / test.csv
70
+ batch_size (int): number of samples per batch, default 128
71
+ num_workers (int): number of DataLoader worker processes, default 8
72
+ # Design decision: num_workers controls data prefetching parallelism
73
+ # - 0 = load in main process (for debugging, slow but easy to troubleshoot)
74
+ # - 8 = 8 subprocesses load in parallel (for training, fully utilize multi-core CPU)
75
+ # - Rule of thumb: set to half of CPU cores or GPU count x 4
76
+ # - Too many will cause memory overhead and process switching overhead
77
+ pin_memory (bool): whether to pin data to page-locked memory, default True
78
+ # Design decision: pin_memory accelerates CPU→GPU data transfer
79
+ # - True: data is first copied to pinned memory, then transferred to GPU via DMA
80
+ # Eliminates one memory copy, improving throughput by ~2x
81
+ # - False: data is in pageable memory and must be copied to pinned memory before transfer
82
+ # - Only meaningful when using GPU; set to False for CPU training
83
+ """
84
+ super().__init__()
85
+ self.data_dir = data_dir
86
+ self.batch_size = batch_size
87
+ self.num_workers = num_workers
88
+ self.pin_memory = pin_memory
89
+
90
+ # Dataset instances, created in setup()
91
+ self.train_dataset: MiRNATargetDataset | None = None
92
+ self.val_dataset: MiRNATargetDataset | None = None
93
+ self.test_dataset: MiRNATargetDataset | None = None
94
+
95
+ # Load RNA-FM alphabet in the main process (before DDP fork)
96
+ # This way the alphabet is loaded only once, avoiding redundant full model loading on each DDP rank
97
+ _model, alphabet = fm.pretrained.rna_fm_t12()
98
+ del _model # Free model weights, keep only the alphabet (tokenizer)
99
+ self._alphabet = alphabet
100
+ self._padding_idx = alphabet.padding_idx # padding_idx = 1
101
+
102
+ def setup(self, stage: str | None = None) -> None:
103
+ """
104
+ Create Dataset instances.
105
+
106
+ Lightning automatically calls this method before training/validation/testing begins.
107
+ Each process (including multi-GPU DDP scenarios) calls setup() independently.
108
+
109
+ Args:
110
+ stage: 'fit' (train+val), 'test', 'predict', or None (all)
111
+ """
112
+ # alphabet was already loaded in __init__() (before DDP fork, loaded only once)
113
+ alphabet = self._alphabet
114
+
115
+ if stage == "fit" or stage is None:
116
+ self.train_dataset = MiRNATargetDataset(
117
+ os.path.join(self.data_dir, "train.csv"), alphabet
118
+ )
119
+ self.val_dataset = MiRNATargetDataset(
120
+ os.path.join(self.data_dir, "val.csv"), alphabet
121
+ )
122
+
123
+ if stage == "test" or stage is None:
124
+ self.test_dataset = MiRNATargetDataset(
125
+ os.path.join(self.data_dir, "test.csv"), alphabet
126
+ )
127
+
128
+ def train_dataloader(self) -> DataLoader:
129
+ """Return the training DataLoader (shuffle=True to randomize data order)."""
130
+ return DataLoader(
131
+ self.train_dataset,
132
+ batch_size=self.batch_size,
133
+ shuffle=True,
134
+ num_workers=self.num_workers,
135
+ pin_memory=self.pin_memory,
136
+ collate_fn=self._collate_fn,
137
+ drop_last=True,
138
+ )
139
+
140
+ def val_dataloader(self) -> DataLoader:
141
+ """Return the validation DataLoader (shuffle=False to preserve order for reproducible evaluation)."""
142
+ return DataLoader(
143
+ self.val_dataset,
144
+ batch_size=self.batch_size,
145
+ shuffle=False,
146
+ num_workers=self.num_workers,
147
+ pin_memory=self.pin_memory,
148
+ collate_fn=self._collate_fn,
149
+ )
150
+
151
+ def test_dataloader(self) -> DataLoader:
152
+ """Return the test DataLoader."""
153
+ return DataLoader(
154
+ self.test_dataset,
155
+ batch_size=self.batch_size,
156
+ shuffle=False,
157
+ num_workers=self.num_workers,
158
+ pin_memory=self.pin_memory,
159
+ collate_fn=self._collate_fn,
160
+ )
161
+
162
+ def _collate_fn(self, batch: list[dict]) -> dict:
163
+ """
164
+ Custom batch collation function β€” handles padding of variable-length miRNA sequences.
165
+
166
+ [Why is a custom collate_fn needed?]
167
+ PyTorch's default collate_fn attempts to stack all sample tensors.
168
+ But miRNA sequence lengths are variable (15-30nt β†’ 17-32 tokens), and direct stacking fails:
169
+ RuntimeError: stack expects each tensor to be equal size
170
+
171
+ [Why does miRNA need padding but target does not?]
172
+ - miRNA has variable length: 15-30 nucleotides β†’ 17-32 tokens after adding BOS+EOS
173
+ A single batch may contain lengths of both 17 and 32, which must be aligned
174
+ - Target has fixed length: all samples are 40 nucleotides β†’ 42 tokens
175
+ Naturally aligned, no padding needed
176
+
177
+ [Role of attention_mask]
178
+ - Tells the model which positions are real tokens (1) and which are padding (0)
179
+ - The Transformer's self-attention uses the mask to block padding positions
180
+ - Prevents padding tokens from participating in attention computation, avoiding noise
181
+
182
+ # Design decision: use pad_sequence instead of manual loop padding
183
+ # pad_sequence is a PyTorch built-in utility, optimized in C++, faster than Python loops
184
+ # It automatically finds the maximum length and pads shorter sequences with the specified value
185
+
186
+ Args:
187
+ batch: list of dicts, each dict from MiRNATargetDataset.__getitem__
188
+
189
+ Returns:
190
+ dict: containing the following key-value pairs:
191
+ - 'mirna_tokens': (batch_size, max_mirna_len) LongTensor
192
+ - 'target_tokens': (batch_size, 42) LongTensor
193
+ - 'labels': (batch_size,) float32 Tensor
194
+ - 'attention_mask_mirna': (batch_size, max_mirna_len) LongTensor
195
+ - 'attention_mask_target': (batch_size, 42) LongTensor
196
+ """
197
+ # ── 1. Collect individual fields ──
198
+ mirna_list = [sample["mirna_tokens"] for sample in batch]
199
+ target_list = [sample["target_tokens"] for sample in batch]
200
+ label_list = [sample["label"] for sample in batch]
201
+
202
+ # ── 2. Pad miRNA sequences ──
203
+ # pad_sequence converts list of 1D tensors β†’ 2D tensor (batch, max_len)
204
+ # batch_first=True ensures the batch dimension comes first
205
+ # padding_value=1 is RNA-FM's <pad> token ID
206
+ mirna_padded = pad_sequence(
207
+ mirna_list, batch_first=True, padding_value=self._padding_idx
208
+ )
209
+
210
+ # ── 3. Stack target sequences (fixed 42 tokens, no padding needed) ──
211
+ target_stacked = torch.stack(target_list)
212
+
213
+ # ── 4. Stack labels ──
214
+ labels = torch.stack(label_list)
215
+
216
+ # ── 5. Generate attention masks ──
217
+ # miRNA mask: non-padding positions = 1, padding positions = 0
218
+ attention_mask_mirna = (mirna_padded != self._padding_idx).long()
219
+
220
+ # target mask: all positions are real tokens, so all 1s
221
+ # Because target is fixed at 40nt with no padding, every position is valid
222
+ attention_mask_target = torch.ones_like(target_stacked, dtype=torch.long)
223
+
224
+ # ── 6. Collect metadata (for stratified analysis during evaluation) ──
225
+ # Each metadata field is collected as list[str], kept on CPU
226
+ metadata_keys = batch[0].get("metadata", {}).keys()
227
+ metadata = {
228
+ key: [sample["metadata"][key] for sample in batch]
229
+ for key in metadata_keys
230
+ } if metadata_keys else {}
231
+
232
+ return {
233
+ "mirna_tokens": mirna_padded, # (B, max_mirna_len)
234
+ "target_tokens": target_stacked, # (B, 42)
235
+ "labels": labels, # (B,)
236
+ "attention_mask_mirna": attention_mask_mirna, # (B, max_mirna_len)
237
+ "attention_mask_target": attention_mask_target, # (B, 42)
238
+ "metadata": metadata, # dict[str, list[str]]
239
+ }
deepmirt/data_module/dataset.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ miRNA-Target Pair Dataset β€” PyTorch Dataset Implementation
4
+
5
+ [Data Flow ASCII Diagram]
6
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
7
+ β”‚ MiRNATargetDataset Data Flow β”‚
8
+ β”‚ β”‚
9
+ β”‚ CSV file (train.csv / val.csv / test.csv) β”‚
10
+ β”‚ β”‚ β”‚
11
+ β”‚ β–Ό β”‚
12
+ β”‚ pd.read_csv() ─→ DataFrame (loaded entirely into memory) β”‚
13
+ β”‚ β”‚ β”‚
14
+ β”‚ β–Ό β”‚
15
+ β”‚ __getitem__(idx) ─→ retrieve row idx β”‚
16
+ β”‚ β”‚ β”‚
17
+ β”‚ β”œβ”€β†’ mirna_seq: "ATCGATCG" β”‚
18
+ β”‚ β”‚ β”‚ β”‚
19
+ β”‚ β”‚ β–Ό β”‚
20
+ β”‚ β”‚ dna_to_rna() ─→ "AUCGAUCG" (Tβ†’U conversion) β”‚
21
+ β”‚ β”‚ β”‚ β”‚
22
+ β”‚ β”‚ β–Ό β”‚
23
+ β”‚ β”‚ batch_converter([("mirna", "AUCGAUCG")]) β”‚
24
+ β”‚ β”‚ β”‚ β”‚
25
+ β”‚ β”‚ β–Ό β”‚
26
+ β”‚ β”‚ tokens: tensor([0, 4, 7, 5, 6, ...., 2]) β”‚
27
+ β”‚ β”‚ ^^BOS ^^EOS β”‚
28
+ β”‚ β”‚ β”‚
29
+ β”‚ β”œβ”€β†’ target_fragment_40nt: "TAGCTAGC..." β”‚
30
+ β”‚ β”‚ β”‚ (same dna_to_rna + batch_converter pipeline) β”‚
31
+ β”‚ β”‚ β–Ό β”‚
32
+ β”‚ β”‚ tokens: tensor([0, ..., 2]) (fixed 42 tokens: BOS+40nt+EOS)β”‚
33
+ β”‚ β”‚ β”‚
34
+ β”‚ └─→ return dict: β”‚
35
+ β”‚ { β”‚
36
+ β”‚ 'mirna_tokens': 1D LongTensor (variable 17-32) β”‚
37
+ β”‚ 'target_tokens': 1D LongTensor (fixed 42) β”‚
38
+ β”‚ 'label': float32 scalar (0.0 or 1.0) β”‚
39
+ β”‚ 'metadata': dict (species, mirna_name, ...) β”‚
40
+ β”‚ } β”‚
41
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
42
+
43
+ [RNA-FM batch_converter Input/Output Format]
44
+ - Input: List[Tuple[str, str]] = [("label_name", "RNA_sequence")]
45
+ e.g.: [("mirna", "AUCGAUCG")]
46
+
47
+ - Output: Tuple[List[str], List[str], Tensor]
48
+ - labels: ["mirna"] β€” label list (not used by us)
49
+ - strs: ["AUCGAUCG"] β€” raw sequences (not used by us)
50
+ - tokens: tensor([[0, 4, 7, 5, 6, 4, 7, 5, 6, 2]])
51
+ shape = (batch=1, seq_len)
52
+ where 0=BOS(<cls>), 2=EOS(<eos>), 1=PAD(<pad>)
53
+ A=4, C=5, G=6, U=7
54
+
55
+ - Important: batch_converter already adds BOS and EOS for us!
56
+ So 22nt miRNA β†’ 24 tokens (BOS + 22nt + EOS)
57
+ 40nt target β†’ 42 tokens (BOS + 40nt + EOS)
58
+ """
59
+
60
+ from __future__ import annotations
61
+
62
+ import pandas as pd
63
+ import torch
64
+ from torch.utils.data import Dataset
65
+
66
+ from deepmirt.data_module.preprocessing import dna_to_rna
67
+
68
+
69
+ class MiRNATargetDataset(Dataset):
70
+ """
71
+ PyTorch Dataset for miRNA-target pairs.
72
+
73
+ [Overview]
74
+ Loads miRNA-target sequence pairs from a CSV file, tokenizes them using
75
+ the RNA-FM alphabet, and returns token tensors and labels for training.
76
+
77
+ [Usage]
78
+ >>> import fm
79
+ >>> _, alphabet = fm.pretrained.rna_fm_t12()
80
+ >>> ds = MiRNATargetDataset('path/to/train.csv', alphabet)
81
+ >>> sample = ds[0]
82
+ >>> sample['mirna_tokens'] # tensor([0, 4, 7, 5, ..., 2])
83
+ >>> sample['label'] # tensor(1.)
84
+
85
+ [Why inherit from torch.utils.data.Dataset?]
86
+ - It is the standard PyTorch interface for data loading
87
+ - After defining __len__ and __getitem__, it can be used with DataLoader
88
+ - DataLoader automatically handles batching, multi-process loading, shuffling, etc.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ csv_path: str,
94
+ alphabet,
95
+ max_mirna_len: int = 30,
96
+ max_target_len: int = 40,
97
+ ):
98
+ """
99
+ Initialize the dataset.
100
+
101
+ Args:
102
+ csv_path (str): Path to the CSV file, which must contain the following columns:
103
+ - mirna_seq: miRNA sequence (DNA notation)
104
+ - target_fragment_40nt: target fragment sequence (DNA notation)
105
+ - label: binary label (0 or 1)
106
+ - species, mirna_name, target_gene_name: metadata columns
107
+ alphabet: RNA-FM alphabet object that provides tokenization capability
108
+ max_mirna_len (int): maximum nucleotide length for miRNA, default 30
109
+ (actual token count = max_mirna_len + 2, due to BOS and EOS)
110
+ max_target_len (int): maximum nucleotide length for target, default 40
111
+ (actual token count = max_target_len + 2 = 42)
112
+
113
+ [Design Decision: Memory Strategy]
114
+ We use pd.read_csv() to load the entire CSV into a DataFrame at once.
115
+ This is the simplest approach β€” for our data scale (~5.4 million training rows),
116
+ the DataFrame occupies approximately 2-3 GB of memory.
117
+
118
+ The current system has 1TB RAM, so this is not an issue at all.
119
+
120
+ # Design decision: if memory is limited (e.g., 8GB), consider these alternatives:
121
+ # 1. Byte-offset indexing: first pass records byte positions of each row in the file,
122
+ # __getitem__ uses file.seek(offset) to jump to and read that row
123
+ # 2. Memory mapping (mmap): open the file with mmap, read on demand
124
+ # 3. Chunked reading: load in chunks, combined with LRU cache
125
+ # These methods sacrifice code simplicity for lower memory usage
126
+ """
127
+ super().__init__()
128
+
129
+ # Save configuration parameters
130
+ self.csv_path = csv_path
131
+ self.alphabet = alphabet
132
+ self.max_mirna_len = max_mirna_len
133
+ self.max_target_len = max_target_len
134
+
135
+ # Get batch_converter for tokenization
136
+ # batch_converter is the tokenization tool provided by RNA-FM, converting RNA strings to token IDs
137
+ self.batch_converter = alphabet.get_batch_converter()
138
+
139
+ # Design decision: load entire CSV into memory (see docstring above for details)
140
+ # On a 1TB RAM system, 5.4 million rows β‰ˆ 2-3 GB, easily affordable
141
+ self.df = pd.read_csv(
142
+ csv_path,
143
+ dtype={"target_gene_name": str, "target_gene_id": str},
144
+ )
145
+
146
+ def __len__(self) -> int:
147
+ """
148
+ Return the number of samples in the dataset.
149
+
150
+ DataLoader calls this method to determine how many steps per epoch.
151
+ e.g.: len(dataset)=557521, batch_size=128 β†’ ~4356 steps per epoch
152
+ """
153
+ return len(self.df)
154
+
155
+ def __getitem__(self, idx: int) -> dict:
156
+ """
157
+ Retrieve the idx-th sample, returning a dict of tokenized tensors.
158
+
159
+ [Processing Pipeline]
160
+ 1. Extract row idx from the DataFrame
161
+ 2. Get mirna_seq and target_fragment_40nt
162
+ 3. Apply dna_to_rna() for T→U conversion
163
+ 4. Tokenize with RNA-FM batch_converter
164
+ 5. Assemble and return the dict
165
+
166
+ Args:
167
+ idx (int): sample index, range [0, len(self)-1]
168
+
169
+ Returns:
170
+ dict: containing the following key-value pairs:
171
+ - 'mirna_tokens': 1D LongTensor, miRNA token sequence
172
+ shape = (mirna_len+2,), including BOS and EOS
173
+ - 'target_tokens': 1D LongTensor, target token sequence
174
+ shape = (42,), fixed length (BOS + 40nt + EOS)
175
+ - 'label': float32 scalar tensor (0.0 or 1.0)
176
+ - 'metadata': dict, containing species, mirna_name, target_gene_name
177
+ """
178
+ # ── Step 1: Extract one row from the DataFrame ──
179
+ row = self.df.iloc[idx]
180
+
181
+ # ── Step 2: Extract sequences and label ──
182
+ mirna_seq_raw = row["mirna_seq"]
183
+ target_seq_raw = row["target_fragment_40nt"]
184
+ label = row["label"]
185
+
186
+ # ── Step 3: DNA-to-RNA conversion (T β†’ U) ──
187
+ # Sequences in the dataset use DNA notation (T for thymine),
188
+ # but the RNA-FM model expects RNA notation (U for uridine), so conversion is needed
189
+ mirna_rna = dna_to_rna(mirna_seq_raw)
190
+ target_rna = dna_to_rna(target_seq_raw)
191
+
192
+ # ── Step 4: Tokenize using RNA-FM batch_converter ──
193
+ # batch_converter input format: List[Tuple[label, sequence]]
194
+ # It automatically adds BOS(<cls>=0) and EOS(<eos>=2) tokens around the sequence
195
+ #
196
+ # e.g.: [("mirna", "AUCG")]
197
+ # output tokens: tensor([[0, 4, 7, 5, 6, 2]])
198
+ # BOS=0 A U C G EOS=2
199
+ #
200
+ # Here we process only 1 sequence at a time (batch_size=1),
201
+ # so we use tokens[0] to extract the first one, yielding a 1D tensor
202
+
203
+ # Tokenize miRNA
204
+ _, _, mirna_tokens = self.batch_converter([("mirna", mirna_rna)])
205
+ mirna_tokens = mirna_tokens[0] # (1, seq_len) β†’ (seq_len,)
206
+
207
+ # Tokenize target
208
+ _, _, target_tokens = self.batch_converter([("target", target_rna)])
209
+ target_tokens = target_tokens[0] # (1, 42) β†’ (42,)
210
+
211
+ # ── Step 5: Assemble the return dict ──
212
+ # Why use float32 for label?
213
+ # Because training uses BCEWithLogitsLoss (binary cross-entropy),
214
+ # which requires both target and prediction to be float type.
215
+ # If label is int/long, PyTorch will raise a type mismatch error.
216
+ return {
217
+ "mirna_tokens": mirna_tokens, # 1D LongTensor, variable (17-32)
218
+ "target_tokens": target_tokens, # 1D LongTensor, fixed 42
219
+ "label": torch.tensor(label, dtype=torch.float32), # scalar float32
220
+ "metadata": {
221
+ "species": row["species"],
222
+ "mirna_name": row["mirna_name"],
223
+ "target_gene_name": row["target_gene_name"],
224
+ "evidence_type": row.get("evidence_type", ""),
225
+ "source_database": row.get("source_database", ""),
226
+ },
227
+ }
deepmirt/data_module/preprocessing.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Data Preprocessing Utilities β€” RNA Sequence Format Conversion Module
4
+
5
+ This module converts DNA-notation sequences in the dataset to the RNA notation
6
+ format required by the RNA-FM model.
7
+
8
+ [Why is this conversion needed?]
9
+ - The RNA-FM model was trained on RNA sequences and expects input in RNA notation: A, U, G, C
10
+ - Our dataset stores sequences in DNA notation: A, T, G, C (where T replaces U)
11
+ - During training, DNA notation T must be converted to RNA notation U to match the model's expected input format
12
+
13
+ [Architecture Position]
14
+ - This module is called by Dataset.__getitem__() during training
15
+ - The conversion happens at the data loading stage without modifying the original CSV files
16
+ - Reference: finalize_dataset.py:86-93 performs the reverse operation (U→T) for data export
17
+
18
+ [Design Decisions]
19
+ - Conversion is performed online (in the Dataset) rather than preprocessing the CSV, to preserve original data integrity
20
+ - All sequences are converted to uppercase to ensure format consistency
21
+ - The character N (representing ambiguous bases) is allowed; RNA-FM can handle ambiguous bases
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+
27
+ def dna_to_rna(seq: str) -> str:
28
+ """
29
+ Convert a DNA-notation sequence to an RNA-notation sequence.
30
+
31
+ [Description]
32
+ - Converts T (thymine, DNA) to U (uridine, RNA)
33
+ - Converts to uppercase
34
+ - Removes all whitespace characters
35
+ - Idempotent: sequences already in RNA format remain unchanged
36
+
37
+ [Design Decisions]
38
+ - Why convert online? To keep the original CSV data intact for auditing and reproducibility
39
+ - Why uppercase? To ensure consistency with the RNA-FM model's expected input format
40
+ - Why allow N? RNA-FM's tokenizer can handle ambiguous bases
41
+
42
+ Args:
43
+ seq (str): DNA-notation sequence string, may contain A, T, G, C, N and whitespace
44
+
45
+ Returns:
46
+ str: RNA-notation sequence string, containing A, U, G, C, N (uppercase, no whitespace)
47
+
48
+ Example:
49
+ >>> dna_to_rna('ATCGATCG')
50
+ 'AUCGAUCG'
51
+ >>> dna_to_rna('atcg') # mixed case
52
+ 'AUCG'
53
+ >>> dna_to_rna('AUCGAUCG') # already RNA format (idempotent)
54
+ 'AUCGAUCG'
55
+ >>> dna_to_rna('ATC NGATCG') # contains N and whitespace
56
+ 'AUCNGAUCG'
57
+ >>> dna_to_rna(' ATC G ') # leading/trailing whitespace
58
+ 'AUCG'
59
+ """
60
+ # Step 1: Convert to uppercase
61
+ seq = str(seq).upper()
62
+
63
+ # Step 2: Remove all whitespace characters (spaces, tabs, newlines)
64
+ seq = seq.replace(" ", "").replace("\t", "").replace("\n", "").replace("\r", "")
65
+
66
+ # Step 3: Convert T (DNA) to U (RNA)
67
+ seq = seq.replace("T", "U")
68
+
69
+ return seq
70
+
71
+
72
+ def validate_rna_sequence(seq: str, min_len: int = 5, max_len: int = 100) -> bool:
73
+ """
74
+ Validate whether a sequence is in valid RNA format.
75
+
76
+ [Description]
77
+ - Checks that the sequence contains only valid RNA characters: A, U, G, C, N
78
+ - Checks that the sequence length is within the specified range
79
+ - If it contains T, the DNA-to-RNA conversion was not performed; returns False
80
+
81
+ [Design Decisions]
82
+ - Why check for T? It serves as an indicator of conversion failure, aiding data flow debugging
83
+ - Why allow N? RNA-FM's tokenizer supports ambiguous bases
84
+ - Why impose length limits? To prevent abnormally long sequences from causing memory overflow
85
+
86
+ Args:
87
+ seq (str): the sequence string to validate
88
+ min_len (int): minimum length (inclusive), default 5
89
+ max_len (int): maximum length (inclusive), default 100
90
+
91
+ Returns:
92
+ bool: True if the sequence is valid, False otherwise
93
+
94
+ Example:
95
+ >>> validate_rna_sequence('AUCGAUCG', 5, 30)
96
+ True
97
+ >>> validate_rna_sequence('ATCG', 5, 30) # contains T (DNA notation)
98
+ False
99
+ >>> validate_rna_sequence('AU', 5, 30) # too short
100
+ False
101
+ >>> validate_rna_sequence('A' * 31, 5, 30) # too long
102
+ False
103
+ >>> validate_rna_sequence('AUCNGAUCG', 5, 30) # contains N (valid)
104
+ True
105
+ """
106
+ # Check length
107
+ if len(seq) < min_len or len(seq) > max_len:
108
+ return False
109
+
110
+ # Define valid RNA character set
111
+ valid_chars = {"A", "U", "G", "C", "N"}
112
+
113
+ # Check if all characters are valid
114
+ for char in seq:
115
+ if char not in valid_chars:
116
+ # Specifically check for T, indicating conversion failure
117
+ if char == "T":
118
+ return False
119
+ # Other invalid characters also return False
120
+ return False
121
+
122
+ return True
123
+
124
+
125
+ def prepare_rnafm_input(mirna_seq: str, target_seq: str) -> tuple[str, str]:
126
+ """
127
+ Prepare an input sequence pair for the RNA-FM model.
128
+
129
+ [Description]
130
+ - Converts both miRNA and target sequences to RNA notation
131
+ - Returns two separate strings (not concatenated)
132
+ - RNA-FM uses a shared encoder architecture that processes each sequence independently
133
+
134
+ [Design Decisions]
135
+ - Why not concatenate? The dual-encoder processes each sequence in separate forward passes
136
+ - Concatenation would break the model's architectural design and degrade performance
137
+ - Returning a tuple is convenient for use in Dataset.__getitem__()
138
+
139
+ Args:
140
+ mirna_seq (str): miRNA sequence (DNA notation)
141
+ target_seq (str): target sequence (DNA notation)
142
+
143
+ Returns:
144
+ tuple[str, str]: (mirna_rna, target_rna) tuple, both in RNA notation
145
+
146
+ Example:
147
+ >>> mirna_rna, target_rna = prepare_rnafm_input('ATCG', 'TAGC')
148
+ >>> mirna_rna
149
+ 'AUCG'
150
+ >>> target_rna
151
+ 'UAGC'
152
+ """
153
+ # Convert the two sequences separately
154
+ mirna_rna = dna_to_rna(mirna_seq)
155
+ target_rna = dna_to_rna(target_seq)
156
+
157
+ return mirna_rna, target_rna
158
+
159
+
160
+ def compute_sequence_stats(csv_path: str, sample_n: int = 10000) -> dict:
161
+ """
162
+ Compute statistics for sequences in a CSV file.
163
+
164
+ [Description]
165
+ - Samples a specified number of rows from the CSV file
166
+ - Computes sequence length distributions, character frequencies, DNA notation detection, etc.
167
+ - Used for data quality checks and analysis
168
+
169
+ [Design Decisions]
170
+ - Why lazy-import pandas? To avoid introducing a heavy dependency at module load time
171
+ - Import only when needed, reducing startup time
172
+ - Sampling instead of full processing speeds up statistics computation
173
+
174
+ Args:
175
+ csv_path (str): path to the CSV file
176
+ sample_n (int): number of rows to sample, default 10000. If the file has fewer rows, all rows are used
177
+
178
+ Returns:
179
+ dict: statistics dictionary containing the following keys:
180
+ - 'total_rows': total number of rows in the file (excluding header)
181
+ - 'sample_rows': actual number of sampled rows
182
+ - 'mirna_length_min': minimum miRNA length
183
+ - 'mirna_length_max': maximum miRNA length
184
+ - 'mirna_length_mean': mean miRNA length
185
+ - 'target_length_min': minimum target sequence length
186
+ - 'target_length_max': maximum target sequence length
187
+ - 'target_length_mean': mean target sequence length
188
+ - 'mirna_char_freq': miRNA character frequency dictionary
189
+ - 'target_char_freq': target sequence character frequency dictionary
190
+ - 'mirna_with_t_count': number of miRNA sequences containing T
191
+ - 'target_with_t_count': number of target sequences containing T
192
+
193
+ Example:
194
+ >>> stats = compute_sequence_stats('deepmirt/data/training/train.csv', sample_n=100)
195
+ >>> print(f"Total rows: {stats['total_rows']}")
196
+ >>> print(f"miRNA length range: {stats['mirna_length_min']}-{stats['mirna_length_max']}")
197
+ """
198
+ # Lazy-import pandas to avoid introducing a heavy dependency at module load time
199
+ import pandas as pd
200
+
201
+ # Read the CSV file
202
+ df = pd.read_csv(csv_path)
203
+
204
+ # Compute total number of rows
205
+ total_rows = len(df)
206
+
207
+ # Determine sample size (capped at total number of rows)
208
+ actual_sample_n = min(sample_n, total_rows)
209
+
210
+ # Sample data
211
+ if actual_sample_n < total_rows:
212
+ sample_df = df.sample(n=actual_sample_n, random_state=42)
213
+ else:
214
+ sample_df = df
215
+
216
+ # Initialize statistics dictionary
217
+ stats = {
218
+ 'total_rows': total_rows,
219
+ 'sample_rows': len(sample_df),
220
+ }
221
+
222
+ # Compute miRNA sequence statistics
223
+ mirna_lengths = sample_df['mirna_seq'].str.len()
224
+ stats['mirna_length_min'] = int(mirna_lengths.min())
225
+ stats['mirna_length_max'] = int(mirna_lengths.max())
226
+ stats['mirna_length_mean'] = float(mirna_lengths.mean())
227
+
228
+ # Compute target sequence statistics
229
+ target_lengths = sample_df['target_fragment_40nt'].str.len()
230
+ stats['target_length_min'] = int(target_lengths.min())
231
+ stats['target_length_max'] = int(target_lengths.max())
232
+ stats['target_length_mean'] = float(target_lengths.mean())
233
+
234
+ # Compute character frequencies
235
+ def compute_char_freq(seq_series):
236
+ """Compute the frequency of each character in the sequences"""
237
+ freq = {}
238
+ for seq in seq_series:
239
+ seq = str(seq).upper()
240
+ for char in seq:
241
+ freq[char] = freq.get(char, 0) + 1
242
+ return freq
243
+
244
+ stats['mirna_char_freq'] = compute_char_freq(sample_df['mirna_seq'])
245
+ stats['target_char_freq'] = compute_char_freq(sample_df['target_fragment_40nt'])
246
+
247
+ # Count sequences containing T (DNA notation)
248
+ stats['mirna_with_t_count'] = (sample_df['mirna_seq'].str.contains('T', case=False, na=False)).sum()
249
+ stats['target_with_t_count'] = (sample_df['target_fragment_40nt'].str.contains('T', case=False, na=False)).sum()
250
+
251
+ return stats
deepmirt/evaluation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """miRNA target prediction model β€” comprehensive evaluation framework."""
deepmirt/evaluation/predict.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Inference engine: load checkpoint and generate prediction DataFrame on the test set.
4
+
5
+ Independent of Lightning trainer.test(), performs batch inference directly and
6
+ retains all metadata. Prediction results are cached as parquet to avoid repeated inference.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ import torch
17
+ import yaml
18
+ from torch.utils.data import DataLoader
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def load_model_from_checkpoint(
24
+ ckpt_path: str,
25
+ config_path: str,
26
+ device: str = "cuda",
27
+ ):
28
+ """
29
+ Load a trained model from checkpoint.
30
+
31
+ Args:
32
+ ckpt_path: path to the checkpoint file
33
+ config_path: path to the training config YAML
34
+ device: inference device
35
+
36
+ Returns:
37
+ (model, config) tuple
38
+ """
39
+ from deepmirt.training.lightning_module import MiRNATargetLitModule
40
+
41
+ with open(config_path) as f:
42
+ config = yaml.safe_load(f)
43
+
44
+ lit_model = MiRNATargetLitModule.load_from_checkpoint(
45
+ ckpt_path, config=config, map_location=device
46
+ )
47
+ lit_model.eval()
48
+ lit_model.to(device)
49
+ return lit_model, config
50
+
51
+
52
+ def run_inference(
53
+ ckpt_path: str,
54
+ config_path: str,
55
+ test_csv_path: str,
56
+ batch_size: int = 256,
57
+ num_workers: int = 8,
58
+ device: str = "cuda",
59
+ cache_path: str | None = None,
60
+ ) -> pd.DataFrame:
61
+ """
62
+ Run model inference on the test set, returning a DataFrame with predictions and metadata.
63
+
64
+ If cache_path exists and is non-empty, loads cached results directly.
65
+
66
+ Args:
67
+ ckpt_path: path to the checkpoint
68
+ config_path: path to the config YAML
69
+ test_csv_path: path to test.csv
70
+ batch_size: inference batch size
71
+ num_workers: number of DataLoader worker threads
72
+ device: inference device
73
+ cache_path: cache file path (parquet), None to disable caching
74
+
75
+ Returns:
76
+ DataFrame with columns:
77
+ mirna_seq, target_fragment_40nt, label, prob, pred, logit,
78
+ species, mirna_name, target_gene_name, evidence_type, source_database
79
+ """
80
+ # Check cache (supports both parquet and csv formats)
81
+ if cache_path and Path(cache_path).exists():
82
+ logger.info(f"Loading cached predictions from {cache_path}")
83
+ if cache_path.endswith(".parquet"):
84
+ return pd.read_parquet(cache_path)
85
+ else:
86
+ return pd.read_csv(cache_path)
87
+
88
+ logger.info(f"Loading model from {ckpt_path}")
89
+ lit_model, config = load_model_from_checkpoint(ckpt_path, config_path, device)
90
+
91
+ # Load data (using DataModule approach for consistency)
92
+ import fm
93
+
94
+ from deepmirt.data_module.datamodule import MiRNATargetDataModule
95
+ from deepmirt.data_module.dataset import MiRNATargetDataset
96
+
97
+ _, alphabet = fm.pretrained.rna_fm_t12()
98
+ del _
99
+ padding_idx = alphabet.padding_idx
100
+
101
+ dataset = MiRNATargetDataset(test_csv_path, alphabet)
102
+
103
+ # Use the DataModule's collate_fn logic
104
+ dm = MiRNATargetDataModule.__new__(MiRNATargetDataModule)
105
+ dm._padding_idx = padding_idx
106
+
107
+ dataloader = DataLoader(
108
+ dataset,
109
+ batch_size=batch_size,
110
+ shuffle=False,
111
+ num_workers=num_workers,
112
+ pin_memory=True,
113
+ collate_fn=dm._collate_fn,
114
+ )
115
+
116
+ # Inference
117
+ all_logits = []
118
+ all_labels = []
119
+ all_metadata = {
120
+ "species": [],
121
+ "mirna_name": [],
122
+ "target_gene_name": [],
123
+ "evidence_type": [],
124
+ "source_database": [],
125
+ }
126
+
127
+ logger.info(f"Running inference on {len(dataset)} samples...")
128
+ with torch.no_grad():
129
+ for batch_idx, batch in enumerate(dataloader):
130
+ mirna_tokens = batch["mirna_tokens"].to(device)
131
+ target_tokens = batch["target_tokens"].to(device)
132
+ labels = batch["labels"]
133
+ attn_mask_mirna = batch["attention_mask_mirna"].to(device)
134
+ attn_mask_target = batch["attention_mask_target"].to(device)
135
+
136
+ logits = lit_model.model(
137
+ mirna_tokens, target_tokens, attn_mask_mirna, attn_mask_target
138
+ )
139
+ logits = logits.squeeze(-1).cpu()
140
+
141
+ all_logits.append(logits)
142
+ all_labels.append(labels)
143
+
144
+ metadata = batch.get("metadata", {})
145
+ for key in all_metadata:
146
+ if key in metadata:
147
+ all_metadata[key].extend(metadata[key])
148
+ else:
149
+ all_metadata[key].extend([""] * len(labels))
150
+
151
+ if (batch_idx + 1) % 500 == 0:
152
+ logger.info(
153
+ f" Processed {(batch_idx + 1) * batch_size} / {len(dataset)}"
154
+ )
155
+
156
+ all_logits = torch.cat(all_logits).numpy()
157
+ all_labels = torch.cat(all_labels).numpy()
158
+ all_probs = 1.0 / (1.0 + np.exp(-all_logits)) # sigmoid
159
+ all_preds = (all_probs >= 0.5).astype(int)
160
+
161
+ # Build raw sequence columns (read directly from CSV)
162
+ raw_df = pd.read_csv(
163
+ test_csv_path,
164
+ usecols=["mirna_seq", "target_fragment_40nt"],
165
+ dtype=str,
166
+ )
167
+
168
+ result_df = pd.DataFrame(
169
+ {
170
+ "mirna_seq": raw_df["mirna_seq"].values,
171
+ "target_fragment_40nt": raw_df["target_fragment_40nt"].values,
172
+ "label": all_labels.astype(int),
173
+ "prob": all_probs,
174
+ "pred": all_preds,
175
+ "logit": all_logits,
176
+ "species": all_metadata["species"],
177
+ "mirna_name": all_metadata["mirna_name"],
178
+ "target_gene_name": all_metadata["target_gene_name"],
179
+ "evidence_type": all_metadata["evidence_type"],
180
+ "source_database": all_metadata["source_database"],
181
+ }
182
+ )
183
+
184
+ # Cache results (prefer parquet, fallback to csv)
185
+ if cache_path:
186
+ Path(cache_path).parent.mkdir(parents=True, exist_ok=True)
187
+ try:
188
+ if cache_path.endswith(".parquet"):
189
+ result_df.to_parquet(cache_path, index=False)
190
+ else:
191
+ result_df.to_csv(cache_path, index=False)
192
+ except ImportError:
193
+ # pyarrow not installed, fallback to csv
194
+ csv_path = cache_path.replace(".parquet", ".csv")
195
+ result_df.to_csv(csv_path, index=False)
196
+ logger.info(f"pyarrow not available, saved as CSV: {csv_path}")
197
+ cache_path = csv_path
198
+ logger.info(f"Predictions cached to {cache_path}")
199
+
200
+ logger.info(
201
+ f"Inference complete: {len(result_df)} samples, "
202
+ f"pos={result_df['label'].sum()}, neg={(result_df['label'] == 0).sum()}"
203
+ )
204
+ return result_df
205
+
206
+
207
+ def predict_on_sequences(
208
+ ckpt_path: str,
209
+ config_path: str,
210
+ mirna_seqs: list[str],
211
+ target_seqs: list[str],
212
+ batch_size: int = 256,
213
+ device: str = "cuda",
214
+ _lit_model=None,
215
+ _alphabet=None,
216
+ ) -> np.ndarray:
217
+ """
218
+ Run inference on arbitrary miRNA + target sequence pairs.
219
+
220
+ Used to run our model on external data such as miRBench standard benchmark datasets.
221
+ Sequences are automatically converted to RNA format (T->U).
222
+
223
+ Args:
224
+ ckpt_path: path to the checkpoint
225
+ config_path: path to the config YAML
226
+ mirna_seqs: list of miRNA sequences (DNA or RNA format accepted)
227
+ target_seqs: list of target sequences (DNA or RNA format, should be 40nt)
228
+ batch_size: inference batch size
229
+ device: inference device
230
+ _lit_model: pre-loaded model (internal use, for caching)
231
+ _alphabet: pre-loaded alphabet (internal use, for caching)
232
+
233
+ Returns:
234
+ numpy array of predicted probabilities, shape (n_samples,)
235
+ """
236
+ import fm
237
+ from torch.nn.utils.rnn import pad_sequence
238
+
239
+ if _lit_model is not None:
240
+ lit_model = _lit_model
241
+ else:
242
+ logger.info(f"Loading model from {ckpt_path}")
243
+ lit_model, config = load_model_from_checkpoint(ckpt_path, config_path, device)
244
+
245
+ if _alphabet is not None:
246
+ alphabet = _alphabet
247
+ else:
248
+ _, alphabet = fm.pretrained.rna_fm_t12()
249
+ del _
250
+ batch_converter = alphabet.get_batch_converter()
251
+ padding_idx = alphabet.padding_idx
252
+
253
+ def _to_rna(seq: str) -> str:
254
+ return seq.upper().replace("T", "U")
255
+
256
+ all_probs = []
257
+ n_samples = len(mirna_seqs)
258
+ logger.info(f"Running inference on {n_samples} sequences...")
259
+
260
+ with torch.no_grad():
261
+ for i in range(0, n_samples, batch_size):
262
+ batch_mirna = mirna_seqs[i : i + batch_size]
263
+ batch_target = target_seqs[i : i + batch_size]
264
+
265
+ mirna_tokens_list = []
266
+ target_tokens_list = []
267
+ for m_seq, t_seq in zip(batch_mirna, batch_target):
268
+ m_rna = _to_rna(str(m_seq))
269
+ t_rna = _to_rna(str(t_seq))
270
+ _, _, m_tok = batch_converter([("m", m_rna)])
271
+ _, _, t_tok = batch_converter([("t", t_rna)])
272
+ mirna_tokens_list.append(m_tok[0])
273
+ target_tokens_list.append(t_tok[0])
274
+
275
+ mirna_padded = pad_sequence(
276
+ mirna_tokens_list, batch_first=True, padding_value=padding_idx
277
+ )
278
+ target_stacked = torch.stack(target_tokens_list)
279
+
280
+ attn_mask_mirna = (mirna_padded != padding_idx).long()
281
+ attn_mask_target = torch.ones_like(target_stacked, dtype=torch.long)
282
+
283
+ mirna_padded = mirna_padded.to(device)
284
+ target_stacked = target_stacked.to(device)
285
+ attn_mask_mirna = attn_mask_mirna.to(device)
286
+ attn_mask_target = attn_mask_target.to(device)
287
+
288
+ logits = lit_model.model(
289
+ mirna_padded, target_stacked, attn_mask_mirna, attn_mask_target
290
+ )
291
+ probs = torch.sigmoid(logits.squeeze(-1)).cpu().numpy()
292
+ all_probs.append(probs)
293
+
294
+ if (i // batch_size + 1) % 100 == 0:
295
+ logger.info(f" Processed {min(i + batch_size, n_samples)} / {n_samples}")
296
+
297
+ return np.concatenate(all_probs)
deepmirt/model/__init__.py ADDED
File without changes
deepmirt/model/classifier.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # pyright: basic, reportMissingImports=false
3
+ """
4
+ MLP classifier head (maps sequence representations to binary classification logits).
5
+
6
+ Architecture diagram:
7
+
8
+ pooled_feature (B, 640)
9
+ |
10
+ v
11
+ Linear(640 -> 256)
12
+ |
13
+ v
14
+ BatchNorm + ReLU + Dropout(0.3)
15
+ |
16
+ v
17
+ Linear(256 -> 64) + ReLU + Dropout(0.2)
18
+ |
19
+ v
20
+ Linear(64 -> 1)
21
+ |
22
+ v
23
+ logits (B, 1)
24
+
25
+ Note:
26
+ - The output is logits (raw scores); do not apply sigmoid inside the model.
27
+ - During training, use BCEWithLogitsLoss which applies sigmoid internally for numerical stability.
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ from collections.abc import Sequence
33
+
34
+ from torch import Tensor, nn
35
+
36
+
37
+ class MLPClassifier(nn.Module):
38
+ """MLP head for binary classification, outputting a single logit."""
39
+
40
+ def __init__(
41
+ self,
42
+ input_dim: int = 640,
43
+ hidden_dims: Sequence[int] | None = None,
44
+ dropout: float = 0.3,
45
+ ) -> None:
46
+ super().__init__()
47
+ dims = list(hidden_dims) if hidden_dims is not None else [256, 64]
48
+ if len(dims) != 2:
49
+ raise ValueError("hidden_dims must contain exactly two elements, e.g. [256, 64].")
50
+
51
+ hidden1, hidden2 = int(dims[0]), int(dims[1])
52
+ in_dim = int(input_dim)
53
+
54
+ # Design decision: [256, 64] balances expressiveness and overfitting risk,
55
+ # suitable for small-to-medium scale biological data.
56
+ # Design decision: first layer uses BatchNorm + Dropout; second layer retains
57
+ # a smaller Dropout for lightweight regularization.
58
+ self.layers = nn.Sequential(
59
+ nn.Linear(in_dim, hidden1),
60
+ nn.BatchNorm1d(hidden1),
61
+ nn.ReLU(),
62
+ nn.Dropout(dropout),
63
+ nn.Linear(hidden1, hidden2),
64
+ nn.ReLU(),
65
+ nn.Dropout(0.2),
66
+ nn.Linear(hidden2, 1),
67
+ )
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ """
71
+ Args:
72
+ x: Pooled sequence representation, shape `(batch, input_dim)`.
73
+
74
+ Returns:
75
+ Logits, shape `(batch, 1)`.
76
+ """
77
+ return self.layers(x)
deepmirt/model/cross_attention.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # pyright: basic, reportMissingImports=false
3
+ """
4
+ Cross-Attention interaction module.
5
+
6
+ Data flow diagram (target as Query, miRNA as Key/Value)::
7
+
8
+ target_emb (B, T, D) -------------------------------> Q
9
+ |
10
+ | Multi-Head Cross Attention
11
+ | (batch_first=True)
12
+ |
13
+ miRNA_emb (B, M, D) ---> K, V -------------------->
14
+
15
+ Output: context_target (B, T, D)
16
+
17
+ Why target=Q and miRNA=K/V:
18
+ - Our task is to determine whether a target is regulated by a given miRNA.
19
+ - Having each target position query miRNA information aligns with the semantics
20
+ of locating potential binding sites on the target.
21
+
22
+ Mask convention:
23
+ - key_padding_mask=True indicates a padding position that should be ignored.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import torch
29
+ from torch import Tensor, nn
30
+
31
+
32
+ class CrossAttentionBlock(nn.Module):
33
+ """Interaction module composed of stacked Cross-Attention + FFN layers."""
34
+
35
+ def __init__(
36
+ self,
37
+ embed_dim: int = 640,
38
+ num_heads: int = 8,
39
+ dropout: float = 0.1,
40
+ num_layers: int = 2,
41
+ ) -> None:
42
+ super().__init__()
43
+ self.embed_dim = int(embed_dim)
44
+ self.num_heads = int(num_heads)
45
+ self.num_layers = int(num_layers)
46
+
47
+ self.layers = nn.ModuleList()
48
+ for _ in range(self.num_layers):
49
+ layer = nn.ModuleDict(
50
+ {
51
+ "cross_attn": nn.MultiheadAttention(
52
+ embed_dim=self.embed_dim,
53
+ num_heads=self.num_heads,
54
+ dropout=dropout,
55
+ batch_first=True,
56
+ ),
57
+ "dropout_attn": nn.Dropout(dropout),
58
+ "norm1": nn.LayerNorm(self.embed_dim),
59
+ "ffn": nn.Sequential(
60
+ nn.Linear(self.embed_dim, self.embed_dim * 4),
61
+ nn.ReLU(),
62
+ nn.Dropout(dropout),
63
+ nn.Linear(self.embed_dim * 4, self.embed_dim),
64
+ ),
65
+ "norm2": nn.LayerNorm(self.embed_dim),
66
+ }
67
+ )
68
+ self.layers.append(layer)
69
+
70
+ # Design decision: 2 layers by default is a lightweight yet effective trade-off;
71
+ # establish a trainable baseline first, then deepen based on data scale.
72
+ # Design decision: 8 attention heads by default improves interaction modeling across
73
+ # different subspaces while keeping GPU memory overhead manageable.
74
+
75
+ def forward(
76
+ self,
77
+ query: Tensor,
78
+ key_value: Tensor,
79
+ key_padding_mask: Tensor | None = None,
80
+ ) -> Tensor:
81
+ """
82
+ Args:
83
+ query: Target representation, shape `(batch, target_len, embed_dim)`.
84
+ key_value: miRNA representation, shape `(batch, mirna_len, embed_dim)`.
85
+ key_padding_mask: miRNA padding mask, shape `(batch, mirna_len)`,
86
+ where True indicates positions to ignore.
87
+
88
+ Returns:
89
+ Updated target representation, shape `(batch, target_len, embed_dim)`.
90
+ """
91
+ hidden = query
92
+ attn_mask = key_padding_mask
93
+ if attn_mask is not None and attn_mask.dtype is not torch.bool:
94
+ attn_mask = attn_mask.to(dtype=torch.bool)
95
+
96
+ for layer in self.layers:
97
+ # Step 1: Cross-Attention (target queries miRNA)
98
+ attn_out, _ = layer["cross_attn"](
99
+ query=hidden,
100
+ key=key_value,
101
+ value=key_value,
102
+ key_padding_mask=attn_mask,
103
+ need_weights=False,
104
+ )
105
+
106
+ # Step 2: Residual + LayerNorm to stabilize deep training and mitigate vanishing gradients
107
+ hidden = layer["norm1"](hidden + layer["dropout_attn"](attn_out))
108
+
109
+ # Step 3: Feed-forward network refines channel-wise features
110
+ ffn_out = layer["ffn"](hidden)
111
+
112
+ # Step 4: Residual + LayerNorm
113
+ hidden = layer["norm2"](hidden + ffn_out)
114
+
115
+ return hidden
deepmirt/model/mirna_target_model.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # pyright: basic, reportMissingImports=false
3
+ """
4
+ Full miRNA-target model: shared RNA-FM encoder + Cross-Attention + MLP classifier head.
5
+
6
+ Complete data flow (with tensor shapes):
7
+
8
+ miRNA tokens (B, M_tok) ---> [RNA-FM Encoder] ---> miRNA_emb (B, M, D) ---┐
9
+ |
10
+ v
11
+ target tokens (B, T_tok) ---> [RNA-FM Encoder] ---> target_emb (B, T, D) --> [Cross-Attention]
12
+ |
13
+ v
14
+ cross_out (B, T, D)
15
+ |
16
+ v
17
+ masked mean pool
18
+ |
19
+ v
20
+ (B, D)
21
+ |
22
+ v
23
+ [MLP Head]
24
+ |
25
+ v
26
+ logits
27
+ (B, 1)
28
+
29
+ Where D is automatically inferred from RNA-FM (typically 640) to avoid hard-coding.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ from collections.abc import Sequence
35
+
36
+ import torch
37
+ from torch import Tensor, nn
38
+
39
+ from .classifier import MLPClassifier
40
+ from .cross_attention import CrossAttentionBlock
41
+ from .rnafm_encoder import RNAFMEncoder
42
+
43
+
44
+ class MiRNATargetModel(nn.Module):
45
+ """End-to-end model for miRNA-target binary classification."""
46
+
47
+ def __init__(
48
+ self,
49
+ freeze_backbone: bool = True,
50
+ cross_attn_heads: int = 8,
51
+ cross_attn_layers: int = 2,
52
+ classifier_hidden: Sequence[int] | None = None,
53
+ dropout: float = 0.3,
54
+ ) -> None:
55
+ super().__init__()
56
+ hidden_dims = list(classifier_hidden) if classifier_hidden is not None else [256, 64]
57
+
58
+ self.encoder = RNAFMEncoder(freeze_backbone=freeze_backbone)
59
+ embed_dim = self.encoder.embed_dim
60
+
61
+ # Design decision: the interaction layer uses a smaller dropout (~1/3 of main dropout)
62
+ # to preserve attention signals while still providing basic regularization.
63
+ self.cross_attention = CrossAttentionBlock(
64
+ embed_dim=embed_dim,
65
+ num_heads=cross_attn_heads,
66
+ dropout=dropout * 0.33,
67
+ num_layers=cross_attn_layers,
68
+ )
69
+ self.classifier = MLPClassifier(
70
+ input_dim=embed_dim,
71
+ hidden_dims=hidden_dims,
72
+ dropout=dropout,
73
+ )
74
+
75
+ def forward(
76
+ self,
77
+ mirna_tokens: Tensor,
78
+ target_tokens: Tensor,
79
+ attention_mask_mirna: Tensor | None = None,
80
+ attention_mask_target: Tensor | None = None,
81
+ ) -> Tensor:
82
+ """
83
+ Forward pass (step by step):
84
+ 1) miRNA encoding: `(B, M_tok)` -> `(B, M, D)`
85
+ 2) target encoding: `(B, T_tok)` -> `(B, T, D)`
86
+ 3) Build key_padding_mask: attention_mask(1=real, 0=padding) -> (==0)
87
+ 4) Cross-Attention: target(Q) queries miRNA(K/V) -> `(B, T, D)`
88
+ 5) Masked mean pooling over target sequence -> `(B, D)`
89
+ 6) Classifier head outputs logits -> `(B, 1)`
90
+ """
91
+ # Step 1: Shared encoder processes miRNA (shared weights)
92
+ mirna_emb = self.encoder(mirna_tokens)
93
+
94
+ # Step 2: Same encoder processes target to ensure consistent representation space
95
+ target_emb = self.encoder(target_tokens)
96
+
97
+ # Step 3: PyTorch MHA key_padding_mask convention: True=ignore.
98
+ key_padding_mask = None
99
+ if attention_mask_mirna is not None:
100
+ key_padding_mask = attention_mask_mirna == 0
101
+
102
+ # Step 4: target as Query, miRNA as Key/Value.
103
+ cross_out = self.cross_attention(
104
+ query=target_emb,
105
+ key_value=mirna_emb,
106
+ key_padding_mask=key_padding_mask,
107
+ )
108
+
109
+ # Step 5: Masked mean pooling over target sequence to obtain a fixed-length representation.
110
+ if attention_mask_target is None:
111
+ pooling_mask = torch.ones(
112
+ cross_out.size(0),
113
+ cross_out.size(1),
114
+ 1,
115
+ device=cross_out.device,
116
+ dtype=cross_out.dtype,
117
+ )
118
+ else:
119
+ pooling_mask = attention_mask_target.to(dtype=cross_out.dtype).unsqueeze(-1)
120
+
121
+ summed = (cross_out * pooling_mask).sum(dim=1)
122
+ denom = pooling_mask.sum(dim=1).clamp_min(1e-6)
123
+ pooled = summed / denom
124
+
125
+ # Step 6: Output raw logits without applying sigmoid.
126
+ logits = self.classifier(pooled)
127
+ return logits
deepmirt/model/rnafm_encoder.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # pyright: basic, reportMissingImports=false
3
+ """
4
+ RNA-FM encoder wrapper (Shared Encoder).
5
+
6
+ Architecture diagram (single-path encoding):
7
+
8
+ Input tokens (B, L)
9
+ |
10
+ v
11
+ [RNA-FM: 12-layer Transformer]
12
+ |
13
+ v
14
+ representations[12] (B, L, D)
15
+ D is typically 640
16
+
17
+ Training strategy diagram (freeze / staged unfreezing):
18
+
19
+ Frozen phase: [L1][L2][L3]...[L12] all requires_grad=False
20
+ Unfrozen phase: [L1]...[L9][L10][L11][L12]
21
+ ^^^^^^^^
22
+ only unfreeze top N layers (e.g., N=3)
23
+
24
+ Notes:
25
+ - Both miRNA and target are RNA sequences, so sharing a single RNA-FM encoder is the most natural approach.
26
+ - `repr_layers=[12]` extracts the 12th (final) layer output as the contextualized representation.
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ from collections.abc import Sequence
32
+
33
+ import fm
34
+ from torch import Tensor, nn
35
+
36
+
37
+ class RNAFMEncoder(nn.Module):
38
+ """Lightweight wrapper around RNA-FM providing forward encoding, freezing, and staged unfreezing."""
39
+
40
+ def __init__(self, freeze_backbone: bool = True) -> None:
41
+ super().__init__()
42
+ self.model, self.alphabet = fm.pretrained.rna_fm_t12()
43
+ self.num_layers = len(self.model.layers)
44
+ self.embed_dim = self._infer_embed_dim(default=640)
45
+
46
+ # Design decision: freeze backbone by default to first stabilize training of the
47
+ # upper interaction module and classifier head, avoiding catastrophic forgetting
48
+ # from full fine-tuning on small datasets.
49
+ if freeze_backbone:
50
+ self.freeze()
51
+
52
+ def _infer_embed_dim(self, default: int = 640) -> int:
53
+ """Try to infer the embedding dimension from the RNA-FM model; fall back to default on failure."""
54
+ model_embed_dim = getattr(self.model, "embed_dim", None)
55
+ if model_embed_dim is not None:
56
+ return int(model_embed_dim)
57
+
58
+ model_args = getattr(self.model, "args", None)
59
+ if model_args is not None and hasattr(model_args, "embed_dim"):
60
+ return int(model_args.embed_dim)
61
+
62
+ embed_tokens = getattr(self.model, "embed_tokens", None)
63
+ if embed_tokens is not None and hasattr(embed_tokens, "embedding_dim"):
64
+ return int(embed_tokens.embedding_dim)
65
+
66
+ return int(default)
67
+
68
+ def forward(self, tokens: Tensor, repr_layers: Sequence[int] | None = None) -> Tensor:
69
+ """
70
+ Encode an RNA token sequence.
71
+
72
+ Args:
73
+ tokens: Token tensor of shape `(batch, seq_len)`.
74
+ repr_layers: List of layer indices to extract. Defaults to `[12]` (final layer).
75
+
76
+ Returns:
77
+ Contextualized representations of shape `(batch, seq_len, embed_dim)`.
78
+ """
79
+ if repr_layers is None:
80
+ # Design decision: use the final layer representation by default (most semantically
81
+ # complete), consistent with common pre-trained model usage.
82
+ repr_layers = [self.num_layers]
83
+
84
+ layer_ids = list(repr_layers)
85
+ if not layer_ids:
86
+ raise ValueError("repr_layers must not be empty; provide at least one layer index.")
87
+
88
+ outputs = self.model(tokens, repr_layers=layer_ids)
89
+ # Note: typically repr_layers=[12] is passed, so this retrieves representations[12].
90
+ final_layer_id = max(layer_ids)
91
+ return outputs["representations"][final_layer_id]
92
+
93
+ def freeze(self) -> None:
94
+ """Freeze all RNA-FM backbone parameters (requires_grad=False)."""
95
+ for param in self.model.parameters():
96
+ param.requires_grad = False
97
+
98
+ def unfreeze(self, num_layers: int = 3) -> None:
99
+ """
100
+ Unfreeze only the per-layer parameters of the top N Transformer layers.
101
+
102
+ Example: when `num_layers=3`, unfreezes layer[9], layer[10], layer[11].
103
+
104
+ Note: global LayerNorm (e.g., emb_layer_norm_after) is NOT unfrozen,
105
+ because unfreezing it would shift the output distribution of all layers at once,
106
+ leading to training instability.
107
+ """
108
+ # Design decision: always freeze all first, then selectively unfreeze, ensuring the
109
+ # set of trainable parameters is controllable and reproducible.
110
+ self.freeze()
111
+
112
+ n = max(0, min(int(num_layers), self.num_layers))
113
+ if n > 0:
114
+ start = self.num_layers - n
115
+ for layer in self.model.layers[start:]:
116
+ for param in layer.parameters():
117
+ param.requires_grad = True
deepmirt/predict.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Public prediction API for DeepMiRT.
4
+
5
+ Provides simple interfaces for miRNA-target interaction prediction:
6
+ - predict(): Python API for sequence pairs
7
+ - predict_from_csv(): Batch prediction from CSV files
8
+ - cli_main(): Command-line entry point
9
+
10
+ Model weights are automatically downloaded from Hugging Face Hub on first use.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import logging
17
+ import re
18
+ import sys
19
+ import warnings
20
+ from pathlib import Path
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Hugging Face Hub model repository
28
+ HF_REPO_ID = "liuliu2333/deepmirt"
29
+ HF_CKPT_FILENAME = "epoch=27-val_auroc=0.9612.ckpt"
30
+ HF_CONFIG_FILENAME = "config.yaml"
31
+
32
+ # Valid nucleotide characters (before T→U conversion)
33
+ _VALID_BASES = re.compile(r"^[AUGCTaugct]+$")
34
+
35
+ # Module-level model cache (avoids reloading 495 MB on every call)
36
+ _model_cache: dict = {}
37
+
38
+
39
+ def _get_model_files() -> tuple[str, str]:
40
+ """Download model checkpoint and config from Hugging Face Hub (cached locally)."""
41
+ from huggingface_hub import hf_hub_download
42
+
43
+ ckpt_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_CKPT_FILENAME)
44
+ config_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_CONFIG_FILENAME)
45
+ return ckpt_path, config_path
46
+
47
+
48
+ def _get_cached_model(device: str):
49
+ """Load model and alphabet, caching for subsequent calls."""
50
+ if device not in _model_cache:
51
+ import fm
52
+
53
+ from deepmirt.evaluation.predict import load_model_from_checkpoint
54
+
55
+ ckpt_path, config_path = _get_model_files()
56
+ logger.info("Loading DeepMiRT model (first call, will be cached)...")
57
+ lit_model, config = load_model_from_checkpoint(ckpt_path, config_path, device)
58
+ _, alphabet = fm.pretrained.rna_fm_t12()
59
+ _model_cache[device] = (lit_model, alphabet, ckpt_path, config_path)
60
+ logger.info("Model loaded and cached.")
61
+
62
+ return _model_cache[device]
63
+
64
+
65
+ def _validate_sequences(
66
+ mirna_seqs: list[str], target_seqs: list[str]
67
+ ) -> tuple[list[str], list[str]]:
68
+ """Validate and clean input sequences."""
69
+ cleaned_mirna = []
70
+ cleaned_target = []
71
+
72
+ for i, (m, t) in enumerate(zip(mirna_seqs, target_seqs)):
73
+ m = str(m).strip().upper()
74
+ t = str(t).strip().upper()
75
+
76
+ if not m:
77
+ raise ValueError(f"Empty miRNA sequence at index {i}")
78
+ if not t:
79
+ raise ValueError(f"Empty target sequence at index {i}")
80
+
81
+ if not _VALID_BASES.match(m):
82
+ invalid = set(m) - set("AUGCT")
83
+ raise ValueError(
84
+ f"miRNA at index {i} contains invalid characters: {invalid}. "
85
+ f"Only A/U/G/C/T are allowed."
86
+ )
87
+ if not _VALID_BASES.match(t):
88
+ invalid = set(t) - set("AUGCT")
89
+ raise ValueError(
90
+ f"Target at index {i} contains invalid characters: {invalid}. "
91
+ f"Only A/U/G/C/T are allowed."
92
+ )
93
+
94
+ cleaned_mirna.append(m)
95
+ cleaned_target.append(t)
96
+
97
+ # Warn about unusual lengths (non-blocking)
98
+ mirna_lens = [len(s) for s in cleaned_mirna]
99
+ target_lens = [len(s) for s in cleaned_target]
100
+ if any(n < 15 or n > 30 for n in mirna_lens):
101
+ warnings.warn(
102
+ "Some miRNA sequences have unusual length (expected 18-25 nt). "
103
+ "Results may be less reliable.",
104
+ stacklevel=3,
105
+ )
106
+ if any(n != 40 for n in target_lens):
107
+ warnings.warn(
108
+ "Some target sequences are not 40 nt. The model was trained on 40-nt "
109
+ "target fragments. Results may be less reliable for other lengths.",
110
+ stacklevel=3,
111
+ )
112
+
113
+ return cleaned_mirna, cleaned_target
114
+
115
+
116
+ def predict(
117
+ mirna_seqs: list[str],
118
+ target_seqs: list[str],
119
+ device: str = "cpu",
120
+ batch_size: int = 256,
121
+ ) -> np.ndarray:
122
+ """
123
+ Predict miRNA-target interaction probabilities.
124
+
125
+ Automatically downloads model weights from Hugging Face Hub on first call.
126
+ The model is cached in memory for subsequent calls.
127
+ Sequences can be in DNA (T) or RNA (U) format -- conversion is handled internally.
128
+
129
+ Args:
130
+ mirna_seqs: List of miRNA sequences (typically 18-25 nt).
131
+ target_seqs: List of target site sequences (40 nt recommended).
132
+ device: Inference device ("cpu" or "cuda").
133
+ batch_size: Batch size for inference.
134
+
135
+ Returns:
136
+ Numpy array of interaction probabilities, shape (n_samples,).
137
+ Values range from 0 (no interaction) to 1 (strong interaction).
138
+
139
+ Example:
140
+ >>> from deepmirt import predict
141
+ >>> probs = predict(
142
+ ... mirna_seqs=["UGAGGUAGUAGGUUGUAUAGUU"],
143
+ ... target_seqs=["ACUGCAGCAUAUCUACUAUUUGCUACUGUAACCAUUGAUCU"],
144
+ ... )
145
+ >>> print(f"Interaction probability: {probs[0]:.4f}")
146
+ """
147
+ if len(mirna_seqs) != len(target_seqs):
148
+ raise ValueError(
149
+ f"mirna_seqs and target_seqs must have the same length, "
150
+ f"got {len(mirna_seqs)} and {len(target_seqs)}"
151
+ )
152
+ if len(mirna_seqs) == 0:
153
+ return np.array([])
154
+
155
+ mirna_seqs, target_seqs = _validate_sequences(mirna_seqs, target_seqs)
156
+
157
+ from deepmirt.evaluation.predict import predict_on_sequences
158
+
159
+ lit_model, alphabet, ckpt_path, config_path = _get_cached_model(device)
160
+
161
+ return predict_on_sequences(
162
+ ckpt_path=ckpt_path,
163
+ config_path=config_path,
164
+ mirna_seqs=mirna_seqs,
165
+ target_seqs=target_seqs,
166
+ batch_size=batch_size,
167
+ device=device,
168
+ _lit_model=lit_model,
169
+ _alphabet=alphabet,
170
+ )
171
+
172
+
173
+ def predict_from_csv(
174
+ csv_path: str,
175
+ output_path: str | None = None,
176
+ device: str = "cpu",
177
+ batch_size: int = 256,
178
+ mirna_col: str = "mirna_seq",
179
+ target_col: str = "target_seq",
180
+ ) -> pd.DataFrame:
181
+ """
182
+ Batch prediction from a CSV file.
183
+
184
+ The CSV must contain columns for miRNA and target sequences.
185
+
186
+ Args:
187
+ csv_path: Path to input CSV file.
188
+ output_path: Path to save results CSV. If None, results are only returned.
189
+ device: Inference device ("cpu" or "cuda").
190
+ batch_size: Batch size for inference.
191
+ mirna_col: Column name for miRNA sequences.
192
+ target_col: Column name for target sequences.
193
+
194
+ Returns:
195
+ DataFrame with original columns plus 'probability' and 'prediction'.
196
+ """
197
+ df = pd.read_csv(csv_path)
198
+
199
+ if mirna_col not in df.columns or target_col not in df.columns:
200
+ raise ValueError(
201
+ f"CSV must contain columns '{mirna_col}' and '{target_col}'. "
202
+ f"Found columns: {list(df.columns)}"
203
+ )
204
+
205
+ mirna_seqs = df[mirna_col].astype(str).tolist()
206
+ target_seqs = df[target_col].astype(str).tolist()
207
+
208
+ probs = predict(mirna_seqs, target_seqs, device=device, batch_size=batch_size)
209
+
210
+ df["probability"] = probs
211
+ df["prediction"] = (probs >= 0.5).astype(int)
212
+
213
+ if output_path:
214
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
215
+ df.to_csv(output_path, index=False)
216
+ logger.info(f"Results saved to {output_path}")
217
+
218
+ return df
219
+
220
+
221
+ def scan_targets(
222
+ mirna_fasta: str | dict[str, str],
223
+ target_fasta: str,
224
+ output_prefix: str | None = None,
225
+ device: str = "cpu",
226
+ batch_size: int = 512,
227
+ prob_threshold: float = 0.5,
228
+ scan_mode: str = "hybrid",
229
+ stride: int = 20,
230
+ top_k: int | None = None,
231
+ ) -> list:
232
+ """
233
+ Scan target sequences for miRNA binding sites genome-wide.
234
+
235
+ Identifies candidate binding positions using seed matching and/or sliding
236
+ windows, then scores each position with the DeepMiRT model.
237
+
238
+ Args:
239
+ mirna_fasta: Path to miRNA FASTA file, or dict of {id: sequence}.
240
+ target_fasta: Path to target FASTA file (e.g. 3'UTRs or transcripts).
241
+ output_prefix: If given, write {prefix}_details.txt, {prefix}_hits.tsv,
242
+ and {prefix}_summary.tsv.
243
+ device: Inference device ("cpu" or "cuda").
244
+ batch_size: Batch size for GPU inference.
245
+ prob_threshold: Minimum probability to report a hit (default 0.5).
246
+ scan_mode: Scanning strategy -- "seed" (fastest), "hybrid" (default),
247
+ or "exhaustive" (slowest, stride-1).
248
+ stride: Window stride for hybrid/exhaustive modes (default 20).
249
+ top_k: If set, keep only the top-K hits per miRNA-target pair.
250
+
251
+ Returns:
252
+ List of TargetScanResult objects, one per miRNA-target pair with hits.
253
+ Each result contains a list of ScanHit objects with position, probability,
254
+ seed type, and the 40nt window sequence.
255
+
256
+ Example:
257
+ >>> from deepmirt import scan_targets
258
+ >>> results = scan_targets(
259
+ ... mirna_fasta={"let-7": "UGAGGUAGUAGGUUGUAUAGUU"},
260
+ ... target_fasta="3utrs.fa",
261
+ ... output_prefix="results/scan",
262
+ ... device="cuda",
263
+ ... )
264
+ >>> for r in results:
265
+ ... for hit in r.hits:
266
+ ... print(f"{r.target_id} pos={hit.position} prob={hit.probability:.3f}")
267
+ """
268
+ from deepmirt.scanning.scanner import TargetScanner
269
+
270
+ scanner = TargetScanner(
271
+ device=device,
272
+ batch_size=batch_size,
273
+ prob_threshold=prob_threshold,
274
+ scan_mode=scan_mode,
275
+ stride=stride,
276
+ top_k=top_k,
277
+ )
278
+ return scanner.scan(mirna_fasta, target_fasta, output_prefix)
279
+
280
+
281
+ def cli_main() -> None:
282
+ """Command-line entry point for deepmirt-predict."""
283
+ parser = argparse.ArgumentParser(
284
+ prog="deepmirt-predict",
285
+ description="DeepMiRT: Predict miRNA-target interactions",
286
+ )
287
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
288
+
289
+ # Single prediction
290
+ single = subparsers.add_parser("single", help="Predict a single miRNA-target pair")
291
+ single.add_argument("--mirna", required=True, help="miRNA sequence")
292
+ single.add_argument("--target", required=True, help="Target sequence (40 nt)")
293
+ single.add_argument("--device", default="cpu", help="Device (cpu or cuda)")
294
+
295
+ # Batch prediction
296
+ batch = subparsers.add_parser("batch", help="Batch prediction from CSV")
297
+ batch.add_argument("--input", required=True, help="Input CSV path")
298
+ batch.add_argument("--output", required=True, help="Output CSV path")
299
+ batch.add_argument("--device", default="cpu", help="Device (cpu or cuda)")
300
+ batch.add_argument("--batch-size", type=int, default=256, help="Batch size")
301
+ batch.add_argument("--mirna-col", default="mirna_seq", help="miRNA column name")
302
+ batch.add_argument("--target-col", default="target_seq", help="Target column name")
303
+
304
+ # Genome-wide scanning
305
+ scan = subparsers.add_parser(
306
+ "scan", help="Scan target sequences for miRNA binding sites"
307
+ )
308
+ scan_input = scan.add_mutually_exclusive_group(required=True)
309
+ scan_input.add_argument("--mirna-fasta", help="miRNA FASTA file")
310
+ scan_input.add_argument("--mirna", help="Single miRNA sequence (use with --mirna-id)")
311
+ scan.add_argument("--mirna-id", default="query_mirna", help="miRNA ID (with --mirna)")
312
+ scan.add_argument("--target-fasta", required=True, help="Target FASTA file")
313
+ scan.add_argument("--output", required=True, help="Output prefix")
314
+ scan.add_argument("--device", default="cpu", help="Device (cpu or cuda)")
315
+ scan.add_argument("--batch-size", type=int, default=512, help="Batch size")
316
+ scan.add_argument("--threshold", type=float, default=0.5, help="Probability threshold")
317
+ scan.add_argument(
318
+ "--scan-mode", default="hybrid", choices=["seed", "hybrid", "exhaustive"],
319
+ help="Scanning mode (default: hybrid)",
320
+ )
321
+ scan.add_argument("--stride", type=int, default=20, help="Window stride (default: 20)")
322
+ scan.add_argument("--top-k", type=int, default=None, help="Keep top-K hits per target")
323
+
324
+ args = parser.parse_args()
325
+
326
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
327
+
328
+ if args.command == "single":
329
+ probs = predict([args.mirna], [args.target], device=args.device)
330
+ prob = probs[0]
331
+ label = "INTERACTION" if prob >= 0.5 else "NO INTERACTION"
332
+ print(f"Probability: {prob:.4f}")
333
+ print(f"Prediction: {label}")
334
+ elif args.command == "batch":
335
+ df = predict_from_csv(
336
+ csv_path=args.input,
337
+ output_path=args.output,
338
+ device=args.device,
339
+ batch_size=args.batch_size,
340
+ mirna_col=args.mirna_col,
341
+ target_col=args.target_col,
342
+ )
343
+ print(f"Processed {len(df)} samples. Results saved to {args.output}")
344
+ elif args.command == "scan":
345
+ if args.mirna:
346
+ mirna_input: str | dict[str, str] = {args.mirna_id: args.mirna}
347
+ else:
348
+ mirna_input = args.mirna_fasta
349
+
350
+ results = scan_targets(
351
+ mirna_fasta=mirna_input,
352
+ target_fasta=args.target_fasta,
353
+ output_prefix=args.output,
354
+ device=args.device,
355
+ batch_size=args.batch_size,
356
+ prob_threshold=args.threshold,
357
+ scan_mode=args.scan_mode,
358
+ stride=args.stride,
359
+ top_k=args.top_k,
360
+ )
361
+ total_hits = sum(len(r.hits) for r in results)
362
+ print(
363
+ f"Scan complete: {len(results)} miRNA-target pairs, "
364
+ f"{total_hits} hits above threshold {args.threshold}"
365
+ )
366
+ print(f"Results: {args.output}_details.txt, _hits.tsv, _summary.tsv")
367
+ else:
368
+ parser.print_help()
369
+ sys.exit(1)
370
+
371
+
372
+ if __name__ == "__main__":
373
+ cli_main()
deepmirt/training/__init__.py ADDED
File without changes
deepmirt/training/lightning_module.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # pyright: basic, reportMissingImports=false
3
+ """
4
+ PyTorch Lightning training module for miRNA-target prediction.
5
+
6
+ [Lightning Training Loop Overview -- Full Lifecycle of One Epoch]
7
+
8
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
9
+ β”‚ Lifecycle of One Epoch β”‚
10
+ β”‚ β”‚
11
+ β”‚ on_train_epoch_start() β”‚
12
+ β”‚ β”‚ β”‚
13
+ β”‚ v β”‚
14
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
15
+ β”‚ β”‚ for batch in train_dataloader: β”‚ β”‚
16
+ β”‚ β”‚ training_step(batch) β”‚ ← forward + loss β”‚
17
+ β”‚ β”‚ backward() [automatic] β”‚ ← backpropagation β”‚
18
+ β”‚ β”‚ optimizer.step() [automatic] β”‚ ← update params β”‚
19
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
20
+ β”‚ β”‚ β”‚
21
+ β”‚ v β”‚
22
+ β”‚ on_train_epoch_end() β”‚
23
+ β”‚ β”‚ β”‚
24
+ β”‚ v β”‚
25
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
26
+ β”‚ β”‚ for batch in val_dataloader: β”‚ β”‚
27
+ β”‚ β”‚ validation_step(batch) β”‚ ← forward only, no β”‚
28
+ β”‚ β”‚ β”‚ param updates β”‚
29
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
30
+ β”‚ β”‚ β”‚
31
+ β”‚ v β”‚
32
+ β”‚ on_validation_epoch_end() β”‚
33
+ β”‚ β”‚ β”‚
34
+ β”‚ v β”‚
35
+ β”‚ lr_scheduler.step() [automatic] β”‚
36
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
37
+
38
+ Things Lightning handles automatically (no manual code needed):
39
+ - loss.backward()
40
+ - optimizer.zero_grad()
41
+ - optimizer.step()
42
+ - Switching to model.eval() and torch.no_grad() during validation
43
+ - Gradient accumulation (if accumulate_grad_batches is configured)
44
+ - Multi-GPU distributed synchronization (if using DDP)
45
+
46
+ You only need to focus on:
47
+ - training_step(): return the loss
48
+ - validation_step(): compute validation metrics
49
+ - configure_optimizers(): define the optimizer and learning rate scheduler
50
+
51
+ [Key Design Decisions]
52
+
53
+ 1. BCEWithLogitsLoss vs BCELoss:
54
+ - BCEWithLogitsLoss = Sigmoid + BCELoss, using the log-sum-exp trick internally
55
+ - Numerical stability: directly computing log(sigmoid(x)) can produce log(0) at
56
+ extreme values. BCEWithLogitsLoss uses the equivalent formula
57
+ max(x,0) - x*y + log(1+exp(-|x|)) to avoid overflow
58
+ - Therefore the model outputs raw logits (no sigmoid); the loss function handles it
59
+
60
+ 2. Differential Learning Rate:
61
+ - Backbone (RNA-FM): base_lr x 0.01 -- pretrained weights encode rich RNA knowledge;
62
+ a large learning rate would cause catastrophic forgetting of this knowledge
63
+ - Cross-attention layers: base_lr x 0.1 -- new module but needs stable attention
64
+ pattern learning
65
+ - Classifier head: base_lr x 1.0 -- learning from scratch, needs the highest
66
+ learning rate for fast convergence
67
+
68
+ 3. Evaluation Metric Selection:
69
+ - AUROC (Area Under ROC Curve): measures the model's ranking ability, i.e., the
70
+ probability of ranking a positive sample above a negative one. Threshold-independent.
71
+ - AUPRC (Average Precision / PR-AUC): measures the precision-recall tradeoff;
72
+ more sensitive than AUROC on class-imbalanced data (biological data often has
73
+ positive:negative ratios of 1:10+)
74
+ - Accuracy: intuitive but can be misleading on imbalanced data (predicting all
75
+ negatives still yields 90% accuracy)
76
+ - F1: harmonic mean of precision and recall, balancing both
77
+
78
+ 4. Logging Strategy -- on_step=False, on_epoch=True:
79
+ - Training loss: fluctuates heavily per step; step-level logging aids debugging
80
+ - Evaluation metrics: require full epoch data to be statistically meaningful,
81
+ hence on_epoch=True
82
+ - prog_bar=True: displays key metrics in the training progress bar for real-time
83
+ monitoring
84
+ """
85
+
86
+ from __future__ import annotations
87
+
88
+ import pytorch_lightning as pl
89
+ import torch
90
+ import torchmetrics
91
+ from torch import nn
92
+ from torch.optim.lr_scheduler import CosineAnnealingLR
93
+
94
+ from deepmirt.model.mirna_target_model import MiRNATargetModel
95
+
96
+
97
+ class MiRNATargetLitModule(pl.LightningModule):
98
+ """
99
+ Lightning training module for miRNA-target binary classification prediction.
100
+
101
+ Responsibilities:
102
+ - Wraps MiRNATargetModel, managing forward pass / loss / metric computation
103
+ - Configures optimizer with differential learning rates and LR scheduler
104
+ - Provides training_step / validation_step / test_step
105
+
106
+ Args:
107
+ config: Nested dictionary with the following structure:
108
+ {
109
+ 'model': {
110
+ 'freeze_backbone': bool,
111
+ 'cross_attn_heads': int,
112
+ 'cross_attn_layers': int,
113
+ 'classifier_hidden': list[int],
114
+ 'dropout': float,
115
+ },
116
+ 'training': {
117
+ 'lr': float, # base learning rate (used by classifier head)
118
+ 'weight_decay': float, # L2 regularization coefficient
119
+ 'scheduler': str, # 'cosine' or 'onecycle'
120
+ 'max_epochs': int, # total training epochs (needed by scheduler)
121
+ }
122
+ }
123
+ """
124
+
125
+ def __init__(self, config: dict) -> None:
126
+ super().__init__()
127
+
128
+ # Save hyperparameters to the checkpoint for restoring the full config on reload
129
+ # Design decision: save_hyperparameters ensures reproducibility -- checkpoint carries the full config
130
+ self.save_hyperparameters(config)
131
+ self.config = config
132
+
133
+ # ── Extract model parameters from config and instantiate ──
134
+ model_cfg = config["model"]
135
+ self.model = MiRNATargetModel(
136
+ freeze_backbone=model_cfg.get("freeze_backbone", True),
137
+ cross_attn_heads=model_cfg.get("cross_attn_heads", 8),
138
+ cross_attn_layers=model_cfg.get("cross_attn_layers", 2),
139
+ classifier_hidden=model_cfg.get("classifier_hidden", [256, 64]),
140
+ dropout=model_cfg.get("dropout", 0.3),
141
+ )
142
+
143
+ # ── Loss function ──
144
+ # Design decision: BCEWithLogitsLoss is more numerically stable than sigmoid + BCELoss.
145
+ # Internal formula: loss = max(logit, 0) - logit * label + log(1 + exp(-|logit|))
146
+ # This formula avoids numerical overflow from log(sigmoid(x)) at extreme values of x.
147
+ self.loss_fn = nn.BCEWithLogitsLoss()
148
+
149
+ # ── Training metrics ──
150
+ # torchmetrics automatically handles metric aggregation in distributed settings (DDP sync)
151
+ self.train_auroc = torchmetrics.AUROC(task="binary")
152
+
153
+ # ── Validation metrics ──
154
+ self.val_auroc = torchmetrics.AUROC(task="binary")
155
+ self.val_auprc = torchmetrics.AveragePrecision(task="binary")
156
+ self.val_acc = torchmetrics.Accuracy(task="binary")
157
+ self.val_f1 = torchmetrics.F1Score(task="binary")
158
+
159
+ # ── Test metrics (same as validation, but separate instances to avoid state contamination) ──
160
+ self.test_auroc = torchmetrics.AUROC(task="binary")
161
+ self.test_auprc = torchmetrics.AveragePrecision(task="binary")
162
+ self.test_acc = torchmetrics.Accuracy(task="binary")
163
+ self.test_f1 = torchmetrics.F1Score(task="binary")
164
+
165
+ # ─────────────────────────────────────────────────────────────
166
+ # Training step
167
+ # ─────────────────────────────────────────────────────────────
168
+
169
+ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
170
+ """
171
+ Single training step: forward pass -> compute loss -> update metrics.
172
+
173
+ Lightning automatically calls backward() and optimizer.step() on the returned loss.
174
+ There is no need to manually call loss.backward() or optimizer.zero_grad().
175
+
176
+ Args:
177
+ batch: Dictionary output from the DataModule collate_fn, containing:
178
+ - mirna_tokens: (B, max_mirna_len)
179
+ - target_tokens: (B, 42)
180
+ - labels: (B,) float32
181
+ - attention_mask_mirna: (B, max_mirna_len)
182
+ - attention_mask_target: (B, 42)
183
+ batch_idx: Index of the current batch (automatically passed by Lightning)
184
+
185
+ Returns:
186
+ loss: Scalar tensor; Lightning automatically backpropagates through it
187
+ """
188
+ # Step 1: Extract inputs from the batch dictionary
189
+ mirna_tokens = batch["mirna_tokens"]
190
+ target_tokens = batch["target_tokens"]
191
+ labels = batch["labels"]
192
+ attention_mask_mirna = batch["attention_mask_mirna"]
193
+ attention_mask_target = batch["attention_mask_target"]
194
+
195
+ # Step 2: Forward pass -> logits shape (B, 1)
196
+ logits = self.model(
197
+ mirna_tokens, target_tokens, attention_mask_mirna, attention_mask_target
198
+ )
199
+
200
+ # Step 3: Compute loss
201
+ # squeeze(-1) reduces logits from (B, 1) to (B,), aligning with labels (B,)
202
+ loss = self.loss_fn(logits.squeeze(-1), labels)
203
+
204
+ # Step 4: Compute prediction probabilities and update metrics
205
+ # Note: sigmoid is only used for metric computation, not for the loss (BCEWithLogitsLoss includes sigmoid internally)
206
+ probs = torch.sigmoid(logits.squeeze(-1))
207
+ self.train_auroc(probs, labels.long())
208
+
209
+ # Step 5: Logging
210
+ # Design decision: train_loss uses on_step=True to monitor convergence trends,
211
+ # train_auroc uses on_epoch=True because per-step AUROC has little statistical significance.
212
+ self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
213
+ self.log(
214
+ "train_auroc",
215
+ self.train_auroc,
216
+ on_step=False,
217
+ on_epoch=True,
218
+ prog_bar=True,
219
+ )
220
+
221
+ return loss
222
+
223
+ # ─────────────────────────────────────────────────────────────
224
+ # Validation step
225
+ # ─────────────────────────────────────────────────────────────
226
+
227
+ def validation_step(self, batch: dict, batch_idx: int) -> None:
228
+ """
229
+ Single validation step: forward pass -> compute loss and full metric suite.
230
+
231
+ Lightning automatically handles the following during validation:
232
+ - Switches to model.eval() mode (disables Dropout, uses running mean for BatchNorm)
233
+ - Wraps in torch.no_grad(), skipping gradient computation to save memory
234
+
235
+ Args:
236
+ batch: Same as training_step
237
+ batch_idx: Index of the current batch
238
+ """
239
+ mirna_tokens = batch["mirna_tokens"]
240
+ target_tokens = batch["target_tokens"]
241
+ labels = batch["labels"]
242
+ attention_mask_mirna = batch["attention_mask_mirna"]
243
+ attention_mask_target = batch["attention_mask_target"]
244
+
245
+ logits = self.model(
246
+ mirna_tokens, target_tokens, attention_mask_mirna, attention_mask_target
247
+ )
248
+
249
+ loss = self.loss_fn(logits.squeeze(-1), labels)
250
+ probs = torch.sigmoid(logits.squeeze(-1))
251
+
252
+ # Update all validation metrics
253
+ self.val_auroc(probs, labels.long())
254
+ self.val_auprc(probs, labels.long())
255
+ self.val_acc(probs, labels.long())
256
+ self.val_f1(probs, labels.long())
257
+
258
+ # Design decision: all validation metrics use on_epoch=True, as they need full data to be statistically meaningful
259
+ # sync_dist=True automatically aggregates metrics across GPUs in multi-GPU settings
260
+ self.log("val_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True)
261
+ self.log("val_auroc", self.val_auroc, on_epoch=True, prog_bar=True)
262
+ self.log("val_auprc", self.val_auprc, on_epoch=True)
263
+ self.log("val_acc", self.val_acc, on_epoch=True)
264
+ self.log("val_f1", self.val_f1, on_epoch=True)
265
+
266
+ # ─────────────────────────────────────────────────────────────
267
+ # Test step
268
+ # ─────────────────────────────────────────────────────────────
269
+
270
+ def test_step(self, batch: dict, batch_idx: int) -> None:
271
+ """
272
+ Single test step: same logic as validation_step, using separate test metric instances.
273
+
274
+ Test metrics are instantiated separately from validation metrics to avoid state
275
+ contamination. For example, val_auroc resets at the end of each validation epoch,
276
+ while test_auroc is only used when trainer.test() is called.
277
+ """
278
+ mirna_tokens = batch["mirna_tokens"]
279
+ target_tokens = batch["target_tokens"]
280
+ labels = batch["labels"]
281
+ attention_mask_mirna = batch["attention_mask_mirna"]
282
+ attention_mask_target = batch["attention_mask_target"]
283
+
284
+ logits = self.model(
285
+ mirna_tokens, target_tokens, attention_mask_mirna, attention_mask_target
286
+ )
287
+
288
+ loss = self.loss_fn(logits.squeeze(-1), labels)
289
+ probs = torch.sigmoid(logits.squeeze(-1))
290
+
291
+ # Update test metrics
292
+ self.test_auroc(probs, labels.long())
293
+ self.test_auprc(probs, labels.long())
294
+ self.test_acc(probs, labels.long())
295
+ self.test_f1(probs, labels.long())
296
+
297
+ self.log("test_loss", loss, on_epoch=True, sync_dist=True)
298
+ self.log("test_auroc", self.test_auroc, on_epoch=True)
299
+ self.log("test_auprc", self.test_auprc, on_epoch=True)
300
+ self.log("test_acc", self.test_acc, on_epoch=True)
301
+ self.log("test_f1", self.test_f1, on_epoch=True)
302
+
303
+ # ─────────────────────────────────────────────────────────────
304
+ # Optimizer and learning rate scheduling
305
+ # ─────────────────────────────────────────────────────────────
306
+
307
+ def configure_optimizers(self) -> dict:
308
+ """
309
+ Configure AdamW optimizer with differential learning rates and cosine annealing scheduler.
310
+
311
+ [Differential Learning Rates -- Why use different learning rates for different modules?]
312
+
313
+ Module Learning Rate Reason
314
+ ───────────── ───────────── ──────────────────────────────────
315
+ RNA-FM backbone base_lrΓ—0.01 Pretrained weights contain rich RNA structure/sequence
316
+ knowledge; a large LR would destroy this knowledge
317
+ (catastrophic forgetting)
318
+ Cross-attention base_lrΓ—0.1 Newly initialized module, but needs to stably learn
319
+ miRNA-target attention patterns
320
+ Classifier head base_lrΓ—1.0 Learns the binary classification decision boundary
321
+ from scratch; needs the highest LR for fast convergence
322
+
323
+ Design decision: The LR ratios [0.01, 0.1, 1.0] follow common transfer learning practice;
324
+ the paper "Universal Language Model Fine-tuning" (Howard & Ruder, 2018)
325
+ calls this "discriminative fine-tuning".
326
+
327
+ [CosineAnnealingLR Scheduler]
328
+ The learning rate decays from its initial value toward 0 following a cosine curve:
329
+ lr(t) = lr_min + 0.5 * (lr_max - lr_min) * (1 + cos(pi * t / T_max))
330
+ Advantage: fast learning early on, fine-grained adjustment later, avoiding instability
331
+ from sudden LR drops.
332
+
333
+ Returns:
334
+ Dictionary containing the optimizer and lr_scheduler
335
+ """
336
+ training_cfg = self.config["training"]
337
+ base_lr = training_cfg["lr"]
338
+ weight_decay = training_cfg.get("weight_decay", 1e-5)
339
+ scheduler_type = training_cfg.get("scheduler", "cosine")
340
+ max_epochs = training_cfg.get("max_epochs", 30)
341
+
342
+ # Design decision: 3 parameter groups correspond to the model's 3 semantic modules;
343
+ # learning rates decrease from downstream to upstream (farther from the task = smaller LR).
344
+ param_groups = [
345
+ {
346
+ "params": list(self.model.encoder.parameters()),
347
+ "lr": base_lr * 0.01,
348
+ "name": "backbone",
349
+ },
350
+ {
351
+ "params": list(self.model.cross_attention.parameters()),
352
+ "lr": base_lr * 0.1,
353
+ "name": "cross_attention",
354
+ },
355
+ {
356
+ "params": list(self.model.classifier.parameters()),
357
+ "lr": base_lr,
358
+ "name": "classifier",
359
+ },
360
+ ]
361
+
362
+ optimizer = torch.optim.AdamW(param_groups, weight_decay=weight_decay)
363
+
364
+ # Design decision: CosineAnnealingLR is a safe default choice --
365
+ # it does not require knowing total steps (unlike OneCycleLR), and provides smooth decay.
366
+ if scheduler_type == "cosine":
367
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
368
+ elif scheduler_type == "onecycle":
369
+ # OneCycleLR requires total_steps = steps_per_epoch * max_epochs,
370
+ # but at the configure_optimizers stage the DataLoader has not been created yet,
371
+ # so steps_per_epoch is unavailable. Therefore, fall back to CosineAnnealingLR.
372
+ # If OneCycleLR is needed, it should be configured in train.py via the Trainer's
373
+ # estimated_stepping_batches.
374
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
375
+ else:
376
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
377
+
378
+ return {
379
+ "optimizer": optimizer,
380
+ "lr_scheduler": {
381
+ "scheduler": scheduler,
382
+ # Design decision: interval='epoch' adjusts the learning rate once per epoch,
383
+ # which is more stable than 'step' (adjusting after every batch), suitable for small to medium datasets.
384
+ "interval": "epoch",
385
+ },
386
+ }
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ rna-fm
3
+ pytorch-lightning>=2.0
4
+ torchmetrics
5
+ pyyaml
6
+ scikit-learn
7
+ numpy
8
+ pandas
9
+ huggingface-hub