Spaces:
Sleeping
Sleeping
Commit Β·
fc481db
1
Parent(s): 9201562
Deploy DeepMiRT Gradio demo with model code
Browse files- README.md +6 -7
- app.py +360 -0
- deepmirt/__init__.py +5 -0
- deepmirt/data_module/__init__.py +0 -0
- deepmirt/data_module/datamodule.py +239 -0
- deepmirt/data_module/dataset.py +227 -0
- deepmirt/data_module/preprocessing.py +251 -0
- deepmirt/evaluation/__init__.py +1 -0
- deepmirt/evaluation/predict.py +297 -0
- deepmirt/model/__init__.py +0 -0
- deepmirt/model/classifier.py +77 -0
- deepmirt/model/cross_attention.py +115 -0
- deepmirt/model/mirna_target_model.py +127 -0
- deepmirt/model/rnafm_encoder.py +117 -0
- deepmirt/predict.py +373 -0
- deepmirt/training/__init__.py +0 -0
- deepmirt/training/lightning_module.py +386 -0
- requirements.txt +9 -0
README.md
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
|
|
|
| 11 |
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: DeepMiRT
|
| 3 |
+
emoji: π§¬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.23.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
+
short_description: miRNA target prediction with RNA foundation models
|
| 12 |
---
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
DeepMiRT Web Demo β Gradio interface for miRNA-target interaction prediction.
|
| 4 |
+
|
| 5 |
+
Run locally:
|
| 6 |
+
python app.py
|
| 7 |
+
|
| 8 |
+
Deploy on Hugging Face Spaces:
|
| 9 |
+
Set sdk: gradio in the Space README.md metadata.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import re
|
| 16 |
+
import tempfile
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import gradio as gr
|
| 20 |
+
import numpy as np
|
| 21 |
+
import pandas as pd
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Global model (loaded once at startup)
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
_model = None
|
| 31 |
+
_alphabet = None
|
| 32 |
+
_config = None
|
| 33 |
+
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _load_model():
|
| 37 |
+
"""Load model from Hugging Face Hub (cached after first download)."""
|
| 38 |
+
global _model, _alphabet, _config
|
| 39 |
+
|
| 40 |
+
if _model is not None:
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
import fm
|
| 44 |
+
import torch
|
| 45 |
+
from huggingface_hub import hf_hub_download
|
| 46 |
+
|
| 47 |
+
from deepmirt.evaluation.predict import load_model_from_checkpoint
|
| 48 |
+
|
| 49 |
+
repo_id = "liuliu2333/deepmirt"
|
| 50 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename="epoch=27-val_auroc=0.9612.ckpt")
|
| 51 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
|
| 52 |
+
|
| 53 |
+
logger.info("Loading model...")
|
| 54 |
+
_model, _config = load_model_from_checkpoint(ckpt_path, config_path, device=_device)
|
| 55 |
+
_, _alphabet = fm.pretrained.rna_fm_t12()
|
| 56 |
+
logger.info("Model loaded successfully.")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
# Validation helpers
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
_VALID_BASES = set("AUGC")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _validate_seq(seq: str, name: str, min_len: int = 1, max_len: int = 200) -> str:
|
| 66 |
+
"""Validate and clean an RNA/DNA sequence."""
|
| 67 |
+
seq = seq.strip().upper().replace("T", "U")
|
| 68 |
+
if not seq:
|
| 69 |
+
raise gr.Error(f"{name} sequence is empty.")
|
| 70 |
+
if len(seq) < min_len or len(seq) > max_len:
|
| 71 |
+
raise gr.Error(f"{name} must be {min_len}-{max_len} nt, got {len(seq)} nt.")
|
| 72 |
+
invalid = set(seq) - _VALID_BASES
|
| 73 |
+
if invalid:
|
| 74 |
+
raise gr.Error(f"{name} contains invalid characters: {invalid}. Only A/U/G/C/T allowed.")
|
| 75 |
+
return seq
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
# Prediction logic
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
def _predict_pair(mirna_seq: str, target_seq: str) -> np.ndarray:
|
| 82 |
+
"""Run model inference on a single pair."""
|
| 83 |
+
import torch
|
| 84 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 85 |
+
|
| 86 |
+
_load_model()
|
| 87 |
+
|
| 88 |
+
batch_converter = _alphabet.get_batch_converter()
|
| 89 |
+
padding_idx = _alphabet.padding_idx
|
| 90 |
+
|
| 91 |
+
_, _, m_tok = batch_converter([("m", mirna_seq)])
|
| 92 |
+
_, _, t_tok = batch_converter([("t", target_seq)])
|
| 93 |
+
|
| 94 |
+
mirna_padded = pad_sequence([m_tok[0]], batch_first=True, padding_value=padding_idx)
|
| 95 |
+
target_stacked = torch.stack([t_tok[0]])
|
| 96 |
+
|
| 97 |
+
attn_mask_mirna = (mirna_padded != padding_idx).long().to(_device)
|
| 98 |
+
attn_mask_target = torch.ones_like(target_stacked, dtype=torch.long).to(_device)
|
| 99 |
+
mirna_padded = mirna_padded.to(_device)
|
| 100 |
+
target_stacked = target_stacked.to(_device)
|
| 101 |
+
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
logits = _model.model(mirna_padded, target_stacked, attn_mask_mirna, attn_mask_target)
|
| 104 |
+
prob = torch.sigmoid(logits.squeeze(-1)).cpu().numpy()
|
| 105 |
+
return prob
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def predict_single(mirna_seq: str, target_seq: str):
|
| 109 |
+
"""Gradio callback for single prediction."""
|
| 110 |
+
mirna_rna = _validate_seq(mirna_seq, "miRNA", min_len=15, max_len=30)
|
| 111 |
+
target_rna = _validate_seq(target_seq, "Target", min_len=20, max_len=50)
|
| 112 |
+
|
| 113 |
+
prob = _predict_pair(mirna_rna, target_rna)
|
| 114 |
+
p = float(prob[0])
|
| 115 |
+
label = "INTERACTION" if p >= 0.5 else "NO INTERACTION"
|
| 116 |
+
color = "#2ecc71" if p >= 0.5 else "#e74c3c"
|
| 117 |
+
details = {
|
| 118 |
+
"probability": round(p, 6),
|
| 119 |
+
"prediction": label,
|
| 120 |
+
"threshold": 0.5,
|
| 121 |
+
"mirna_length": len(mirna_rna),
|
| 122 |
+
"target_length": len(target_rna),
|
| 123 |
+
}
|
| 124 |
+
return (
|
| 125 |
+
f"<div style='text-align:center;padding:20px;'>"
|
| 126 |
+
f"<span style='font-size:48px;font-weight:bold;color:{color};'>{p:.4f}</span><br>"
|
| 127 |
+
f"<span style='font-size:20px;color:{color};'>{label}</span></div>"
|
| 128 |
+
), details
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def predict_batch(file):
|
| 132 |
+
"""Gradio callback for batch prediction."""
|
| 133 |
+
if file is None:
|
| 134 |
+
raise gr.Error("Please upload a CSV file.")
|
| 135 |
+
|
| 136 |
+
_load_model()
|
| 137 |
+
|
| 138 |
+
df = pd.read_csv(file.name)
|
| 139 |
+
|
| 140 |
+
mirna_col = None
|
| 141 |
+
target_col = None
|
| 142 |
+
for col in df.columns:
|
| 143 |
+
cl = col.lower().strip()
|
| 144 |
+
if "mirna" in cl:
|
| 145 |
+
mirna_col = col
|
| 146 |
+
elif "target" in cl:
|
| 147 |
+
target_col = col
|
| 148 |
+
|
| 149 |
+
if mirna_col is None or target_col is None:
|
| 150 |
+
raise gr.Error(
|
| 151 |
+
"CSV must contain a column with 'mirna' and a column with 'target' in the name. "
|
| 152 |
+
f"Found columns: {list(df.columns)}"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
mirna_seqs = df[mirna_col].astype(str).tolist()
|
| 156 |
+
target_seqs = df[target_col].astype(str).tolist()
|
| 157 |
+
|
| 158 |
+
# Validate and convert
|
| 159 |
+
cleaned_mirna = []
|
| 160 |
+
cleaned_target = []
|
| 161 |
+
for i, (m, t) in enumerate(zip(mirna_seqs, target_seqs)):
|
| 162 |
+
m = m.strip().upper().replace("T", "U")
|
| 163 |
+
t = t.strip().upper().replace("T", "U")
|
| 164 |
+
invalid_m = set(m) - _VALID_BASES
|
| 165 |
+
invalid_t = set(t) - _VALID_BASES
|
| 166 |
+
if invalid_m or invalid_t:
|
| 167 |
+
raise gr.Error(f"Row {i}: invalid characters in sequences.")
|
| 168 |
+
cleaned_mirna.append(m)
|
| 169 |
+
cleaned_target.append(t)
|
| 170 |
+
|
| 171 |
+
# Batch inference
|
| 172 |
+
import torch
|
| 173 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 174 |
+
|
| 175 |
+
batch_converter = _alphabet.get_batch_converter()
|
| 176 |
+
padding_idx = _alphabet.padding_idx
|
| 177 |
+
all_probs = []
|
| 178 |
+
batch_size = 128
|
| 179 |
+
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
for start in range(0, len(cleaned_mirna), batch_size):
|
| 182 |
+
batch_m = cleaned_mirna[start : start + batch_size]
|
| 183 |
+
batch_t = cleaned_target[start : start + batch_size]
|
| 184 |
+
|
| 185 |
+
m_toks = []
|
| 186 |
+
t_toks = []
|
| 187 |
+
for ms, ts in zip(batch_m, batch_t):
|
| 188 |
+
_, _, mt = batch_converter([("m", ms)])
|
| 189 |
+
_, _, tt = batch_converter([("t", ts)])
|
| 190 |
+
m_toks.append(mt[0])
|
| 191 |
+
t_toks.append(tt[0])
|
| 192 |
+
|
| 193 |
+
mirna_padded = pad_sequence(m_toks, batch_first=True, padding_value=padding_idx)
|
| 194 |
+
target_stacked = torch.stack(t_toks)
|
| 195 |
+
attn_mask_mirna = (mirna_padded != padding_idx).long().to(_device)
|
| 196 |
+
attn_mask_target = torch.ones_like(target_stacked, dtype=torch.long).to(_device)
|
| 197 |
+
|
| 198 |
+
logits = _model.model(
|
| 199 |
+
mirna_padded.to(_device),
|
| 200 |
+
target_stacked.to(_device),
|
| 201 |
+
attn_mask_mirna,
|
| 202 |
+
attn_mask_target,
|
| 203 |
+
)
|
| 204 |
+
probs = torch.sigmoid(logits.squeeze(-1)).cpu().numpy()
|
| 205 |
+
all_probs.append(probs)
|
| 206 |
+
|
| 207 |
+
all_probs = np.concatenate(all_probs)
|
| 208 |
+
df["probability"] = all_probs
|
| 209 |
+
df["prediction"] = (all_probs >= 0.5).astype(int)
|
| 210 |
+
|
| 211 |
+
# Save to temp file for download
|
| 212 |
+
out_path = Path(tempfile.mkdtemp()) / "deepmirt_predictions.csv"
|
| 213 |
+
df.to_csv(str(out_path), index=False)
|
| 214 |
+
return str(out_path), df.head(20)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ---------------------------------------------------------------------------
|
| 218 |
+
# Examples
|
| 219 |
+
# ---------------------------------------------------------------------------
|
| 220 |
+
EXAMPLES = [
|
| 221 |
+
# [miRNA, target_40nt] - real miRNA-target pairs
|
| 222 |
+
["UGAGGUAGUAGGUUGUAUAGUU", "ACUGCAGCAUAUCUACUAUUUGCUACUGUAACCAUUGAUCU"], # let-7a / lin-41
|
| 223 |
+
["UAAAGUGCUUAUAGUGCAGGUAG", "GCAGCAUUGUACAGGGCUAUCAGAAACUAUUGACACUAAAA"], # miR-20a / E2F1
|
| 224 |
+
["UAGCAGCACGUAAAUAUUGGCG", "GCAAUGUUUUCCACAGUGCUUACACAGAAAUAGCAACUUUA"], # miR-16 / BCL2
|
| 225 |
+
["CAUCAAAGUGGAGGCCCUCUCU", "AAUGCUUCUAAAUUGAAUCCAAACUGCAGUUUAUUAGUGGU"], # miR-198 (negative)
|
| 226 |
+
["UGGAAUGUAAAGAAGUAUGUAU", "UCGAAUCCAUGCAAAACAGCUUGAUUUGUUAGUACACGAAU"], # miR-1 / HAND2
|
| 227 |
+
]
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ---------------------------------------------------------------------------
|
| 231 |
+
# Gradio UI
|
| 232 |
+
# ---------------------------------------------------------------------------
|
| 233 |
+
def build_demo():
|
| 234 |
+
with gr.Blocks(
|
| 235 |
+
title="DeepMiRT: miRNA Target Prediction",
|
| 236 |
+
theme=gr.themes.Soft(),
|
| 237 |
+
) as demo:
|
| 238 |
+
gr.Markdown(
|
| 239 |
+
"""
|
| 240 |
+
# DeepMiRT: miRNA Target Prediction with RNA Foundation Models
|
| 241 |
+
|
| 242 |
+
Predict miRNA-target interactions using RNA-FM embeddings and cross-attention.
|
| 243 |
+
Ranked **#1** on eCLIP benchmarks (AUROC 0.75) and achieves **AUROC 0.96** on our comprehensive test set.
|
| 244 |
+
|
| 245 |
+
**Paper:** *coming soon* | **GitHub:** [DeepMiRT](https://github.com/zichengll/DeepMiRT) | **Model:** [Hugging Face](https://huggingface.co/liuliu2333/deepmirt)
|
| 246 |
+
"""
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
with gr.Tab("Single Prediction"):
|
| 250 |
+
with gr.Row():
|
| 251 |
+
with gr.Column():
|
| 252 |
+
mirna_input = gr.Textbox(
|
| 253 |
+
label="miRNA Sequence",
|
| 254 |
+
placeholder="e.g., UGAGGUAGUAGGUUGUAUAGUU",
|
| 255 |
+
info="18-25 nt. DNA (T) or RNA (U) format accepted.",
|
| 256 |
+
)
|
| 257 |
+
target_input = gr.Textbox(
|
| 258 |
+
label="Target Sequence",
|
| 259 |
+
placeholder="e.g., ACUGCAGCAUAUCUACUAUUUGCUACUGUAACCAUUGAUCU",
|
| 260 |
+
info="40 nt recommended. DNA (T) or RNA (U) format accepted.",
|
| 261 |
+
)
|
| 262 |
+
predict_btn = gr.Button("Predict", variant="primary")
|
| 263 |
+
|
| 264 |
+
with gr.Column():
|
| 265 |
+
result_html = gr.HTML(label="Prediction Result")
|
| 266 |
+
result_json = gr.JSON(label="Details")
|
| 267 |
+
|
| 268 |
+
predict_btn.click(
|
| 269 |
+
predict_single,
|
| 270 |
+
inputs=[mirna_input, target_input],
|
| 271 |
+
outputs=[result_html, result_json],
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
gr.Examples(
|
| 275 |
+
examples=EXAMPLES,
|
| 276 |
+
inputs=[mirna_input, target_input],
|
| 277 |
+
outputs=[result_html, result_json],
|
| 278 |
+
fn=predict_single,
|
| 279 |
+
cache_examples=False,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
with gr.Tab("Batch Prediction"):
|
| 283 |
+
gr.Markdown(
|
| 284 |
+
"""
|
| 285 |
+
Upload a CSV file with columns containing **mirna** and **target** in the column names.
|
| 286 |
+
|
| 287 |
+
Example format:
|
| 288 |
+
| mirna_seq | target_seq |
|
| 289 |
+
|-----------|------------|
|
| 290 |
+
| UGAGGUAGUAGGUUGUAUAGUU | ACUGCAGCAUAUCUACUAUUUGCUACUGUAACCAUUGAUCU |
|
| 291 |
+
"""
|
| 292 |
+
)
|
| 293 |
+
csv_input = gr.File(label="Upload CSV", file_types=[".csv"])
|
| 294 |
+
batch_btn = gr.Button("Run Batch Prediction", variant="primary")
|
| 295 |
+
csv_output = gr.File(label="Download Results")
|
| 296 |
+
preview = gr.Dataframe(label="Preview (first 20 rows)")
|
| 297 |
+
|
| 298 |
+
batch_btn.click(
|
| 299 |
+
predict_batch,
|
| 300 |
+
inputs=[csv_input],
|
| 301 |
+
outputs=[csv_output, preview],
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
with gr.Tab("About"):
|
| 305 |
+
gr.Markdown(
|
| 306 |
+
"""
|
| 307 |
+
## Model Architecture
|
| 308 |
+
|
| 309 |
+
DeepMiRT uses a **shared RNA-FM encoder** (12-layer Transformer, pre-trained on 23M non-coding RNAs)
|
| 310 |
+
to embed both miRNA and target sequences into the same representation space.
|
| 311 |
+
A **cross-attention module** (2 layers, 8 heads) allows the target to attend to the miRNA,
|
| 312 |
+
capturing interaction patterns. The attended representations are pooled and classified
|
| 313 |
+
by an **MLP head** (640 β 256 β 64 β 1).
|
| 314 |
+
|
| 315 |
+
```
|
| 316 |
+
miRNA β [RNA-FM Encoder] β miRNA embedding ββββββββββ
|
| 317 |
+
β
|
| 318 |
+
Target β [RNA-FM Encoder] β target embedding β [Cross-Attention] β Pool β [MLP] β probability
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
## Training
|
| 322 |
+
|
| 323 |
+
- **Data:** miRNA-target interactions from multiple databases and literature mining
|
| 324 |
+
- **Two-phase training:** Phase 1 (frozen backbone) β Phase 2 (unfreeze top 3 RNA-FM layers)
|
| 325 |
+
- **Hardware:** 2Γ NVIDIA L20 GPUs, mixed-precision (fp16)
|
| 326 |
+
- **Best checkpoint:** epoch 27, val AUROC = 0.9612
|
| 327 |
+
|
| 328 |
+
## Performance
|
| 329 |
+
|
| 330 |
+
| Benchmark | AUROC | Rank |
|
| 331 |
+
|-----------|-------|------|
|
| 332 |
+
| miRBench eCLIP (Klimentova 2022) | 0.7511 | #1/12 |
|
| 333 |
+
| miRBench eCLIP (Manakov 2022) | 0.7543 | #1/12 |
|
| 334 |
+
| miRBench CLASH (Hejret 2023) | 0.6952 | #5/12 |
|
| 335 |
+
| Our test set (813K samples, 16 methods) | 0.9606 | #1/16 |
|
| 336 |
+
|
| 337 |
+
## Citation
|
| 338 |
+
|
| 339 |
+
If you use DeepMiRT in your research, please cite:
|
| 340 |
+
```
|
| 341 |
+
@software{liu2026deepmirt,
|
| 342 |
+
title={DeepMiRT: miRNA Target Prediction with RNA Foundation Models},
|
| 343 |
+
author={Liu, Zicheng},
|
| 344 |
+
year={2026},
|
| 345 |
+
url={https://github.com/zichengll/DeepMiRT}
|
| 346 |
+
}
|
| 347 |
+
```
|
| 348 |
+
|
| 349 |
+
## License
|
| 350 |
+
|
| 351 |
+
MIT License. See [LICENSE](https://github.com/zichengll/DeepMiRT/blob/main/LICENSE).
|
| 352 |
+
"""
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
return demo
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
if __name__ == "__main__":
|
| 359 |
+
demo = build_demo()
|
| 360 |
+
demo.launch()
|
deepmirt/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DeepMiRT: miRNA target prediction using RNA foundation models and cross-attention."""
|
| 2 |
+
|
| 3 |
+
__version__ = "1.0.0"
|
| 4 |
+
|
| 5 |
+
from deepmirt.predict import predict as predict
|
deepmirt/data_module/__init__.py
ADDED
|
File without changes
|
deepmirt/data_module/datamodule.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
miRNA-Target PyTorch Lightning DataModule
|
| 4 |
+
|
| 5 |
+
[Lightning DataModule Lifecycle]
|
| 6 |
+
Lightning DataModule encapsulates data loading logic into a reusable module.
|
| 7 |
+
Its lifecycle is as follows:
|
| 8 |
+
|
| 9 |
+
1. prepare_data() β download data (runs only on main process; not needed in this project)
|
| 10 |
+
2. setup(stage) β create Dataset instances (runs on every process)
|
| 11 |
+
- stage='fit' β create train_dataset + val_dataset
|
| 12 |
+
- stage='test' β create test_dataset
|
| 13 |
+
- stage='predict' β create predict_dataset
|
| 14 |
+
3. train_dataloader() β return training DataLoader
|
| 15 |
+
4. val_dataloader() β return validation DataLoader
|
| 16 |
+
5. test_dataloader() β return test DataLoader
|
| 17 |
+
|
| 18 |
+
[Why use DataModule instead of manually creating DataLoaders?]
|
| 19 |
+
- Centralizes all data-related logic (paths, batch size, tokenizer, data splits)
|
| 20 |
+
- Lightning Trainer automatically calls the correct methods, reducing boilerplate
|
| 21 |
+
- Makes it easy to reuse the same data configuration across different experiments
|
| 22 |
+
|
| 23 |
+
[collate_fn Explained β The Core Difficulty of This Module]
|
| 24 |
+
Since miRNA sequence lengths are variable (15-30nt β 17-32 tokens),
|
| 25 |
+
samples in the same batch may have mirna_tokens of different lengths.
|
| 26 |
+
PyTorch's default collate cannot stack variable-length tensors,
|
| 27 |
+
so we need a custom collate_fn to:
|
| 28 |
+
1. Find the longest miRNA sequence in the batch
|
| 29 |
+
2. Pad all miRNA sequences to the same length
|
| 30 |
+
3. Generate an attention mask indicating which positions are real tokens vs. padding
|
| 31 |
+
|
| 32 |
+
Target sequences are fixed at 40nt (β 42 tokens) and do not require additional padding.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
from __future__ import annotations
|
| 36 |
+
|
| 37 |
+
import os
|
| 38 |
+
|
| 39 |
+
import fm
|
| 40 |
+
import pytorch_lightning as pl
|
| 41 |
+
import torch
|
| 42 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 43 |
+
from torch.utils.data import DataLoader
|
| 44 |
+
|
| 45 |
+
from deepmirt.data_module.dataset import MiRNATargetDataset
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MiRNATargetDataModule(pl.LightningDataModule):
|
| 49 |
+
"""
|
| 50 |
+
Lightning DataModule for miRNA-target pairs.
|
| 51 |
+
|
| 52 |
+
[Responsibilities]
|
| 53 |
+
- Manage creation and DataLoader configuration for train / val / test datasets
|
| 54 |
+
- Provide a custom collate_fn to handle variable-length miRNA sequence padding
|
| 55 |
+
- Encapsulate RNA-FM alphabet loading to avoid redundant initialization in multiple places
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
data_dir: str,
|
| 61 |
+
batch_size: int = 128,
|
| 62 |
+
num_workers: int = 8,
|
| 63 |
+
pin_memory: bool = True,
|
| 64 |
+
):
|
| 65 |
+
"""
|
| 66 |
+
Initialize the DataModule.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
data_dir (str): path to the directory containing train.csv / val.csv / test.csv
|
| 70 |
+
batch_size (int): number of samples per batch, default 128
|
| 71 |
+
num_workers (int): number of DataLoader worker processes, default 8
|
| 72 |
+
# Design decision: num_workers controls data prefetching parallelism
|
| 73 |
+
# - 0 = load in main process (for debugging, slow but easy to troubleshoot)
|
| 74 |
+
# - 8 = 8 subprocesses load in parallel (for training, fully utilize multi-core CPU)
|
| 75 |
+
# - Rule of thumb: set to half of CPU cores or GPU count x 4
|
| 76 |
+
# - Too many will cause memory overhead and process switching overhead
|
| 77 |
+
pin_memory (bool): whether to pin data to page-locked memory, default True
|
| 78 |
+
# Design decision: pin_memory accelerates CPUβGPU data transfer
|
| 79 |
+
# - True: data is first copied to pinned memory, then transferred to GPU via DMA
|
| 80 |
+
# Eliminates one memory copy, improving throughput by ~2x
|
| 81 |
+
# - False: data is in pageable memory and must be copied to pinned memory before transfer
|
| 82 |
+
# - Only meaningful when using GPU; set to False for CPU training
|
| 83 |
+
"""
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.data_dir = data_dir
|
| 86 |
+
self.batch_size = batch_size
|
| 87 |
+
self.num_workers = num_workers
|
| 88 |
+
self.pin_memory = pin_memory
|
| 89 |
+
|
| 90 |
+
# Dataset instances, created in setup()
|
| 91 |
+
self.train_dataset: MiRNATargetDataset | None = None
|
| 92 |
+
self.val_dataset: MiRNATargetDataset | None = None
|
| 93 |
+
self.test_dataset: MiRNATargetDataset | None = None
|
| 94 |
+
|
| 95 |
+
# Load RNA-FM alphabet in the main process (before DDP fork)
|
| 96 |
+
# This way the alphabet is loaded only once, avoiding redundant full model loading on each DDP rank
|
| 97 |
+
_model, alphabet = fm.pretrained.rna_fm_t12()
|
| 98 |
+
del _model # Free model weights, keep only the alphabet (tokenizer)
|
| 99 |
+
self._alphabet = alphabet
|
| 100 |
+
self._padding_idx = alphabet.padding_idx # padding_idx = 1
|
| 101 |
+
|
| 102 |
+
def setup(self, stage: str | None = None) -> None:
|
| 103 |
+
"""
|
| 104 |
+
Create Dataset instances.
|
| 105 |
+
|
| 106 |
+
Lightning automatically calls this method before training/validation/testing begins.
|
| 107 |
+
Each process (including multi-GPU DDP scenarios) calls setup() independently.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
stage: 'fit' (train+val), 'test', 'predict', or None (all)
|
| 111 |
+
"""
|
| 112 |
+
# alphabet was already loaded in __init__() (before DDP fork, loaded only once)
|
| 113 |
+
alphabet = self._alphabet
|
| 114 |
+
|
| 115 |
+
if stage == "fit" or stage is None:
|
| 116 |
+
self.train_dataset = MiRNATargetDataset(
|
| 117 |
+
os.path.join(self.data_dir, "train.csv"), alphabet
|
| 118 |
+
)
|
| 119 |
+
self.val_dataset = MiRNATargetDataset(
|
| 120 |
+
os.path.join(self.data_dir, "val.csv"), alphabet
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if stage == "test" or stage is None:
|
| 124 |
+
self.test_dataset = MiRNATargetDataset(
|
| 125 |
+
os.path.join(self.data_dir, "test.csv"), alphabet
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def train_dataloader(self) -> DataLoader:
|
| 129 |
+
"""Return the training DataLoader (shuffle=True to randomize data order)."""
|
| 130 |
+
return DataLoader(
|
| 131 |
+
self.train_dataset,
|
| 132 |
+
batch_size=self.batch_size,
|
| 133 |
+
shuffle=True,
|
| 134 |
+
num_workers=self.num_workers,
|
| 135 |
+
pin_memory=self.pin_memory,
|
| 136 |
+
collate_fn=self._collate_fn,
|
| 137 |
+
drop_last=True,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def val_dataloader(self) -> DataLoader:
|
| 141 |
+
"""Return the validation DataLoader (shuffle=False to preserve order for reproducible evaluation)."""
|
| 142 |
+
return DataLoader(
|
| 143 |
+
self.val_dataset,
|
| 144 |
+
batch_size=self.batch_size,
|
| 145 |
+
shuffle=False,
|
| 146 |
+
num_workers=self.num_workers,
|
| 147 |
+
pin_memory=self.pin_memory,
|
| 148 |
+
collate_fn=self._collate_fn,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def test_dataloader(self) -> DataLoader:
|
| 152 |
+
"""Return the test DataLoader."""
|
| 153 |
+
return DataLoader(
|
| 154 |
+
self.test_dataset,
|
| 155 |
+
batch_size=self.batch_size,
|
| 156 |
+
shuffle=False,
|
| 157 |
+
num_workers=self.num_workers,
|
| 158 |
+
pin_memory=self.pin_memory,
|
| 159 |
+
collate_fn=self._collate_fn,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def _collate_fn(self, batch: list[dict]) -> dict:
|
| 163 |
+
"""
|
| 164 |
+
Custom batch collation function β handles padding of variable-length miRNA sequences.
|
| 165 |
+
|
| 166 |
+
[Why is a custom collate_fn needed?]
|
| 167 |
+
PyTorch's default collate_fn attempts to stack all sample tensors.
|
| 168 |
+
But miRNA sequence lengths are variable (15-30nt β 17-32 tokens), and direct stacking fails:
|
| 169 |
+
RuntimeError: stack expects each tensor to be equal size
|
| 170 |
+
|
| 171 |
+
[Why does miRNA need padding but target does not?]
|
| 172 |
+
- miRNA has variable length: 15-30 nucleotides β 17-32 tokens after adding BOS+EOS
|
| 173 |
+
A single batch may contain lengths of both 17 and 32, which must be aligned
|
| 174 |
+
- Target has fixed length: all samples are 40 nucleotides β 42 tokens
|
| 175 |
+
Naturally aligned, no padding needed
|
| 176 |
+
|
| 177 |
+
[Role of attention_mask]
|
| 178 |
+
- Tells the model which positions are real tokens (1) and which are padding (0)
|
| 179 |
+
- The Transformer's self-attention uses the mask to block padding positions
|
| 180 |
+
- Prevents padding tokens from participating in attention computation, avoiding noise
|
| 181 |
+
|
| 182 |
+
# Design decision: use pad_sequence instead of manual loop padding
|
| 183 |
+
# pad_sequence is a PyTorch built-in utility, optimized in C++, faster than Python loops
|
| 184 |
+
# It automatically finds the maximum length and pads shorter sequences with the specified value
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
batch: list of dicts, each dict from MiRNATargetDataset.__getitem__
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
dict: containing the following key-value pairs:
|
| 191 |
+
- 'mirna_tokens': (batch_size, max_mirna_len) LongTensor
|
| 192 |
+
- 'target_tokens': (batch_size, 42) LongTensor
|
| 193 |
+
- 'labels': (batch_size,) float32 Tensor
|
| 194 |
+
- 'attention_mask_mirna': (batch_size, max_mirna_len) LongTensor
|
| 195 |
+
- 'attention_mask_target': (batch_size, 42) LongTensor
|
| 196 |
+
"""
|
| 197 |
+
# ββ 1. Collect individual fields ββ
|
| 198 |
+
mirna_list = [sample["mirna_tokens"] for sample in batch]
|
| 199 |
+
target_list = [sample["target_tokens"] for sample in batch]
|
| 200 |
+
label_list = [sample["label"] for sample in batch]
|
| 201 |
+
|
| 202 |
+
# ββ 2. Pad miRNA sequences ββ
|
| 203 |
+
# pad_sequence converts list of 1D tensors β 2D tensor (batch, max_len)
|
| 204 |
+
# batch_first=True ensures the batch dimension comes first
|
| 205 |
+
# padding_value=1 is RNA-FM's <pad> token ID
|
| 206 |
+
mirna_padded = pad_sequence(
|
| 207 |
+
mirna_list, batch_first=True, padding_value=self._padding_idx
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# ββ 3. Stack target sequences (fixed 42 tokens, no padding needed) ββ
|
| 211 |
+
target_stacked = torch.stack(target_list)
|
| 212 |
+
|
| 213 |
+
# ββ 4. Stack labels ββ
|
| 214 |
+
labels = torch.stack(label_list)
|
| 215 |
+
|
| 216 |
+
# ββ 5. Generate attention masks ββ
|
| 217 |
+
# miRNA mask: non-padding positions = 1, padding positions = 0
|
| 218 |
+
attention_mask_mirna = (mirna_padded != self._padding_idx).long()
|
| 219 |
+
|
| 220 |
+
# target mask: all positions are real tokens, so all 1s
|
| 221 |
+
# Because target is fixed at 40nt with no padding, every position is valid
|
| 222 |
+
attention_mask_target = torch.ones_like(target_stacked, dtype=torch.long)
|
| 223 |
+
|
| 224 |
+
# ββ 6. Collect metadata (for stratified analysis during evaluation) ββ
|
| 225 |
+
# Each metadata field is collected as list[str], kept on CPU
|
| 226 |
+
metadata_keys = batch[0].get("metadata", {}).keys()
|
| 227 |
+
metadata = {
|
| 228 |
+
key: [sample["metadata"][key] for sample in batch]
|
| 229 |
+
for key in metadata_keys
|
| 230 |
+
} if metadata_keys else {}
|
| 231 |
+
|
| 232 |
+
return {
|
| 233 |
+
"mirna_tokens": mirna_padded, # (B, max_mirna_len)
|
| 234 |
+
"target_tokens": target_stacked, # (B, 42)
|
| 235 |
+
"labels": labels, # (B,)
|
| 236 |
+
"attention_mask_mirna": attention_mask_mirna, # (B, max_mirna_len)
|
| 237 |
+
"attention_mask_target": attention_mask_target, # (B, 42)
|
| 238 |
+
"metadata": metadata, # dict[str, list[str]]
|
| 239 |
+
}
|
deepmirt/data_module/dataset.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
miRNA-Target Pair Dataset β PyTorch Dataset Implementation
|
| 4 |
+
|
| 5 |
+
[Data Flow ASCII Diagram]
|
| 6 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 7 |
+
β MiRNATargetDataset Data Flow β
|
| 8 |
+
β β
|
| 9 |
+
β CSV file (train.csv / val.csv / test.csv) β
|
| 10 |
+
β β β
|
| 11 |
+
β βΌ β
|
| 12 |
+
β pd.read_csv() ββ DataFrame (loaded entirely into memory) β
|
| 13 |
+
β β β
|
| 14 |
+
β βΌ β
|
| 15 |
+
β __getitem__(idx) ββ retrieve row idx β
|
| 16 |
+
β β β
|
| 17 |
+
β βββ mirna_seq: "ATCGATCG" β
|
| 18 |
+
β β β β
|
| 19 |
+
β β βΌ β
|
| 20 |
+
β β dna_to_rna() ββ "AUCGAUCG" (TβU conversion) β
|
| 21 |
+
β β β β
|
| 22 |
+
β β βΌ β
|
| 23 |
+
β β batch_converter([("mirna", "AUCGAUCG")]) β
|
| 24 |
+
β β β β
|
| 25 |
+
β β βΌ β
|
| 26 |
+
β β tokens: tensor([0, 4, 7, 5, 6, ...., 2]) β
|
| 27 |
+
β β ^^BOS ^^EOS β
|
| 28 |
+
β β β
|
| 29 |
+
β βββ target_fragment_40nt: "TAGCTAGC..." β
|
| 30 |
+
β β β (same dna_to_rna + batch_converter pipeline) β
|
| 31 |
+
β β βΌ β
|
| 32 |
+
β β tokens: tensor([0, ..., 2]) (fixed 42 tokens: BOS+40nt+EOS)β
|
| 33 |
+
β β β
|
| 34 |
+
β βββ return dict: β
|
| 35 |
+
β { β
|
| 36 |
+
β 'mirna_tokens': 1D LongTensor (variable 17-32) β
|
| 37 |
+
β 'target_tokens': 1D LongTensor (fixed 42) β
|
| 38 |
+
β 'label': float32 scalar (0.0 or 1.0) β
|
| 39 |
+
β 'metadata': dict (species, mirna_name, ...) β
|
| 40 |
+
β } β
|
| 41 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
|
| 43 |
+
[RNA-FM batch_converter Input/Output Format]
|
| 44 |
+
- Input: List[Tuple[str, str]] = [("label_name", "RNA_sequence")]
|
| 45 |
+
e.g.: [("mirna", "AUCGAUCG")]
|
| 46 |
+
|
| 47 |
+
- Output: Tuple[List[str], List[str], Tensor]
|
| 48 |
+
- labels: ["mirna"] β label list (not used by us)
|
| 49 |
+
- strs: ["AUCGAUCG"] β raw sequences (not used by us)
|
| 50 |
+
- tokens: tensor([[0, 4, 7, 5, 6, 4, 7, 5, 6, 2]])
|
| 51 |
+
shape = (batch=1, seq_len)
|
| 52 |
+
where 0=BOS(<cls>), 2=EOS(<eos>), 1=PAD(<pad>)
|
| 53 |
+
A=4, C=5, G=6, U=7
|
| 54 |
+
|
| 55 |
+
- Important: batch_converter already adds BOS and EOS for us!
|
| 56 |
+
So 22nt miRNA β 24 tokens (BOS + 22nt + EOS)
|
| 57 |
+
40nt target β 42 tokens (BOS + 40nt + EOS)
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
from __future__ import annotations
|
| 61 |
+
|
| 62 |
+
import pandas as pd
|
| 63 |
+
import torch
|
| 64 |
+
from torch.utils.data import Dataset
|
| 65 |
+
|
| 66 |
+
from deepmirt.data_module.preprocessing import dna_to_rna
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MiRNATargetDataset(Dataset):
|
| 70 |
+
"""
|
| 71 |
+
PyTorch Dataset for miRNA-target pairs.
|
| 72 |
+
|
| 73 |
+
[Overview]
|
| 74 |
+
Loads miRNA-target sequence pairs from a CSV file, tokenizes them using
|
| 75 |
+
the RNA-FM alphabet, and returns token tensors and labels for training.
|
| 76 |
+
|
| 77 |
+
[Usage]
|
| 78 |
+
>>> import fm
|
| 79 |
+
>>> _, alphabet = fm.pretrained.rna_fm_t12()
|
| 80 |
+
>>> ds = MiRNATargetDataset('path/to/train.csv', alphabet)
|
| 81 |
+
>>> sample = ds[0]
|
| 82 |
+
>>> sample['mirna_tokens'] # tensor([0, 4, 7, 5, ..., 2])
|
| 83 |
+
>>> sample['label'] # tensor(1.)
|
| 84 |
+
|
| 85 |
+
[Why inherit from torch.utils.data.Dataset?]
|
| 86 |
+
- It is the standard PyTorch interface for data loading
|
| 87 |
+
- After defining __len__ and __getitem__, it can be used with DataLoader
|
| 88 |
+
- DataLoader automatically handles batching, multi-process loading, shuffling, etc.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
csv_path: str,
|
| 94 |
+
alphabet,
|
| 95 |
+
max_mirna_len: int = 30,
|
| 96 |
+
max_target_len: int = 40,
|
| 97 |
+
):
|
| 98 |
+
"""
|
| 99 |
+
Initialize the dataset.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
csv_path (str): Path to the CSV file, which must contain the following columns:
|
| 103 |
+
- mirna_seq: miRNA sequence (DNA notation)
|
| 104 |
+
- target_fragment_40nt: target fragment sequence (DNA notation)
|
| 105 |
+
- label: binary label (0 or 1)
|
| 106 |
+
- species, mirna_name, target_gene_name: metadata columns
|
| 107 |
+
alphabet: RNA-FM alphabet object that provides tokenization capability
|
| 108 |
+
max_mirna_len (int): maximum nucleotide length for miRNA, default 30
|
| 109 |
+
(actual token count = max_mirna_len + 2, due to BOS and EOS)
|
| 110 |
+
max_target_len (int): maximum nucleotide length for target, default 40
|
| 111 |
+
(actual token count = max_target_len + 2 = 42)
|
| 112 |
+
|
| 113 |
+
[Design Decision: Memory Strategy]
|
| 114 |
+
We use pd.read_csv() to load the entire CSV into a DataFrame at once.
|
| 115 |
+
This is the simplest approach β for our data scale (~5.4 million training rows),
|
| 116 |
+
the DataFrame occupies approximately 2-3 GB of memory.
|
| 117 |
+
|
| 118 |
+
The current system has 1TB RAM, so this is not an issue at all.
|
| 119 |
+
|
| 120 |
+
# Design decision: if memory is limited (e.g., 8GB), consider these alternatives:
|
| 121 |
+
# 1. Byte-offset indexing: first pass records byte positions of each row in the file,
|
| 122 |
+
# __getitem__ uses file.seek(offset) to jump to and read that row
|
| 123 |
+
# 2. Memory mapping (mmap): open the file with mmap, read on demand
|
| 124 |
+
# 3. Chunked reading: load in chunks, combined with LRU cache
|
| 125 |
+
# These methods sacrifice code simplicity for lower memory usage
|
| 126 |
+
"""
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
# Save configuration parameters
|
| 130 |
+
self.csv_path = csv_path
|
| 131 |
+
self.alphabet = alphabet
|
| 132 |
+
self.max_mirna_len = max_mirna_len
|
| 133 |
+
self.max_target_len = max_target_len
|
| 134 |
+
|
| 135 |
+
# Get batch_converter for tokenization
|
| 136 |
+
# batch_converter is the tokenization tool provided by RNA-FM, converting RNA strings to token IDs
|
| 137 |
+
self.batch_converter = alphabet.get_batch_converter()
|
| 138 |
+
|
| 139 |
+
# Design decision: load entire CSV into memory (see docstring above for details)
|
| 140 |
+
# On a 1TB RAM system, 5.4 million rows β 2-3 GB, easily affordable
|
| 141 |
+
self.df = pd.read_csv(
|
| 142 |
+
csv_path,
|
| 143 |
+
dtype={"target_gene_name": str, "target_gene_id": str},
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def __len__(self) -> int:
|
| 147 |
+
"""
|
| 148 |
+
Return the number of samples in the dataset.
|
| 149 |
+
|
| 150 |
+
DataLoader calls this method to determine how many steps per epoch.
|
| 151 |
+
e.g.: len(dataset)=557521, batch_size=128 β ~4356 steps per epoch
|
| 152 |
+
"""
|
| 153 |
+
return len(self.df)
|
| 154 |
+
|
| 155 |
+
def __getitem__(self, idx: int) -> dict:
|
| 156 |
+
"""
|
| 157 |
+
Retrieve the idx-th sample, returning a dict of tokenized tensors.
|
| 158 |
+
|
| 159 |
+
[Processing Pipeline]
|
| 160 |
+
1. Extract row idx from the DataFrame
|
| 161 |
+
2. Get mirna_seq and target_fragment_40nt
|
| 162 |
+
3. Apply dna_to_rna() for TβU conversion
|
| 163 |
+
4. Tokenize with RNA-FM batch_converter
|
| 164 |
+
5. Assemble and return the dict
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
idx (int): sample index, range [0, len(self)-1]
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
dict: containing the following key-value pairs:
|
| 171 |
+
- 'mirna_tokens': 1D LongTensor, miRNA token sequence
|
| 172 |
+
shape = (mirna_len+2,), including BOS and EOS
|
| 173 |
+
- 'target_tokens': 1D LongTensor, target token sequence
|
| 174 |
+
shape = (42,), fixed length (BOS + 40nt + EOS)
|
| 175 |
+
- 'label': float32 scalar tensor (0.0 or 1.0)
|
| 176 |
+
- 'metadata': dict, containing species, mirna_name, target_gene_name
|
| 177 |
+
"""
|
| 178 |
+
# ββ Step 1: Extract one row from the DataFrame ββ
|
| 179 |
+
row = self.df.iloc[idx]
|
| 180 |
+
|
| 181 |
+
# ββ Step 2: Extract sequences and label ββ
|
| 182 |
+
mirna_seq_raw = row["mirna_seq"]
|
| 183 |
+
target_seq_raw = row["target_fragment_40nt"]
|
| 184 |
+
label = row["label"]
|
| 185 |
+
|
| 186 |
+
# ββ Step 3: DNA-to-RNA conversion (T β U) ββ
|
| 187 |
+
# Sequences in the dataset use DNA notation (T for thymine),
|
| 188 |
+
# but the RNA-FM model expects RNA notation (U for uridine), so conversion is needed
|
| 189 |
+
mirna_rna = dna_to_rna(mirna_seq_raw)
|
| 190 |
+
target_rna = dna_to_rna(target_seq_raw)
|
| 191 |
+
|
| 192 |
+
# ββ Step 4: Tokenize using RNA-FM batch_converter ββ
|
| 193 |
+
# batch_converter input format: List[Tuple[label, sequence]]
|
| 194 |
+
# It automatically adds BOS(<cls>=0) and EOS(<eos>=2) tokens around the sequence
|
| 195 |
+
#
|
| 196 |
+
# e.g.: [("mirna", "AUCG")]
|
| 197 |
+
# output tokens: tensor([[0, 4, 7, 5, 6, 2]])
|
| 198 |
+
# BOS=0 A U C G EOS=2
|
| 199 |
+
#
|
| 200 |
+
# Here we process only 1 sequence at a time (batch_size=1),
|
| 201 |
+
# so we use tokens[0] to extract the first one, yielding a 1D tensor
|
| 202 |
+
|
| 203 |
+
# Tokenize miRNA
|
| 204 |
+
_, _, mirna_tokens = self.batch_converter([("mirna", mirna_rna)])
|
| 205 |
+
mirna_tokens = mirna_tokens[0] # (1, seq_len) β (seq_len,)
|
| 206 |
+
|
| 207 |
+
# Tokenize target
|
| 208 |
+
_, _, target_tokens = self.batch_converter([("target", target_rna)])
|
| 209 |
+
target_tokens = target_tokens[0] # (1, 42) β (42,)
|
| 210 |
+
|
| 211 |
+
# ββ Step 5: Assemble the return dict ββ
|
| 212 |
+
# Why use float32 for label?
|
| 213 |
+
# Because training uses BCEWithLogitsLoss (binary cross-entropy),
|
| 214 |
+
# which requires both target and prediction to be float type.
|
| 215 |
+
# If label is int/long, PyTorch will raise a type mismatch error.
|
| 216 |
+
return {
|
| 217 |
+
"mirna_tokens": mirna_tokens, # 1D LongTensor, variable (17-32)
|
| 218 |
+
"target_tokens": target_tokens, # 1D LongTensor, fixed 42
|
| 219 |
+
"label": torch.tensor(label, dtype=torch.float32), # scalar float32
|
| 220 |
+
"metadata": {
|
| 221 |
+
"species": row["species"],
|
| 222 |
+
"mirna_name": row["mirna_name"],
|
| 223 |
+
"target_gene_name": row["target_gene_name"],
|
| 224 |
+
"evidence_type": row.get("evidence_type", ""),
|
| 225 |
+
"source_database": row.get("source_database", ""),
|
| 226 |
+
},
|
| 227 |
+
}
|
deepmirt/data_module/preprocessing.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Data Preprocessing Utilities β RNA Sequence Format Conversion Module
|
| 4 |
+
|
| 5 |
+
This module converts DNA-notation sequences in the dataset to the RNA notation
|
| 6 |
+
format required by the RNA-FM model.
|
| 7 |
+
|
| 8 |
+
[Why is this conversion needed?]
|
| 9 |
+
- The RNA-FM model was trained on RNA sequences and expects input in RNA notation: A, U, G, C
|
| 10 |
+
- Our dataset stores sequences in DNA notation: A, T, G, C (where T replaces U)
|
| 11 |
+
- During training, DNA notation T must be converted to RNA notation U to match the model's expected input format
|
| 12 |
+
|
| 13 |
+
[Architecture Position]
|
| 14 |
+
- This module is called by Dataset.__getitem__() during training
|
| 15 |
+
- The conversion happens at the data loading stage without modifying the original CSV files
|
| 16 |
+
- Reference: finalize_dataset.py:86-93 performs the reverse operation (UβT) for data export
|
| 17 |
+
|
| 18 |
+
[Design Decisions]
|
| 19 |
+
- Conversion is performed online (in the Dataset) rather than preprocessing the CSV, to preserve original data integrity
|
| 20 |
+
- All sequences are converted to uppercase to ensure format consistency
|
| 21 |
+
- The character N (representing ambiguous bases) is allowed; RNA-FM can handle ambiguous bases
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def dna_to_rna(seq: str) -> str:
|
| 28 |
+
"""
|
| 29 |
+
Convert a DNA-notation sequence to an RNA-notation sequence.
|
| 30 |
+
|
| 31 |
+
[Description]
|
| 32 |
+
- Converts T (thymine, DNA) to U (uridine, RNA)
|
| 33 |
+
- Converts to uppercase
|
| 34 |
+
- Removes all whitespace characters
|
| 35 |
+
- Idempotent: sequences already in RNA format remain unchanged
|
| 36 |
+
|
| 37 |
+
[Design Decisions]
|
| 38 |
+
- Why convert online? To keep the original CSV data intact for auditing and reproducibility
|
| 39 |
+
- Why uppercase? To ensure consistency with the RNA-FM model's expected input format
|
| 40 |
+
- Why allow N? RNA-FM's tokenizer can handle ambiguous bases
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
seq (str): DNA-notation sequence string, may contain A, T, G, C, N and whitespace
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
str: RNA-notation sequence string, containing A, U, G, C, N (uppercase, no whitespace)
|
| 47 |
+
|
| 48 |
+
Example:
|
| 49 |
+
>>> dna_to_rna('ATCGATCG')
|
| 50 |
+
'AUCGAUCG'
|
| 51 |
+
>>> dna_to_rna('atcg') # mixed case
|
| 52 |
+
'AUCG'
|
| 53 |
+
>>> dna_to_rna('AUCGAUCG') # already RNA format (idempotent)
|
| 54 |
+
'AUCGAUCG'
|
| 55 |
+
>>> dna_to_rna('ATC NGATCG') # contains N and whitespace
|
| 56 |
+
'AUCNGAUCG'
|
| 57 |
+
>>> dna_to_rna(' ATC G ') # leading/trailing whitespace
|
| 58 |
+
'AUCG'
|
| 59 |
+
"""
|
| 60 |
+
# Step 1: Convert to uppercase
|
| 61 |
+
seq = str(seq).upper()
|
| 62 |
+
|
| 63 |
+
# Step 2: Remove all whitespace characters (spaces, tabs, newlines)
|
| 64 |
+
seq = seq.replace(" ", "").replace("\t", "").replace("\n", "").replace("\r", "")
|
| 65 |
+
|
| 66 |
+
# Step 3: Convert T (DNA) to U (RNA)
|
| 67 |
+
seq = seq.replace("T", "U")
|
| 68 |
+
|
| 69 |
+
return seq
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def validate_rna_sequence(seq: str, min_len: int = 5, max_len: int = 100) -> bool:
|
| 73 |
+
"""
|
| 74 |
+
Validate whether a sequence is in valid RNA format.
|
| 75 |
+
|
| 76 |
+
[Description]
|
| 77 |
+
- Checks that the sequence contains only valid RNA characters: A, U, G, C, N
|
| 78 |
+
- Checks that the sequence length is within the specified range
|
| 79 |
+
- If it contains T, the DNA-to-RNA conversion was not performed; returns False
|
| 80 |
+
|
| 81 |
+
[Design Decisions]
|
| 82 |
+
- Why check for T? It serves as an indicator of conversion failure, aiding data flow debugging
|
| 83 |
+
- Why allow N? RNA-FM's tokenizer supports ambiguous bases
|
| 84 |
+
- Why impose length limits? To prevent abnormally long sequences from causing memory overflow
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
seq (str): the sequence string to validate
|
| 88 |
+
min_len (int): minimum length (inclusive), default 5
|
| 89 |
+
max_len (int): maximum length (inclusive), default 100
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
bool: True if the sequence is valid, False otherwise
|
| 93 |
+
|
| 94 |
+
Example:
|
| 95 |
+
>>> validate_rna_sequence('AUCGAUCG', 5, 30)
|
| 96 |
+
True
|
| 97 |
+
>>> validate_rna_sequence('ATCG', 5, 30) # contains T (DNA notation)
|
| 98 |
+
False
|
| 99 |
+
>>> validate_rna_sequence('AU', 5, 30) # too short
|
| 100 |
+
False
|
| 101 |
+
>>> validate_rna_sequence('A' * 31, 5, 30) # too long
|
| 102 |
+
False
|
| 103 |
+
>>> validate_rna_sequence('AUCNGAUCG', 5, 30) # contains N (valid)
|
| 104 |
+
True
|
| 105 |
+
"""
|
| 106 |
+
# Check length
|
| 107 |
+
if len(seq) < min_len or len(seq) > max_len:
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
# Define valid RNA character set
|
| 111 |
+
valid_chars = {"A", "U", "G", "C", "N"}
|
| 112 |
+
|
| 113 |
+
# Check if all characters are valid
|
| 114 |
+
for char in seq:
|
| 115 |
+
if char not in valid_chars:
|
| 116 |
+
# Specifically check for T, indicating conversion failure
|
| 117 |
+
if char == "T":
|
| 118 |
+
return False
|
| 119 |
+
# Other invalid characters also return False
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
return True
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def prepare_rnafm_input(mirna_seq: str, target_seq: str) -> tuple[str, str]:
|
| 126 |
+
"""
|
| 127 |
+
Prepare an input sequence pair for the RNA-FM model.
|
| 128 |
+
|
| 129 |
+
[Description]
|
| 130 |
+
- Converts both miRNA and target sequences to RNA notation
|
| 131 |
+
- Returns two separate strings (not concatenated)
|
| 132 |
+
- RNA-FM uses a shared encoder architecture that processes each sequence independently
|
| 133 |
+
|
| 134 |
+
[Design Decisions]
|
| 135 |
+
- Why not concatenate? The dual-encoder processes each sequence in separate forward passes
|
| 136 |
+
- Concatenation would break the model's architectural design and degrade performance
|
| 137 |
+
- Returning a tuple is convenient for use in Dataset.__getitem__()
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
mirna_seq (str): miRNA sequence (DNA notation)
|
| 141 |
+
target_seq (str): target sequence (DNA notation)
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
tuple[str, str]: (mirna_rna, target_rna) tuple, both in RNA notation
|
| 145 |
+
|
| 146 |
+
Example:
|
| 147 |
+
>>> mirna_rna, target_rna = prepare_rnafm_input('ATCG', 'TAGC')
|
| 148 |
+
>>> mirna_rna
|
| 149 |
+
'AUCG'
|
| 150 |
+
>>> target_rna
|
| 151 |
+
'UAGC'
|
| 152 |
+
"""
|
| 153 |
+
# Convert the two sequences separately
|
| 154 |
+
mirna_rna = dna_to_rna(mirna_seq)
|
| 155 |
+
target_rna = dna_to_rna(target_seq)
|
| 156 |
+
|
| 157 |
+
return mirna_rna, target_rna
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def compute_sequence_stats(csv_path: str, sample_n: int = 10000) -> dict:
|
| 161 |
+
"""
|
| 162 |
+
Compute statistics for sequences in a CSV file.
|
| 163 |
+
|
| 164 |
+
[Description]
|
| 165 |
+
- Samples a specified number of rows from the CSV file
|
| 166 |
+
- Computes sequence length distributions, character frequencies, DNA notation detection, etc.
|
| 167 |
+
- Used for data quality checks and analysis
|
| 168 |
+
|
| 169 |
+
[Design Decisions]
|
| 170 |
+
- Why lazy-import pandas? To avoid introducing a heavy dependency at module load time
|
| 171 |
+
- Import only when needed, reducing startup time
|
| 172 |
+
- Sampling instead of full processing speeds up statistics computation
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
csv_path (str): path to the CSV file
|
| 176 |
+
sample_n (int): number of rows to sample, default 10000. If the file has fewer rows, all rows are used
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
dict: statistics dictionary containing the following keys:
|
| 180 |
+
- 'total_rows': total number of rows in the file (excluding header)
|
| 181 |
+
- 'sample_rows': actual number of sampled rows
|
| 182 |
+
- 'mirna_length_min': minimum miRNA length
|
| 183 |
+
- 'mirna_length_max': maximum miRNA length
|
| 184 |
+
- 'mirna_length_mean': mean miRNA length
|
| 185 |
+
- 'target_length_min': minimum target sequence length
|
| 186 |
+
- 'target_length_max': maximum target sequence length
|
| 187 |
+
- 'target_length_mean': mean target sequence length
|
| 188 |
+
- 'mirna_char_freq': miRNA character frequency dictionary
|
| 189 |
+
- 'target_char_freq': target sequence character frequency dictionary
|
| 190 |
+
- 'mirna_with_t_count': number of miRNA sequences containing T
|
| 191 |
+
- 'target_with_t_count': number of target sequences containing T
|
| 192 |
+
|
| 193 |
+
Example:
|
| 194 |
+
>>> stats = compute_sequence_stats('deepmirt/data/training/train.csv', sample_n=100)
|
| 195 |
+
>>> print(f"Total rows: {stats['total_rows']}")
|
| 196 |
+
>>> print(f"miRNA length range: {stats['mirna_length_min']}-{stats['mirna_length_max']}")
|
| 197 |
+
"""
|
| 198 |
+
# Lazy-import pandas to avoid introducing a heavy dependency at module load time
|
| 199 |
+
import pandas as pd
|
| 200 |
+
|
| 201 |
+
# Read the CSV file
|
| 202 |
+
df = pd.read_csv(csv_path)
|
| 203 |
+
|
| 204 |
+
# Compute total number of rows
|
| 205 |
+
total_rows = len(df)
|
| 206 |
+
|
| 207 |
+
# Determine sample size (capped at total number of rows)
|
| 208 |
+
actual_sample_n = min(sample_n, total_rows)
|
| 209 |
+
|
| 210 |
+
# Sample data
|
| 211 |
+
if actual_sample_n < total_rows:
|
| 212 |
+
sample_df = df.sample(n=actual_sample_n, random_state=42)
|
| 213 |
+
else:
|
| 214 |
+
sample_df = df
|
| 215 |
+
|
| 216 |
+
# Initialize statistics dictionary
|
| 217 |
+
stats = {
|
| 218 |
+
'total_rows': total_rows,
|
| 219 |
+
'sample_rows': len(sample_df),
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
# Compute miRNA sequence statistics
|
| 223 |
+
mirna_lengths = sample_df['mirna_seq'].str.len()
|
| 224 |
+
stats['mirna_length_min'] = int(mirna_lengths.min())
|
| 225 |
+
stats['mirna_length_max'] = int(mirna_lengths.max())
|
| 226 |
+
stats['mirna_length_mean'] = float(mirna_lengths.mean())
|
| 227 |
+
|
| 228 |
+
# Compute target sequence statistics
|
| 229 |
+
target_lengths = sample_df['target_fragment_40nt'].str.len()
|
| 230 |
+
stats['target_length_min'] = int(target_lengths.min())
|
| 231 |
+
stats['target_length_max'] = int(target_lengths.max())
|
| 232 |
+
stats['target_length_mean'] = float(target_lengths.mean())
|
| 233 |
+
|
| 234 |
+
# Compute character frequencies
|
| 235 |
+
def compute_char_freq(seq_series):
|
| 236 |
+
"""Compute the frequency of each character in the sequences"""
|
| 237 |
+
freq = {}
|
| 238 |
+
for seq in seq_series:
|
| 239 |
+
seq = str(seq).upper()
|
| 240 |
+
for char in seq:
|
| 241 |
+
freq[char] = freq.get(char, 0) + 1
|
| 242 |
+
return freq
|
| 243 |
+
|
| 244 |
+
stats['mirna_char_freq'] = compute_char_freq(sample_df['mirna_seq'])
|
| 245 |
+
stats['target_char_freq'] = compute_char_freq(sample_df['target_fragment_40nt'])
|
| 246 |
+
|
| 247 |
+
# Count sequences containing T (DNA notation)
|
| 248 |
+
stats['mirna_with_t_count'] = (sample_df['mirna_seq'].str.contains('T', case=False, na=False)).sum()
|
| 249 |
+
stats['target_with_t_count'] = (sample_df['target_fragment_40nt'].str.contains('T', case=False, na=False)).sum()
|
| 250 |
+
|
| 251 |
+
return stats
|
deepmirt/evaluation/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""miRNA target prediction model β comprehensive evaluation framework."""
|
deepmirt/evaluation/predict.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Inference engine: load checkpoint and generate prediction DataFrame on the test set.
|
| 4 |
+
|
| 5 |
+
Independent of Lightning trainer.test(), performs batch inference directly and
|
| 6 |
+
retains all metadata. Prediction results are cached as parquet to avoid repeated inference.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import torch
|
| 17 |
+
import yaml
|
| 18 |
+
from torch.utils.data import DataLoader
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_model_from_checkpoint(
|
| 24 |
+
ckpt_path: str,
|
| 25 |
+
config_path: str,
|
| 26 |
+
device: str = "cuda",
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Load a trained model from checkpoint.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
ckpt_path: path to the checkpoint file
|
| 33 |
+
config_path: path to the training config YAML
|
| 34 |
+
device: inference device
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
(model, config) tuple
|
| 38 |
+
"""
|
| 39 |
+
from deepmirt.training.lightning_module import MiRNATargetLitModule
|
| 40 |
+
|
| 41 |
+
with open(config_path) as f:
|
| 42 |
+
config = yaml.safe_load(f)
|
| 43 |
+
|
| 44 |
+
lit_model = MiRNATargetLitModule.load_from_checkpoint(
|
| 45 |
+
ckpt_path, config=config, map_location=device
|
| 46 |
+
)
|
| 47 |
+
lit_model.eval()
|
| 48 |
+
lit_model.to(device)
|
| 49 |
+
return lit_model, config
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def run_inference(
|
| 53 |
+
ckpt_path: str,
|
| 54 |
+
config_path: str,
|
| 55 |
+
test_csv_path: str,
|
| 56 |
+
batch_size: int = 256,
|
| 57 |
+
num_workers: int = 8,
|
| 58 |
+
device: str = "cuda",
|
| 59 |
+
cache_path: str | None = None,
|
| 60 |
+
) -> pd.DataFrame:
|
| 61 |
+
"""
|
| 62 |
+
Run model inference on the test set, returning a DataFrame with predictions and metadata.
|
| 63 |
+
|
| 64 |
+
If cache_path exists and is non-empty, loads cached results directly.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
ckpt_path: path to the checkpoint
|
| 68 |
+
config_path: path to the config YAML
|
| 69 |
+
test_csv_path: path to test.csv
|
| 70 |
+
batch_size: inference batch size
|
| 71 |
+
num_workers: number of DataLoader worker threads
|
| 72 |
+
device: inference device
|
| 73 |
+
cache_path: cache file path (parquet), None to disable caching
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
DataFrame with columns:
|
| 77 |
+
mirna_seq, target_fragment_40nt, label, prob, pred, logit,
|
| 78 |
+
species, mirna_name, target_gene_name, evidence_type, source_database
|
| 79 |
+
"""
|
| 80 |
+
# Check cache (supports both parquet and csv formats)
|
| 81 |
+
if cache_path and Path(cache_path).exists():
|
| 82 |
+
logger.info(f"Loading cached predictions from {cache_path}")
|
| 83 |
+
if cache_path.endswith(".parquet"):
|
| 84 |
+
return pd.read_parquet(cache_path)
|
| 85 |
+
else:
|
| 86 |
+
return pd.read_csv(cache_path)
|
| 87 |
+
|
| 88 |
+
logger.info(f"Loading model from {ckpt_path}")
|
| 89 |
+
lit_model, config = load_model_from_checkpoint(ckpt_path, config_path, device)
|
| 90 |
+
|
| 91 |
+
# Load data (using DataModule approach for consistency)
|
| 92 |
+
import fm
|
| 93 |
+
|
| 94 |
+
from deepmirt.data_module.datamodule import MiRNATargetDataModule
|
| 95 |
+
from deepmirt.data_module.dataset import MiRNATargetDataset
|
| 96 |
+
|
| 97 |
+
_, alphabet = fm.pretrained.rna_fm_t12()
|
| 98 |
+
del _
|
| 99 |
+
padding_idx = alphabet.padding_idx
|
| 100 |
+
|
| 101 |
+
dataset = MiRNATargetDataset(test_csv_path, alphabet)
|
| 102 |
+
|
| 103 |
+
# Use the DataModule's collate_fn logic
|
| 104 |
+
dm = MiRNATargetDataModule.__new__(MiRNATargetDataModule)
|
| 105 |
+
dm._padding_idx = padding_idx
|
| 106 |
+
|
| 107 |
+
dataloader = DataLoader(
|
| 108 |
+
dataset,
|
| 109 |
+
batch_size=batch_size,
|
| 110 |
+
shuffle=False,
|
| 111 |
+
num_workers=num_workers,
|
| 112 |
+
pin_memory=True,
|
| 113 |
+
collate_fn=dm._collate_fn,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Inference
|
| 117 |
+
all_logits = []
|
| 118 |
+
all_labels = []
|
| 119 |
+
all_metadata = {
|
| 120 |
+
"species": [],
|
| 121 |
+
"mirna_name": [],
|
| 122 |
+
"target_gene_name": [],
|
| 123 |
+
"evidence_type": [],
|
| 124 |
+
"source_database": [],
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
logger.info(f"Running inference on {len(dataset)} samples...")
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 130 |
+
mirna_tokens = batch["mirna_tokens"].to(device)
|
| 131 |
+
target_tokens = batch["target_tokens"].to(device)
|
| 132 |
+
labels = batch["labels"]
|
| 133 |
+
attn_mask_mirna = batch["attention_mask_mirna"].to(device)
|
| 134 |
+
attn_mask_target = batch["attention_mask_target"].to(device)
|
| 135 |
+
|
| 136 |
+
logits = lit_model.model(
|
| 137 |
+
mirna_tokens, target_tokens, attn_mask_mirna, attn_mask_target
|
| 138 |
+
)
|
| 139 |
+
logits = logits.squeeze(-1).cpu()
|
| 140 |
+
|
| 141 |
+
all_logits.append(logits)
|
| 142 |
+
all_labels.append(labels)
|
| 143 |
+
|
| 144 |
+
metadata = batch.get("metadata", {})
|
| 145 |
+
for key in all_metadata:
|
| 146 |
+
if key in metadata:
|
| 147 |
+
all_metadata[key].extend(metadata[key])
|
| 148 |
+
else:
|
| 149 |
+
all_metadata[key].extend([""] * len(labels))
|
| 150 |
+
|
| 151 |
+
if (batch_idx + 1) % 500 == 0:
|
| 152 |
+
logger.info(
|
| 153 |
+
f" Processed {(batch_idx + 1) * batch_size} / {len(dataset)}"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
all_logits = torch.cat(all_logits).numpy()
|
| 157 |
+
all_labels = torch.cat(all_labels).numpy()
|
| 158 |
+
all_probs = 1.0 / (1.0 + np.exp(-all_logits)) # sigmoid
|
| 159 |
+
all_preds = (all_probs >= 0.5).astype(int)
|
| 160 |
+
|
| 161 |
+
# Build raw sequence columns (read directly from CSV)
|
| 162 |
+
raw_df = pd.read_csv(
|
| 163 |
+
test_csv_path,
|
| 164 |
+
usecols=["mirna_seq", "target_fragment_40nt"],
|
| 165 |
+
dtype=str,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
result_df = pd.DataFrame(
|
| 169 |
+
{
|
| 170 |
+
"mirna_seq": raw_df["mirna_seq"].values,
|
| 171 |
+
"target_fragment_40nt": raw_df["target_fragment_40nt"].values,
|
| 172 |
+
"label": all_labels.astype(int),
|
| 173 |
+
"prob": all_probs,
|
| 174 |
+
"pred": all_preds,
|
| 175 |
+
"logit": all_logits,
|
| 176 |
+
"species": all_metadata["species"],
|
| 177 |
+
"mirna_name": all_metadata["mirna_name"],
|
| 178 |
+
"target_gene_name": all_metadata["target_gene_name"],
|
| 179 |
+
"evidence_type": all_metadata["evidence_type"],
|
| 180 |
+
"source_database": all_metadata["source_database"],
|
| 181 |
+
}
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Cache results (prefer parquet, fallback to csv)
|
| 185 |
+
if cache_path:
|
| 186 |
+
Path(cache_path).parent.mkdir(parents=True, exist_ok=True)
|
| 187 |
+
try:
|
| 188 |
+
if cache_path.endswith(".parquet"):
|
| 189 |
+
result_df.to_parquet(cache_path, index=False)
|
| 190 |
+
else:
|
| 191 |
+
result_df.to_csv(cache_path, index=False)
|
| 192 |
+
except ImportError:
|
| 193 |
+
# pyarrow not installed, fallback to csv
|
| 194 |
+
csv_path = cache_path.replace(".parquet", ".csv")
|
| 195 |
+
result_df.to_csv(csv_path, index=False)
|
| 196 |
+
logger.info(f"pyarrow not available, saved as CSV: {csv_path}")
|
| 197 |
+
cache_path = csv_path
|
| 198 |
+
logger.info(f"Predictions cached to {cache_path}")
|
| 199 |
+
|
| 200 |
+
logger.info(
|
| 201 |
+
f"Inference complete: {len(result_df)} samples, "
|
| 202 |
+
f"pos={result_df['label'].sum()}, neg={(result_df['label'] == 0).sum()}"
|
| 203 |
+
)
|
| 204 |
+
return result_df
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def predict_on_sequences(
|
| 208 |
+
ckpt_path: str,
|
| 209 |
+
config_path: str,
|
| 210 |
+
mirna_seqs: list[str],
|
| 211 |
+
target_seqs: list[str],
|
| 212 |
+
batch_size: int = 256,
|
| 213 |
+
device: str = "cuda",
|
| 214 |
+
_lit_model=None,
|
| 215 |
+
_alphabet=None,
|
| 216 |
+
) -> np.ndarray:
|
| 217 |
+
"""
|
| 218 |
+
Run inference on arbitrary miRNA + target sequence pairs.
|
| 219 |
+
|
| 220 |
+
Used to run our model on external data such as miRBench standard benchmark datasets.
|
| 221 |
+
Sequences are automatically converted to RNA format (T->U).
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
ckpt_path: path to the checkpoint
|
| 225 |
+
config_path: path to the config YAML
|
| 226 |
+
mirna_seqs: list of miRNA sequences (DNA or RNA format accepted)
|
| 227 |
+
target_seqs: list of target sequences (DNA or RNA format, should be 40nt)
|
| 228 |
+
batch_size: inference batch size
|
| 229 |
+
device: inference device
|
| 230 |
+
_lit_model: pre-loaded model (internal use, for caching)
|
| 231 |
+
_alphabet: pre-loaded alphabet (internal use, for caching)
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
numpy array of predicted probabilities, shape (n_samples,)
|
| 235 |
+
"""
|
| 236 |
+
import fm
|
| 237 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 238 |
+
|
| 239 |
+
if _lit_model is not None:
|
| 240 |
+
lit_model = _lit_model
|
| 241 |
+
else:
|
| 242 |
+
logger.info(f"Loading model from {ckpt_path}")
|
| 243 |
+
lit_model, config = load_model_from_checkpoint(ckpt_path, config_path, device)
|
| 244 |
+
|
| 245 |
+
if _alphabet is not None:
|
| 246 |
+
alphabet = _alphabet
|
| 247 |
+
else:
|
| 248 |
+
_, alphabet = fm.pretrained.rna_fm_t12()
|
| 249 |
+
del _
|
| 250 |
+
batch_converter = alphabet.get_batch_converter()
|
| 251 |
+
padding_idx = alphabet.padding_idx
|
| 252 |
+
|
| 253 |
+
def _to_rna(seq: str) -> str:
|
| 254 |
+
return seq.upper().replace("T", "U")
|
| 255 |
+
|
| 256 |
+
all_probs = []
|
| 257 |
+
n_samples = len(mirna_seqs)
|
| 258 |
+
logger.info(f"Running inference on {n_samples} sequences...")
|
| 259 |
+
|
| 260 |
+
with torch.no_grad():
|
| 261 |
+
for i in range(0, n_samples, batch_size):
|
| 262 |
+
batch_mirna = mirna_seqs[i : i + batch_size]
|
| 263 |
+
batch_target = target_seqs[i : i + batch_size]
|
| 264 |
+
|
| 265 |
+
mirna_tokens_list = []
|
| 266 |
+
target_tokens_list = []
|
| 267 |
+
for m_seq, t_seq in zip(batch_mirna, batch_target):
|
| 268 |
+
m_rna = _to_rna(str(m_seq))
|
| 269 |
+
t_rna = _to_rna(str(t_seq))
|
| 270 |
+
_, _, m_tok = batch_converter([("m", m_rna)])
|
| 271 |
+
_, _, t_tok = batch_converter([("t", t_rna)])
|
| 272 |
+
mirna_tokens_list.append(m_tok[0])
|
| 273 |
+
target_tokens_list.append(t_tok[0])
|
| 274 |
+
|
| 275 |
+
mirna_padded = pad_sequence(
|
| 276 |
+
mirna_tokens_list, batch_first=True, padding_value=padding_idx
|
| 277 |
+
)
|
| 278 |
+
target_stacked = torch.stack(target_tokens_list)
|
| 279 |
+
|
| 280 |
+
attn_mask_mirna = (mirna_padded != padding_idx).long()
|
| 281 |
+
attn_mask_target = torch.ones_like(target_stacked, dtype=torch.long)
|
| 282 |
+
|
| 283 |
+
mirna_padded = mirna_padded.to(device)
|
| 284 |
+
target_stacked = target_stacked.to(device)
|
| 285 |
+
attn_mask_mirna = attn_mask_mirna.to(device)
|
| 286 |
+
attn_mask_target = attn_mask_target.to(device)
|
| 287 |
+
|
| 288 |
+
logits = lit_model.model(
|
| 289 |
+
mirna_padded, target_stacked, attn_mask_mirna, attn_mask_target
|
| 290 |
+
)
|
| 291 |
+
probs = torch.sigmoid(logits.squeeze(-1)).cpu().numpy()
|
| 292 |
+
all_probs.append(probs)
|
| 293 |
+
|
| 294 |
+
if (i // batch_size + 1) % 100 == 0:
|
| 295 |
+
logger.info(f" Processed {min(i + batch_size, n_samples)} / {n_samples}")
|
| 296 |
+
|
| 297 |
+
return np.concatenate(all_probs)
|
deepmirt/model/__init__.py
ADDED
|
File without changes
|
deepmirt/model/classifier.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# pyright: basic, reportMissingImports=false
|
| 3 |
+
"""
|
| 4 |
+
MLP classifier head (maps sequence representations to binary classification logits).
|
| 5 |
+
|
| 6 |
+
Architecture diagram:
|
| 7 |
+
|
| 8 |
+
pooled_feature (B, 640)
|
| 9 |
+
|
|
| 10 |
+
v
|
| 11 |
+
Linear(640 -> 256)
|
| 12 |
+
|
|
| 13 |
+
v
|
| 14 |
+
BatchNorm + ReLU + Dropout(0.3)
|
| 15 |
+
|
|
| 16 |
+
v
|
| 17 |
+
Linear(256 -> 64) + ReLU + Dropout(0.2)
|
| 18 |
+
|
|
| 19 |
+
v
|
| 20 |
+
Linear(64 -> 1)
|
| 21 |
+
|
|
| 22 |
+
v
|
| 23 |
+
logits (B, 1)
|
| 24 |
+
|
| 25 |
+
Note:
|
| 26 |
+
- The output is logits (raw scores); do not apply sigmoid inside the model.
|
| 27 |
+
- During training, use BCEWithLogitsLoss which applies sigmoid internally for numerical stability.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
from collections.abc import Sequence
|
| 33 |
+
|
| 34 |
+
from torch import Tensor, nn
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class MLPClassifier(nn.Module):
|
| 38 |
+
"""MLP head for binary classification, outputting a single logit."""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
input_dim: int = 640,
|
| 43 |
+
hidden_dims: Sequence[int] | None = None,
|
| 44 |
+
dropout: float = 0.3,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
dims = list(hidden_dims) if hidden_dims is not None else [256, 64]
|
| 48 |
+
if len(dims) != 2:
|
| 49 |
+
raise ValueError("hidden_dims must contain exactly two elements, e.g. [256, 64].")
|
| 50 |
+
|
| 51 |
+
hidden1, hidden2 = int(dims[0]), int(dims[1])
|
| 52 |
+
in_dim = int(input_dim)
|
| 53 |
+
|
| 54 |
+
# Design decision: [256, 64] balances expressiveness and overfitting risk,
|
| 55 |
+
# suitable for small-to-medium scale biological data.
|
| 56 |
+
# Design decision: first layer uses BatchNorm + Dropout; second layer retains
|
| 57 |
+
# a smaller Dropout for lightweight regularization.
|
| 58 |
+
self.layers = nn.Sequential(
|
| 59 |
+
nn.Linear(in_dim, hidden1),
|
| 60 |
+
nn.BatchNorm1d(hidden1),
|
| 61 |
+
nn.ReLU(),
|
| 62 |
+
nn.Dropout(dropout),
|
| 63 |
+
nn.Linear(hidden1, hidden2),
|
| 64 |
+
nn.ReLU(),
|
| 65 |
+
nn.Dropout(0.2),
|
| 66 |
+
nn.Linear(hidden2, 1),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 70 |
+
"""
|
| 71 |
+
Args:
|
| 72 |
+
x: Pooled sequence representation, shape `(batch, input_dim)`.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Logits, shape `(batch, 1)`.
|
| 76 |
+
"""
|
| 77 |
+
return self.layers(x)
|
deepmirt/model/cross_attention.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# pyright: basic, reportMissingImports=false
|
| 3 |
+
"""
|
| 4 |
+
Cross-Attention interaction module.
|
| 5 |
+
|
| 6 |
+
Data flow diagram (target as Query, miRNA as Key/Value)::
|
| 7 |
+
|
| 8 |
+
target_emb (B, T, D) -------------------------------> Q
|
| 9 |
+
|
|
| 10 |
+
| Multi-Head Cross Attention
|
| 11 |
+
| (batch_first=True)
|
| 12 |
+
|
|
| 13 |
+
miRNA_emb (B, M, D) ---> K, V -------------------->
|
| 14 |
+
|
| 15 |
+
Output: context_target (B, T, D)
|
| 16 |
+
|
| 17 |
+
Why target=Q and miRNA=K/V:
|
| 18 |
+
- Our task is to determine whether a target is regulated by a given miRNA.
|
| 19 |
+
- Having each target position query miRNA information aligns with the semantics
|
| 20 |
+
of locating potential binding sites on the target.
|
| 21 |
+
|
| 22 |
+
Mask convention:
|
| 23 |
+
- key_padding_mask=True indicates a padding position that should be ignored.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from torch import Tensor, nn
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CrossAttentionBlock(nn.Module):
|
| 33 |
+
"""Interaction module composed of stacked Cross-Attention + FFN layers."""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
embed_dim: int = 640,
|
| 38 |
+
num_heads: int = 8,
|
| 39 |
+
dropout: float = 0.1,
|
| 40 |
+
num_layers: int = 2,
|
| 41 |
+
) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.embed_dim = int(embed_dim)
|
| 44 |
+
self.num_heads = int(num_heads)
|
| 45 |
+
self.num_layers = int(num_layers)
|
| 46 |
+
|
| 47 |
+
self.layers = nn.ModuleList()
|
| 48 |
+
for _ in range(self.num_layers):
|
| 49 |
+
layer = nn.ModuleDict(
|
| 50 |
+
{
|
| 51 |
+
"cross_attn": nn.MultiheadAttention(
|
| 52 |
+
embed_dim=self.embed_dim,
|
| 53 |
+
num_heads=self.num_heads,
|
| 54 |
+
dropout=dropout,
|
| 55 |
+
batch_first=True,
|
| 56 |
+
),
|
| 57 |
+
"dropout_attn": nn.Dropout(dropout),
|
| 58 |
+
"norm1": nn.LayerNorm(self.embed_dim),
|
| 59 |
+
"ffn": nn.Sequential(
|
| 60 |
+
nn.Linear(self.embed_dim, self.embed_dim * 4),
|
| 61 |
+
nn.ReLU(),
|
| 62 |
+
nn.Dropout(dropout),
|
| 63 |
+
nn.Linear(self.embed_dim * 4, self.embed_dim),
|
| 64 |
+
),
|
| 65 |
+
"norm2": nn.LayerNorm(self.embed_dim),
|
| 66 |
+
}
|
| 67 |
+
)
|
| 68 |
+
self.layers.append(layer)
|
| 69 |
+
|
| 70 |
+
# Design decision: 2 layers by default is a lightweight yet effective trade-off;
|
| 71 |
+
# establish a trainable baseline first, then deepen based on data scale.
|
| 72 |
+
# Design decision: 8 attention heads by default improves interaction modeling across
|
| 73 |
+
# different subspaces while keeping GPU memory overhead manageable.
|
| 74 |
+
|
| 75 |
+
def forward(
|
| 76 |
+
self,
|
| 77 |
+
query: Tensor,
|
| 78 |
+
key_value: Tensor,
|
| 79 |
+
key_padding_mask: Tensor | None = None,
|
| 80 |
+
) -> Tensor:
|
| 81 |
+
"""
|
| 82 |
+
Args:
|
| 83 |
+
query: Target representation, shape `(batch, target_len, embed_dim)`.
|
| 84 |
+
key_value: miRNA representation, shape `(batch, mirna_len, embed_dim)`.
|
| 85 |
+
key_padding_mask: miRNA padding mask, shape `(batch, mirna_len)`,
|
| 86 |
+
where True indicates positions to ignore.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Updated target representation, shape `(batch, target_len, embed_dim)`.
|
| 90 |
+
"""
|
| 91 |
+
hidden = query
|
| 92 |
+
attn_mask = key_padding_mask
|
| 93 |
+
if attn_mask is not None and attn_mask.dtype is not torch.bool:
|
| 94 |
+
attn_mask = attn_mask.to(dtype=torch.bool)
|
| 95 |
+
|
| 96 |
+
for layer in self.layers:
|
| 97 |
+
# Step 1: Cross-Attention (target queries miRNA)
|
| 98 |
+
attn_out, _ = layer["cross_attn"](
|
| 99 |
+
query=hidden,
|
| 100 |
+
key=key_value,
|
| 101 |
+
value=key_value,
|
| 102 |
+
key_padding_mask=attn_mask,
|
| 103 |
+
need_weights=False,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Step 2: Residual + LayerNorm to stabilize deep training and mitigate vanishing gradients
|
| 107 |
+
hidden = layer["norm1"](hidden + layer["dropout_attn"](attn_out))
|
| 108 |
+
|
| 109 |
+
# Step 3: Feed-forward network refines channel-wise features
|
| 110 |
+
ffn_out = layer["ffn"](hidden)
|
| 111 |
+
|
| 112 |
+
# Step 4: Residual + LayerNorm
|
| 113 |
+
hidden = layer["norm2"](hidden + ffn_out)
|
| 114 |
+
|
| 115 |
+
return hidden
|
deepmirt/model/mirna_target_model.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# pyright: basic, reportMissingImports=false
|
| 3 |
+
"""
|
| 4 |
+
Full miRNA-target model: shared RNA-FM encoder + Cross-Attention + MLP classifier head.
|
| 5 |
+
|
| 6 |
+
Complete data flow (with tensor shapes):
|
| 7 |
+
|
| 8 |
+
miRNA tokens (B, M_tok) ---> [RNA-FM Encoder] ---> miRNA_emb (B, M, D) ---β
|
| 9 |
+
|
|
| 10 |
+
v
|
| 11 |
+
target tokens (B, T_tok) ---> [RNA-FM Encoder] ---> target_emb (B, T, D) --> [Cross-Attention]
|
| 12 |
+
|
|
| 13 |
+
v
|
| 14 |
+
cross_out (B, T, D)
|
| 15 |
+
|
|
| 16 |
+
v
|
| 17 |
+
masked mean pool
|
| 18 |
+
|
|
| 19 |
+
v
|
| 20 |
+
(B, D)
|
| 21 |
+
|
|
| 22 |
+
v
|
| 23 |
+
[MLP Head]
|
| 24 |
+
|
|
| 25 |
+
v
|
| 26 |
+
logits
|
| 27 |
+
(B, 1)
|
| 28 |
+
|
| 29 |
+
Where D is automatically inferred from RNA-FM (typically 640) to avoid hard-coding.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
from collections.abc import Sequence
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor, nn
|
| 38 |
+
|
| 39 |
+
from .classifier import MLPClassifier
|
| 40 |
+
from .cross_attention import CrossAttentionBlock
|
| 41 |
+
from .rnafm_encoder import RNAFMEncoder
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MiRNATargetModel(nn.Module):
|
| 45 |
+
"""End-to-end model for miRNA-target binary classification."""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
freeze_backbone: bool = True,
|
| 50 |
+
cross_attn_heads: int = 8,
|
| 51 |
+
cross_attn_layers: int = 2,
|
| 52 |
+
classifier_hidden: Sequence[int] | None = None,
|
| 53 |
+
dropout: float = 0.3,
|
| 54 |
+
) -> None:
|
| 55 |
+
super().__init__()
|
| 56 |
+
hidden_dims = list(classifier_hidden) if classifier_hidden is not None else [256, 64]
|
| 57 |
+
|
| 58 |
+
self.encoder = RNAFMEncoder(freeze_backbone=freeze_backbone)
|
| 59 |
+
embed_dim = self.encoder.embed_dim
|
| 60 |
+
|
| 61 |
+
# Design decision: the interaction layer uses a smaller dropout (~1/3 of main dropout)
|
| 62 |
+
# to preserve attention signals while still providing basic regularization.
|
| 63 |
+
self.cross_attention = CrossAttentionBlock(
|
| 64 |
+
embed_dim=embed_dim,
|
| 65 |
+
num_heads=cross_attn_heads,
|
| 66 |
+
dropout=dropout * 0.33,
|
| 67 |
+
num_layers=cross_attn_layers,
|
| 68 |
+
)
|
| 69 |
+
self.classifier = MLPClassifier(
|
| 70 |
+
input_dim=embed_dim,
|
| 71 |
+
hidden_dims=hidden_dims,
|
| 72 |
+
dropout=dropout,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def forward(
|
| 76 |
+
self,
|
| 77 |
+
mirna_tokens: Tensor,
|
| 78 |
+
target_tokens: Tensor,
|
| 79 |
+
attention_mask_mirna: Tensor | None = None,
|
| 80 |
+
attention_mask_target: Tensor | None = None,
|
| 81 |
+
) -> Tensor:
|
| 82 |
+
"""
|
| 83 |
+
Forward pass (step by step):
|
| 84 |
+
1) miRNA encoding: `(B, M_tok)` -> `(B, M, D)`
|
| 85 |
+
2) target encoding: `(B, T_tok)` -> `(B, T, D)`
|
| 86 |
+
3) Build key_padding_mask: attention_mask(1=real, 0=padding) -> (==0)
|
| 87 |
+
4) Cross-Attention: target(Q) queries miRNA(K/V) -> `(B, T, D)`
|
| 88 |
+
5) Masked mean pooling over target sequence -> `(B, D)`
|
| 89 |
+
6) Classifier head outputs logits -> `(B, 1)`
|
| 90 |
+
"""
|
| 91 |
+
# Step 1: Shared encoder processes miRNA (shared weights)
|
| 92 |
+
mirna_emb = self.encoder(mirna_tokens)
|
| 93 |
+
|
| 94 |
+
# Step 2: Same encoder processes target to ensure consistent representation space
|
| 95 |
+
target_emb = self.encoder(target_tokens)
|
| 96 |
+
|
| 97 |
+
# Step 3: PyTorch MHA key_padding_mask convention: True=ignore.
|
| 98 |
+
key_padding_mask = None
|
| 99 |
+
if attention_mask_mirna is not None:
|
| 100 |
+
key_padding_mask = attention_mask_mirna == 0
|
| 101 |
+
|
| 102 |
+
# Step 4: target as Query, miRNA as Key/Value.
|
| 103 |
+
cross_out = self.cross_attention(
|
| 104 |
+
query=target_emb,
|
| 105 |
+
key_value=mirna_emb,
|
| 106 |
+
key_padding_mask=key_padding_mask,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Step 5: Masked mean pooling over target sequence to obtain a fixed-length representation.
|
| 110 |
+
if attention_mask_target is None:
|
| 111 |
+
pooling_mask = torch.ones(
|
| 112 |
+
cross_out.size(0),
|
| 113 |
+
cross_out.size(1),
|
| 114 |
+
1,
|
| 115 |
+
device=cross_out.device,
|
| 116 |
+
dtype=cross_out.dtype,
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
pooling_mask = attention_mask_target.to(dtype=cross_out.dtype).unsqueeze(-1)
|
| 120 |
+
|
| 121 |
+
summed = (cross_out * pooling_mask).sum(dim=1)
|
| 122 |
+
denom = pooling_mask.sum(dim=1).clamp_min(1e-6)
|
| 123 |
+
pooled = summed / denom
|
| 124 |
+
|
| 125 |
+
# Step 6: Output raw logits without applying sigmoid.
|
| 126 |
+
logits = self.classifier(pooled)
|
| 127 |
+
return logits
|
deepmirt/model/rnafm_encoder.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# pyright: basic, reportMissingImports=false
|
| 3 |
+
"""
|
| 4 |
+
RNA-FM encoder wrapper (Shared Encoder).
|
| 5 |
+
|
| 6 |
+
Architecture diagram (single-path encoding):
|
| 7 |
+
|
| 8 |
+
Input tokens (B, L)
|
| 9 |
+
|
|
| 10 |
+
v
|
| 11 |
+
[RNA-FM: 12-layer Transformer]
|
| 12 |
+
|
|
| 13 |
+
v
|
| 14 |
+
representations[12] (B, L, D)
|
| 15 |
+
D is typically 640
|
| 16 |
+
|
| 17 |
+
Training strategy diagram (freeze / staged unfreezing):
|
| 18 |
+
|
| 19 |
+
Frozen phase: [L1][L2][L3]...[L12] all requires_grad=False
|
| 20 |
+
Unfrozen phase: [L1]...[L9][L10][L11][L12]
|
| 21 |
+
^^^^^^^^
|
| 22 |
+
only unfreeze top N layers (e.g., N=3)
|
| 23 |
+
|
| 24 |
+
Notes:
|
| 25 |
+
- Both miRNA and target are RNA sequences, so sharing a single RNA-FM encoder is the most natural approach.
|
| 26 |
+
- `repr_layers=[12]` extracts the 12th (final) layer output as the contextualized representation.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
from collections.abc import Sequence
|
| 32 |
+
|
| 33 |
+
import fm
|
| 34 |
+
from torch import Tensor, nn
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class RNAFMEncoder(nn.Module):
|
| 38 |
+
"""Lightweight wrapper around RNA-FM providing forward encoding, freezing, and staged unfreezing."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, freeze_backbone: bool = True) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.model, self.alphabet = fm.pretrained.rna_fm_t12()
|
| 43 |
+
self.num_layers = len(self.model.layers)
|
| 44 |
+
self.embed_dim = self._infer_embed_dim(default=640)
|
| 45 |
+
|
| 46 |
+
# Design decision: freeze backbone by default to first stabilize training of the
|
| 47 |
+
# upper interaction module and classifier head, avoiding catastrophic forgetting
|
| 48 |
+
# from full fine-tuning on small datasets.
|
| 49 |
+
if freeze_backbone:
|
| 50 |
+
self.freeze()
|
| 51 |
+
|
| 52 |
+
def _infer_embed_dim(self, default: int = 640) -> int:
|
| 53 |
+
"""Try to infer the embedding dimension from the RNA-FM model; fall back to default on failure."""
|
| 54 |
+
model_embed_dim = getattr(self.model, "embed_dim", None)
|
| 55 |
+
if model_embed_dim is not None:
|
| 56 |
+
return int(model_embed_dim)
|
| 57 |
+
|
| 58 |
+
model_args = getattr(self.model, "args", None)
|
| 59 |
+
if model_args is not None and hasattr(model_args, "embed_dim"):
|
| 60 |
+
return int(model_args.embed_dim)
|
| 61 |
+
|
| 62 |
+
embed_tokens = getattr(self.model, "embed_tokens", None)
|
| 63 |
+
if embed_tokens is not None and hasattr(embed_tokens, "embedding_dim"):
|
| 64 |
+
return int(embed_tokens.embedding_dim)
|
| 65 |
+
|
| 66 |
+
return int(default)
|
| 67 |
+
|
| 68 |
+
def forward(self, tokens: Tensor, repr_layers: Sequence[int] | None = None) -> Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Encode an RNA token sequence.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
tokens: Token tensor of shape `(batch, seq_len)`.
|
| 74 |
+
repr_layers: List of layer indices to extract. Defaults to `[12]` (final layer).
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Contextualized representations of shape `(batch, seq_len, embed_dim)`.
|
| 78 |
+
"""
|
| 79 |
+
if repr_layers is None:
|
| 80 |
+
# Design decision: use the final layer representation by default (most semantically
|
| 81 |
+
# complete), consistent with common pre-trained model usage.
|
| 82 |
+
repr_layers = [self.num_layers]
|
| 83 |
+
|
| 84 |
+
layer_ids = list(repr_layers)
|
| 85 |
+
if not layer_ids:
|
| 86 |
+
raise ValueError("repr_layers must not be empty; provide at least one layer index.")
|
| 87 |
+
|
| 88 |
+
outputs = self.model(tokens, repr_layers=layer_ids)
|
| 89 |
+
# Note: typically repr_layers=[12] is passed, so this retrieves representations[12].
|
| 90 |
+
final_layer_id = max(layer_ids)
|
| 91 |
+
return outputs["representations"][final_layer_id]
|
| 92 |
+
|
| 93 |
+
def freeze(self) -> None:
|
| 94 |
+
"""Freeze all RNA-FM backbone parameters (requires_grad=False)."""
|
| 95 |
+
for param in self.model.parameters():
|
| 96 |
+
param.requires_grad = False
|
| 97 |
+
|
| 98 |
+
def unfreeze(self, num_layers: int = 3) -> None:
|
| 99 |
+
"""
|
| 100 |
+
Unfreeze only the per-layer parameters of the top N Transformer layers.
|
| 101 |
+
|
| 102 |
+
Example: when `num_layers=3`, unfreezes layer[9], layer[10], layer[11].
|
| 103 |
+
|
| 104 |
+
Note: global LayerNorm (e.g., emb_layer_norm_after) is NOT unfrozen,
|
| 105 |
+
because unfreezing it would shift the output distribution of all layers at once,
|
| 106 |
+
leading to training instability.
|
| 107 |
+
"""
|
| 108 |
+
# Design decision: always freeze all first, then selectively unfreeze, ensuring the
|
| 109 |
+
# set of trainable parameters is controllable and reproducible.
|
| 110 |
+
self.freeze()
|
| 111 |
+
|
| 112 |
+
n = max(0, min(int(num_layers), self.num_layers))
|
| 113 |
+
if n > 0:
|
| 114 |
+
start = self.num_layers - n
|
| 115 |
+
for layer in self.model.layers[start:]:
|
| 116 |
+
for param in layer.parameters():
|
| 117 |
+
param.requires_grad = True
|
deepmirt/predict.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Public prediction API for DeepMiRT.
|
| 4 |
+
|
| 5 |
+
Provides simple interfaces for miRNA-target interaction prediction:
|
| 6 |
+
- predict(): Python API for sequence pairs
|
| 7 |
+
- predict_from_csv(): Batch prediction from CSV files
|
| 8 |
+
- cli_main(): Command-line entry point
|
| 9 |
+
|
| 10 |
+
Model weights are automatically downloaded from Hugging Face Hub on first use.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import logging
|
| 17 |
+
import re
|
| 18 |
+
import sys
|
| 19 |
+
import warnings
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import pandas as pd
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
# Hugging Face Hub model repository
|
| 28 |
+
HF_REPO_ID = "liuliu2333/deepmirt"
|
| 29 |
+
HF_CKPT_FILENAME = "epoch=27-val_auroc=0.9612.ckpt"
|
| 30 |
+
HF_CONFIG_FILENAME = "config.yaml"
|
| 31 |
+
|
| 32 |
+
# Valid nucleotide characters (before TβU conversion)
|
| 33 |
+
_VALID_BASES = re.compile(r"^[AUGCTaugct]+$")
|
| 34 |
+
|
| 35 |
+
# Module-level model cache (avoids reloading 495 MB on every call)
|
| 36 |
+
_model_cache: dict = {}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _get_model_files() -> tuple[str, str]:
|
| 40 |
+
"""Download model checkpoint and config from Hugging Face Hub (cached locally)."""
|
| 41 |
+
from huggingface_hub import hf_hub_download
|
| 42 |
+
|
| 43 |
+
ckpt_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_CKPT_FILENAME)
|
| 44 |
+
config_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_CONFIG_FILENAME)
|
| 45 |
+
return ckpt_path, config_path
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _get_cached_model(device: str):
|
| 49 |
+
"""Load model and alphabet, caching for subsequent calls."""
|
| 50 |
+
if device not in _model_cache:
|
| 51 |
+
import fm
|
| 52 |
+
|
| 53 |
+
from deepmirt.evaluation.predict import load_model_from_checkpoint
|
| 54 |
+
|
| 55 |
+
ckpt_path, config_path = _get_model_files()
|
| 56 |
+
logger.info("Loading DeepMiRT model (first call, will be cached)...")
|
| 57 |
+
lit_model, config = load_model_from_checkpoint(ckpt_path, config_path, device)
|
| 58 |
+
_, alphabet = fm.pretrained.rna_fm_t12()
|
| 59 |
+
_model_cache[device] = (lit_model, alphabet, ckpt_path, config_path)
|
| 60 |
+
logger.info("Model loaded and cached.")
|
| 61 |
+
|
| 62 |
+
return _model_cache[device]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _validate_sequences(
|
| 66 |
+
mirna_seqs: list[str], target_seqs: list[str]
|
| 67 |
+
) -> tuple[list[str], list[str]]:
|
| 68 |
+
"""Validate and clean input sequences."""
|
| 69 |
+
cleaned_mirna = []
|
| 70 |
+
cleaned_target = []
|
| 71 |
+
|
| 72 |
+
for i, (m, t) in enumerate(zip(mirna_seqs, target_seqs)):
|
| 73 |
+
m = str(m).strip().upper()
|
| 74 |
+
t = str(t).strip().upper()
|
| 75 |
+
|
| 76 |
+
if not m:
|
| 77 |
+
raise ValueError(f"Empty miRNA sequence at index {i}")
|
| 78 |
+
if not t:
|
| 79 |
+
raise ValueError(f"Empty target sequence at index {i}")
|
| 80 |
+
|
| 81 |
+
if not _VALID_BASES.match(m):
|
| 82 |
+
invalid = set(m) - set("AUGCT")
|
| 83 |
+
raise ValueError(
|
| 84 |
+
f"miRNA at index {i} contains invalid characters: {invalid}. "
|
| 85 |
+
f"Only A/U/G/C/T are allowed."
|
| 86 |
+
)
|
| 87 |
+
if not _VALID_BASES.match(t):
|
| 88 |
+
invalid = set(t) - set("AUGCT")
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"Target at index {i} contains invalid characters: {invalid}. "
|
| 91 |
+
f"Only A/U/G/C/T are allowed."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
cleaned_mirna.append(m)
|
| 95 |
+
cleaned_target.append(t)
|
| 96 |
+
|
| 97 |
+
# Warn about unusual lengths (non-blocking)
|
| 98 |
+
mirna_lens = [len(s) for s in cleaned_mirna]
|
| 99 |
+
target_lens = [len(s) for s in cleaned_target]
|
| 100 |
+
if any(n < 15 or n > 30 for n in mirna_lens):
|
| 101 |
+
warnings.warn(
|
| 102 |
+
"Some miRNA sequences have unusual length (expected 18-25 nt). "
|
| 103 |
+
"Results may be less reliable.",
|
| 104 |
+
stacklevel=3,
|
| 105 |
+
)
|
| 106 |
+
if any(n != 40 for n in target_lens):
|
| 107 |
+
warnings.warn(
|
| 108 |
+
"Some target sequences are not 40 nt. The model was trained on 40-nt "
|
| 109 |
+
"target fragments. Results may be less reliable for other lengths.",
|
| 110 |
+
stacklevel=3,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return cleaned_mirna, cleaned_target
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def predict(
|
| 117 |
+
mirna_seqs: list[str],
|
| 118 |
+
target_seqs: list[str],
|
| 119 |
+
device: str = "cpu",
|
| 120 |
+
batch_size: int = 256,
|
| 121 |
+
) -> np.ndarray:
|
| 122 |
+
"""
|
| 123 |
+
Predict miRNA-target interaction probabilities.
|
| 124 |
+
|
| 125 |
+
Automatically downloads model weights from Hugging Face Hub on first call.
|
| 126 |
+
The model is cached in memory for subsequent calls.
|
| 127 |
+
Sequences can be in DNA (T) or RNA (U) format -- conversion is handled internally.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
mirna_seqs: List of miRNA sequences (typically 18-25 nt).
|
| 131 |
+
target_seqs: List of target site sequences (40 nt recommended).
|
| 132 |
+
device: Inference device ("cpu" or "cuda").
|
| 133 |
+
batch_size: Batch size for inference.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Numpy array of interaction probabilities, shape (n_samples,).
|
| 137 |
+
Values range from 0 (no interaction) to 1 (strong interaction).
|
| 138 |
+
|
| 139 |
+
Example:
|
| 140 |
+
>>> from deepmirt import predict
|
| 141 |
+
>>> probs = predict(
|
| 142 |
+
... mirna_seqs=["UGAGGUAGUAGGUUGUAUAGUU"],
|
| 143 |
+
... target_seqs=["ACUGCAGCAUAUCUACUAUUUGCUACUGUAACCAUUGAUCU"],
|
| 144 |
+
... )
|
| 145 |
+
>>> print(f"Interaction probability: {probs[0]:.4f}")
|
| 146 |
+
"""
|
| 147 |
+
if len(mirna_seqs) != len(target_seqs):
|
| 148 |
+
raise ValueError(
|
| 149 |
+
f"mirna_seqs and target_seqs must have the same length, "
|
| 150 |
+
f"got {len(mirna_seqs)} and {len(target_seqs)}"
|
| 151 |
+
)
|
| 152 |
+
if len(mirna_seqs) == 0:
|
| 153 |
+
return np.array([])
|
| 154 |
+
|
| 155 |
+
mirna_seqs, target_seqs = _validate_sequences(mirna_seqs, target_seqs)
|
| 156 |
+
|
| 157 |
+
from deepmirt.evaluation.predict import predict_on_sequences
|
| 158 |
+
|
| 159 |
+
lit_model, alphabet, ckpt_path, config_path = _get_cached_model(device)
|
| 160 |
+
|
| 161 |
+
return predict_on_sequences(
|
| 162 |
+
ckpt_path=ckpt_path,
|
| 163 |
+
config_path=config_path,
|
| 164 |
+
mirna_seqs=mirna_seqs,
|
| 165 |
+
target_seqs=target_seqs,
|
| 166 |
+
batch_size=batch_size,
|
| 167 |
+
device=device,
|
| 168 |
+
_lit_model=lit_model,
|
| 169 |
+
_alphabet=alphabet,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def predict_from_csv(
|
| 174 |
+
csv_path: str,
|
| 175 |
+
output_path: str | None = None,
|
| 176 |
+
device: str = "cpu",
|
| 177 |
+
batch_size: int = 256,
|
| 178 |
+
mirna_col: str = "mirna_seq",
|
| 179 |
+
target_col: str = "target_seq",
|
| 180 |
+
) -> pd.DataFrame:
|
| 181 |
+
"""
|
| 182 |
+
Batch prediction from a CSV file.
|
| 183 |
+
|
| 184 |
+
The CSV must contain columns for miRNA and target sequences.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
csv_path: Path to input CSV file.
|
| 188 |
+
output_path: Path to save results CSV. If None, results are only returned.
|
| 189 |
+
device: Inference device ("cpu" or "cuda").
|
| 190 |
+
batch_size: Batch size for inference.
|
| 191 |
+
mirna_col: Column name for miRNA sequences.
|
| 192 |
+
target_col: Column name for target sequences.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
DataFrame with original columns plus 'probability' and 'prediction'.
|
| 196 |
+
"""
|
| 197 |
+
df = pd.read_csv(csv_path)
|
| 198 |
+
|
| 199 |
+
if mirna_col not in df.columns or target_col not in df.columns:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"CSV must contain columns '{mirna_col}' and '{target_col}'. "
|
| 202 |
+
f"Found columns: {list(df.columns)}"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
mirna_seqs = df[mirna_col].astype(str).tolist()
|
| 206 |
+
target_seqs = df[target_col].astype(str).tolist()
|
| 207 |
+
|
| 208 |
+
probs = predict(mirna_seqs, target_seqs, device=device, batch_size=batch_size)
|
| 209 |
+
|
| 210 |
+
df["probability"] = probs
|
| 211 |
+
df["prediction"] = (probs >= 0.5).astype(int)
|
| 212 |
+
|
| 213 |
+
if output_path:
|
| 214 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 215 |
+
df.to_csv(output_path, index=False)
|
| 216 |
+
logger.info(f"Results saved to {output_path}")
|
| 217 |
+
|
| 218 |
+
return df
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def scan_targets(
|
| 222 |
+
mirna_fasta: str | dict[str, str],
|
| 223 |
+
target_fasta: str,
|
| 224 |
+
output_prefix: str | None = None,
|
| 225 |
+
device: str = "cpu",
|
| 226 |
+
batch_size: int = 512,
|
| 227 |
+
prob_threshold: float = 0.5,
|
| 228 |
+
scan_mode: str = "hybrid",
|
| 229 |
+
stride: int = 20,
|
| 230 |
+
top_k: int | None = None,
|
| 231 |
+
) -> list:
|
| 232 |
+
"""
|
| 233 |
+
Scan target sequences for miRNA binding sites genome-wide.
|
| 234 |
+
|
| 235 |
+
Identifies candidate binding positions using seed matching and/or sliding
|
| 236 |
+
windows, then scores each position with the DeepMiRT model.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
mirna_fasta: Path to miRNA FASTA file, or dict of {id: sequence}.
|
| 240 |
+
target_fasta: Path to target FASTA file (e.g. 3'UTRs or transcripts).
|
| 241 |
+
output_prefix: If given, write {prefix}_details.txt, {prefix}_hits.tsv,
|
| 242 |
+
and {prefix}_summary.tsv.
|
| 243 |
+
device: Inference device ("cpu" or "cuda").
|
| 244 |
+
batch_size: Batch size for GPU inference.
|
| 245 |
+
prob_threshold: Minimum probability to report a hit (default 0.5).
|
| 246 |
+
scan_mode: Scanning strategy -- "seed" (fastest), "hybrid" (default),
|
| 247 |
+
or "exhaustive" (slowest, stride-1).
|
| 248 |
+
stride: Window stride for hybrid/exhaustive modes (default 20).
|
| 249 |
+
top_k: If set, keep only the top-K hits per miRNA-target pair.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
List of TargetScanResult objects, one per miRNA-target pair with hits.
|
| 253 |
+
Each result contains a list of ScanHit objects with position, probability,
|
| 254 |
+
seed type, and the 40nt window sequence.
|
| 255 |
+
|
| 256 |
+
Example:
|
| 257 |
+
>>> from deepmirt import scan_targets
|
| 258 |
+
>>> results = scan_targets(
|
| 259 |
+
... mirna_fasta={"let-7": "UGAGGUAGUAGGUUGUAUAGUU"},
|
| 260 |
+
... target_fasta="3utrs.fa",
|
| 261 |
+
... output_prefix="results/scan",
|
| 262 |
+
... device="cuda",
|
| 263 |
+
... )
|
| 264 |
+
>>> for r in results:
|
| 265 |
+
... for hit in r.hits:
|
| 266 |
+
... print(f"{r.target_id} pos={hit.position} prob={hit.probability:.3f}")
|
| 267 |
+
"""
|
| 268 |
+
from deepmirt.scanning.scanner import TargetScanner
|
| 269 |
+
|
| 270 |
+
scanner = TargetScanner(
|
| 271 |
+
device=device,
|
| 272 |
+
batch_size=batch_size,
|
| 273 |
+
prob_threshold=prob_threshold,
|
| 274 |
+
scan_mode=scan_mode,
|
| 275 |
+
stride=stride,
|
| 276 |
+
top_k=top_k,
|
| 277 |
+
)
|
| 278 |
+
return scanner.scan(mirna_fasta, target_fasta, output_prefix)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def cli_main() -> None:
|
| 282 |
+
"""Command-line entry point for deepmirt-predict."""
|
| 283 |
+
parser = argparse.ArgumentParser(
|
| 284 |
+
prog="deepmirt-predict",
|
| 285 |
+
description="DeepMiRT: Predict miRNA-target interactions",
|
| 286 |
+
)
|
| 287 |
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
| 288 |
+
|
| 289 |
+
# Single prediction
|
| 290 |
+
single = subparsers.add_parser("single", help="Predict a single miRNA-target pair")
|
| 291 |
+
single.add_argument("--mirna", required=True, help="miRNA sequence")
|
| 292 |
+
single.add_argument("--target", required=True, help="Target sequence (40 nt)")
|
| 293 |
+
single.add_argument("--device", default="cpu", help="Device (cpu or cuda)")
|
| 294 |
+
|
| 295 |
+
# Batch prediction
|
| 296 |
+
batch = subparsers.add_parser("batch", help="Batch prediction from CSV")
|
| 297 |
+
batch.add_argument("--input", required=True, help="Input CSV path")
|
| 298 |
+
batch.add_argument("--output", required=True, help="Output CSV path")
|
| 299 |
+
batch.add_argument("--device", default="cpu", help="Device (cpu or cuda)")
|
| 300 |
+
batch.add_argument("--batch-size", type=int, default=256, help="Batch size")
|
| 301 |
+
batch.add_argument("--mirna-col", default="mirna_seq", help="miRNA column name")
|
| 302 |
+
batch.add_argument("--target-col", default="target_seq", help="Target column name")
|
| 303 |
+
|
| 304 |
+
# Genome-wide scanning
|
| 305 |
+
scan = subparsers.add_parser(
|
| 306 |
+
"scan", help="Scan target sequences for miRNA binding sites"
|
| 307 |
+
)
|
| 308 |
+
scan_input = scan.add_mutually_exclusive_group(required=True)
|
| 309 |
+
scan_input.add_argument("--mirna-fasta", help="miRNA FASTA file")
|
| 310 |
+
scan_input.add_argument("--mirna", help="Single miRNA sequence (use with --mirna-id)")
|
| 311 |
+
scan.add_argument("--mirna-id", default="query_mirna", help="miRNA ID (with --mirna)")
|
| 312 |
+
scan.add_argument("--target-fasta", required=True, help="Target FASTA file")
|
| 313 |
+
scan.add_argument("--output", required=True, help="Output prefix")
|
| 314 |
+
scan.add_argument("--device", default="cpu", help="Device (cpu or cuda)")
|
| 315 |
+
scan.add_argument("--batch-size", type=int, default=512, help="Batch size")
|
| 316 |
+
scan.add_argument("--threshold", type=float, default=0.5, help="Probability threshold")
|
| 317 |
+
scan.add_argument(
|
| 318 |
+
"--scan-mode", default="hybrid", choices=["seed", "hybrid", "exhaustive"],
|
| 319 |
+
help="Scanning mode (default: hybrid)",
|
| 320 |
+
)
|
| 321 |
+
scan.add_argument("--stride", type=int, default=20, help="Window stride (default: 20)")
|
| 322 |
+
scan.add_argument("--top-k", type=int, default=None, help="Keep top-K hits per target")
|
| 323 |
+
|
| 324 |
+
args = parser.parse_args()
|
| 325 |
+
|
| 326 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 327 |
+
|
| 328 |
+
if args.command == "single":
|
| 329 |
+
probs = predict([args.mirna], [args.target], device=args.device)
|
| 330 |
+
prob = probs[0]
|
| 331 |
+
label = "INTERACTION" if prob >= 0.5 else "NO INTERACTION"
|
| 332 |
+
print(f"Probability: {prob:.4f}")
|
| 333 |
+
print(f"Prediction: {label}")
|
| 334 |
+
elif args.command == "batch":
|
| 335 |
+
df = predict_from_csv(
|
| 336 |
+
csv_path=args.input,
|
| 337 |
+
output_path=args.output,
|
| 338 |
+
device=args.device,
|
| 339 |
+
batch_size=args.batch_size,
|
| 340 |
+
mirna_col=args.mirna_col,
|
| 341 |
+
target_col=args.target_col,
|
| 342 |
+
)
|
| 343 |
+
print(f"Processed {len(df)} samples. Results saved to {args.output}")
|
| 344 |
+
elif args.command == "scan":
|
| 345 |
+
if args.mirna:
|
| 346 |
+
mirna_input: str | dict[str, str] = {args.mirna_id: args.mirna}
|
| 347 |
+
else:
|
| 348 |
+
mirna_input = args.mirna_fasta
|
| 349 |
+
|
| 350 |
+
results = scan_targets(
|
| 351 |
+
mirna_fasta=mirna_input,
|
| 352 |
+
target_fasta=args.target_fasta,
|
| 353 |
+
output_prefix=args.output,
|
| 354 |
+
device=args.device,
|
| 355 |
+
batch_size=args.batch_size,
|
| 356 |
+
prob_threshold=args.threshold,
|
| 357 |
+
scan_mode=args.scan_mode,
|
| 358 |
+
stride=args.stride,
|
| 359 |
+
top_k=args.top_k,
|
| 360 |
+
)
|
| 361 |
+
total_hits = sum(len(r.hits) for r in results)
|
| 362 |
+
print(
|
| 363 |
+
f"Scan complete: {len(results)} miRNA-target pairs, "
|
| 364 |
+
f"{total_hits} hits above threshold {args.threshold}"
|
| 365 |
+
)
|
| 366 |
+
print(f"Results: {args.output}_details.txt, _hits.tsv, _summary.tsv")
|
| 367 |
+
else:
|
| 368 |
+
parser.print_help()
|
| 369 |
+
sys.exit(1)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
if __name__ == "__main__":
|
| 373 |
+
cli_main()
|
deepmirt/training/__init__.py
ADDED
|
File without changes
|
deepmirt/training/lightning_module.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# pyright: basic, reportMissingImports=false
|
| 3 |
+
"""
|
| 4 |
+
PyTorch Lightning training module for miRNA-target prediction.
|
| 5 |
+
|
| 6 |
+
[Lightning Training Loop Overview -- Full Lifecycle of One Epoch]
|
| 7 |
+
|
| 8 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
+
β Lifecycle of One Epoch β
|
| 10 |
+
β β
|
| 11 |
+
β on_train_epoch_start() β
|
| 12 |
+
β β β
|
| 13 |
+
β v β
|
| 14 |
+
β ββββββββββββββββββββββββββββββββββββββββββββ β
|
| 15 |
+
β β for batch in train_dataloader: β β
|
| 16 |
+
β β training_step(batch) β β forward + loss β
|
| 17 |
+
β β backward() [automatic] β β backpropagation β
|
| 18 |
+
β β optimizer.step() [automatic] β β update params β
|
| 19 |
+
β ββββββββββββββββββββββββββββββββββββββββββββ β
|
| 20 |
+
β β β
|
| 21 |
+
β v β
|
| 22 |
+
β on_train_epoch_end() β
|
| 23 |
+
β β β
|
| 24 |
+
β v β
|
| 25 |
+
β ββββββββββββββββββββββββββββββββββββββββββββ β
|
| 26 |
+
β β for batch in val_dataloader: β β
|
| 27 |
+
β β validation_step(batch) β β forward only, no β
|
| 28 |
+
β β β param updates β
|
| 29 |
+
β ββββββββββββββββββββββββββββββββββββββββββββ β
|
| 30 |
+
β β β
|
| 31 |
+
β v β
|
| 32 |
+
β on_validation_epoch_end() β
|
| 33 |
+
β β β
|
| 34 |
+
β v β
|
| 35 |
+
β lr_scheduler.step() [automatic] β
|
| 36 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
|
| 38 |
+
Things Lightning handles automatically (no manual code needed):
|
| 39 |
+
- loss.backward()
|
| 40 |
+
- optimizer.zero_grad()
|
| 41 |
+
- optimizer.step()
|
| 42 |
+
- Switching to model.eval() and torch.no_grad() during validation
|
| 43 |
+
- Gradient accumulation (if accumulate_grad_batches is configured)
|
| 44 |
+
- Multi-GPU distributed synchronization (if using DDP)
|
| 45 |
+
|
| 46 |
+
You only need to focus on:
|
| 47 |
+
- training_step(): return the loss
|
| 48 |
+
- validation_step(): compute validation metrics
|
| 49 |
+
- configure_optimizers(): define the optimizer and learning rate scheduler
|
| 50 |
+
|
| 51 |
+
[Key Design Decisions]
|
| 52 |
+
|
| 53 |
+
1. BCEWithLogitsLoss vs BCELoss:
|
| 54 |
+
- BCEWithLogitsLoss = Sigmoid + BCELoss, using the log-sum-exp trick internally
|
| 55 |
+
- Numerical stability: directly computing log(sigmoid(x)) can produce log(0) at
|
| 56 |
+
extreme values. BCEWithLogitsLoss uses the equivalent formula
|
| 57 |
+
max(x,0) - x*y + log(1+exp(-|x|)) to avoid overflow
|
| 58 |
+
- Therefore the model outputs raw logits (no sigmoid); the loss function handles it
|
| 59 |
+
|
| 60 |
+
2. Differential Learning Rate:
|
| 61 |
+
- Backbone (RNA-FM): base_lr x 0.01 -- pretrained weights encode rich RNA knowledge;
|
| 62 |
+
a large learning rate would cause catastrophic forgetting of this knowledge
|
| 63 |
+
- Cross-attention layers: base_lr x 0.1 -- new module but needs stable attention
|
| 64 |
+
pattern learning
|
| 65 |
+
- Classifier head: base_lr x 1.0 -- learning from scratch, needs the highest
|
| 66 |
+
learning rate for fast convergence
|
| 67 |
+
|
| 68 |
+
3. Evaluation Metric Selection:
|
| 69 |
+
- AUROC (Area Under ROC Curve): measures the model's ranking ability, i.e., the
|
| 70 |
+
probability of ranking a positive sample above a negative one. Threshold-independent.
|
| 71 |
+
- AUPRC (Average Precision / PR-AUC): measures the precision-recall tradeoff;
|
| 72 |
+
more sensitive than AUROC on class-imbalanced data (biological data often has
|
| 73 |
+
positive:negative ratios of 1:10+)
|
| 74 |
+
- Accuracy: intuitive but can be misleading on imbalanced data (predicting all
|
| 75 |
+
negatives still yields 90% accuracy)
|
| 76 |
+
- F1: harmonic mean of precision and recall, balancing both
|
| 77 |
+
|
| 78 |
+
4. Logging Strategy -- on_step=False, on_epoch=True:
|
| 79 |
+
- Training loss: fluctuates heavily per step; step-level logging aids debugging
|
| 80 |
+
- Evaluation metrics: require full epoch data to be statistically meaningful,
|
| 81 |
+
hence on_epoch=True
|
| 82 |
+
- prog_bar=True: displays key metrics in the training progress bar for real-time
|
| 83 |
+
monitoring
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
from __future__ import annotations
|
| 87 |
+
|
| 88 |
+
import pytorch_lightning as pl
|
| 89 |
+
import torch
|
| 90 |
+
import torchmetrics
|
| 91 |
+
from torch import nn
|
| 92 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 93 |
+
|
| 94 |
+
from deepmirt.model.mirna_target_model import MiRNATargetModel
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class MiRNATargetLitModule(pl.LightningModule):
|
| 98 |
+
"""
|
| 99 |
+
Lightning training module for miRNA-target binary classification prediction.
|
| 100 |
+
|
| 101 |
+
Responsibilities:
|
| 102 |
+
- Wraps MiRNATargetModel, managing forward pass / loss / metric computation
|
| 103 |
+
- Configures optimizer with differential learning rates and LR scheduler
|
| 104 |
+
- Provides training_step / validation_step / test_step
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
config: Nested dictionary with the following structure:
|
| 108 |
+
{
|
| 109 |
+
'model': {
|
| 110 |
+
'freeze_backbone': bool,
|
| 111 |
+
'cross_attn_heads': int,
|
| 112 |
+
'cross_attn_layers': int,
|
| 113 |
+
'classifier_hidden': list[int],
|
| 114 |
+
'dropout': float,
|
| 115 |
+
},
|
| 116 |
+
'training': {
|
| 117 |
+
'lr': float, # base learning rate (used by classifier head)
|
| 118 |
+
'weight_decay': float, # L2 regularization coefficient
|
| 119 |
+
'scheduler': str, # 'cosine' or 'onecycle'
|
| 120 |
+
'max_epochs': int, # total training epochs (needed by scheduler)
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, config: dict) -> None:
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
# Save hyperparameters to the checkpoint for restoring the full config on reload
|
| 129 |
+
# Design decision: save_hyperparameters ensures reproducibility -- checkpoint carries the full config
|
| 130 |
+
self.save_hyperparameters(config)
|
| 131 |
+
self.config = config
|
| 132 |
+
|
| 133 |
+
# ββ Extract model parameters from config and instantiate ββ
|
| 134 |
+
model_cfg = config["model"]
|
| 135 |
+
self.model = MiRNATargetModel(
|
| 136 |
+
freeze_backbone=model_cfg.get("freeze_backbone", True),
|
| 137 |
+
cross_attn_heads=model_cfg.get("cross_attn_heads", 8),
|
| 138 |
+
cross_attn_layers=model_cfg.get("cross_attn_layers", 2),
|
| 139 |
+
classifier_hidden=model_cfg.get("classifier_hidden", [256, 64]),
|
| 140 |
+
dropout=model_cfg.get("dropout", 0.3),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# ββ Loss function ββ
|
| 144 |
+
# Design decision: BCEWithLogitsLoss is more numerically stable than sigmoid + BCELoss.
|
| 145 |
+
# Internal formula: loss = max(logit, 0) - logit * label + log(1 + exp(-|logit|))
|
| 146 |
+
# This formula avoids numerical overflow from log(sigmoid(x)) at extreme values of x.
|
| 147 |
+
self.loss_fn = nn.BCEWithLogitsLoss()
|
| 148 |
+
|
| 149 |
+
# ββ Training metrics ββ
|
| 150 |
+
# torchmetrics automatically handles metric aggregation in distributed settings (DDP sync)
|
| 151 |
+
self.train_auroc = torchmetrics.AUROC(task="binary")
|
| 152 |
+
|
| 153 |
+
# ββ Validation metrics ββ
|
| 154 |
+
self.val_auroc = torchmetrics.AUROC(task="binary")
|
| 155 |
+
self.val_auprc = torchmetrics.AveragePrecision(task="binary")
|
| 156 |
+
self.val_acc = torchmetrics.Accuracy(task="binary")
|
| 157 |
+
self.val_f1 = torchmetrics.F1Score(task="binary")
|
| 158 |
+
|
| 159 |
+
# ββ Test metrics (same as validation, but separate instances to avoid state contamination) ββ
|
| 160 |
+
self.test_auroc = torchmetrics.AUROC(task="binary")
|
| 161 |
+
self.test_auprc = torchmetrics.AveragePrecision(task="binary")
|
| 162 |
+
self.test_acc = torchmetrics.Accuracy(task="binary")
|
| 163 |
+
self.test_f1 = torchmetrics.F1Score(task="binary")
|
| 164 |
+
|
| 165 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 166 |
+
# Training step
|
| 167 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 168 |
+
|
| 169 |
+
def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
| 170 |
+
"""
|
| 171 |
+
Single training step: forward pass -> compute loss -> update metrics.
|
| 172 |
+
|
| 173 |
+
Lightning automatically calls backward() and optimizer.step() on the returned loss.
|
| 174 |
+
There is no need to manually call loss.backward() or optimizer.zero_grad().
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
batch: Dictionary output from the DataModule collate_fn, containing:
|
| 178 |
+
- mirna_tokens: (B, max_mirna_len)
|
| 179 |
+
- target_tokens: (B, 42)
|
| 180 |
+
- labels: (B,) float32
|
| 181 |
+
- attention_mask_mirna: (B, max_mirna_len)
|
| 182 |
+
- attention_mask_target: (B, 42)
|
| 183 |
+
batch_idx: Index of the current batch (automatically passed by Lightning)
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
loss: Scalar tensor; Lightning automatically backpropagates through it
|
| 187 |
+
"""
|
| 188 |
+
# Step 1: Extract inputs from the batch dictionary
|
| 189 |
+
mirna_tokens = batch["mirna_tokens"]
|
| 190 |
+
target_tokens = batch["target_tokens"]
|
| 191 |
+
labels = batch["labels"]
|
| 192 |
+
attention_mask_mirna = batch["attention_mask_mirna"]
|
| 193 |
+
attention_mask_target = batch["attention_mask_target"]
|
| 194 |
+
|
| 195 |
+
# Step 2: Forward pass -> logits shape (B, 1)
|
| 196 |
+
logits = self.model(
|
| 197 |
+
mirna_tokens, target_tokens, attention_mask_mirna, attention_mask_target
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Step 3: Compute loss
|
| 201 |
+
# squeeze(-1) reduces logits from (B, 1) to (B,), aligning with labels (B,)
|
| 202 |
+
loss = self.loss_fn(logits.squeeze(-1), labels)
|
| 203 |
+
|
| 204 |
+
# Step 4: Compute prediction probabilities and update metrics
|
| 205 |
+
# Note: sigmoid is only used for metric computation, not for the loss (BCEWithLogitsLoss includes sigmoid internally)
|
| 206 |
+
probs = torch.sigmoid(logits.squeeze(-1))
|
| 207 |
+
self.train_auroc(probs, labels.long())
|
| 208 |
+
|
| 209 |
+
# Step 5: Logging
|
| 210 |
+
# Design decision: train_loss uses on_step=True to monitor convergence trends,
|
| 211 |
+
# train_auroc uses on_epoch=True because per-step AUROC has little statistical significance.
|
| 212 |
+
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
|
| 213 |
+
self.log(
|
| 214 |
+
"train_auroc",
|
| 215 |
+
self.train_auroc,
|
| 216 |
+
on_step=False,
|
| 217 |
+
on_epoch=True,
|
| 218 |
+
prog_bar=True,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
return loss
|
| 222 |
+
|
| 223 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 224 |
+
# Validation step
|
| 225 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 226 |
+
|
| 227 |
+
def validation_step(self, batch: dict, batch_idx: int) -> None:
|
| 228 |
+
"""
|
| 229 |
+
Single validation step: forward pass -> compute loss and full metric suite.
|
| 230 |
+
|
| 231 |
+
Lightning automatically handles the following during validation:
|
| 232 |
+
- Switches to model.eval() mode (disables Dropout, uses running mean for BatchNorm)
|
| 233 |
+
- Wraps in torch.no_grad(), skipping gradient computation to save memory
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
batch: Same as training_step
|
| 237 |
+
batch_idx: Index of the current batch
|
| 238 |
+
"""
|
| 239 |
+
mirna_tokens = batch["mirna_tokens"]
|
| 240 |
+
target_tokens = batch["target_tokens"]
|
| 241 |
+
labels = batch["labels"]
|
| 242 |
+
attention_mask_mirna = batch["attention_mask_mirna"]
|
| 243 |
+
attention_mask_target = batch["attention_mask_target"]
|
| 244 |
+
|
| 245 |
+
logits = self.model(
|
| 246 |
+
mirna_tokens, target_tokens, attention_mask_mirna, attention_mask_target
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
loss = self.loss_fn(logits.squeeze(-1), labels)
|
| 250 |
+
probs = torch.sigmoid(logits.squeeze(-1))
|
| 251 |
+
|
| 252 |
+
# Update all validation metrics
|
| 253 |
+
self.val_auroc(probs, labels.long())
|
| 254 |
+
self.val_auprc(probs, labels.long())
|
| 255 |
+
self.val_acc(probs, labels.long())
|
| 256 |
+
self.val_f1(probs, labels.long())
|
| 257 |
+
|
| 258 |
+
# Design decision: all validation metrics use on_epoch=True, as they need full data to be statistically meaningful
|
| 259 |
+
# sync_dist=True automatically aggregates metrics across GPUs in multi-GPU settings
|
| 260 |
+
self.log("val_loss", loss, prog_bar=True, on_epoch=True, sync_dist=True)
|
| 261 |
+
self.log("val_auroc", self.val_auroc, on_epoch=True, prog_bar=True)
|
| 262 |
+
self.log("val_auprc", self.val_auprc, on_epoch=True)
|
| 263 |
+
self.log("val_acc", self.val_acc, on_epoch=True)
|
| 264 |
+
self.log("val_f1", self.val_f1, on_epoch=True)
|
| 265 |
+
|
| 266 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 267 |
+
# Test step
|
| 268 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 269 |
+
|
| 270 |
+
def test_step(self, batch: dict, batch_idx: int) -> None:
|
| 271 |
+
"""
|
| 272 |
+
Single test step: same logic as validation_step, using separate test metric instances.
|
| 273 |
+
|
| 274 |
+
Test metrics are instantiated separately from validation metrics to avoid state
|
| 275 |
+
contamination. For example, val_auroc resets at the end of each validation epoch,
|
| 276 |
+
while test_auroc is only used when trainer.test() is called.
|
| 277 |
+
"""
|
| 278 |
+
mirna_tokens = batch["mirna_tokens"]
|
| 279 |
+
target_tokens = batch["target_tokens"]
|
| 280 |
+
labels = batch["labels"]
|
| 281 |
+
attention_mask_mirna = batch["attention_mask_mirna"]
|
| 282 |
+
attention_mask_target = batch["attention_mask_target"]
|
| 283 |
+
|
| 284 |
+
logits = self.model(
|
| 285 |
+
mirna_tokens, target_tokens, attention_mask_mirna, attention_mask_target
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
loss = self.loss_fn(logits.squeeze(-1), labels)
|
| 289 |
+
probs = torch.sigmoid(logits.squeeze(-1))
|
| 290 |
+
|
| 291 |
+
# Update test metrics
|
| 292 |
+
self.test_auroc(probs, labels.long())
|
| 293 |
+
self.test_auprc(probs, labels.long())
|
| 294 |
+
self.test_acc(probs, labels.long())
|
| 295 |
+
self.test_f1(probs, labels.long())
|
| 296 |
+
|
| 297 |
+
self.log("test_loss", loss, on_epoch=True, sync_dist=True)
|
| 298 |
+
self.log("test_auroc", self.test_auroc, on_epoch=True)
|
| 299 |
+
self.log("test_auprc", self.test_auprc, on_epoch=True)
|
| 300 |
+
self.log("test_acc", self.test_acc, on_epoch=True)
|
| 301 |
+
self.log("test_f1", self.test_f1, on_epoch=True)
|
| 302 |
+
|
| 303 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 304 |
+
# Optimizer and learning rate scheduling
|
| 305 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 306 |
+
|
| 307 |
+
def configure_optimizers(self) -> dict:
|
| 308 |
+
"""
|
| 309 |
+
Configure AdamW optimizer with differential learning rates and cosine annealing scheduler.
|
| 310 |
+
|
| 311 |
+
[Differential Learning Rates -- Why use different learning rates for different modules?]
|
| 312 |
+
|
| 313 |
+
Module Learning Rate Reason
|
| 314 |
+
βββββββββββββ βββββββββββββ ββββββββββββββββββββββββββββββββββ
|
| 315 |
+
RNA-FM backbone base_lrΓ0.01 Pretrained weights contain rich RNA structure/sequence
|
| 316 |
+
knowledge; a large LR would destroy this knowledge
|
| 317 |
+
(catastrophic forgetting)
|
| 318 |
+
Cross-attention base_lrΓ0.1 Newly initialized module, but needs to stably learn
|
| 319 |
+
miRNA-target attention patterns
|
| 320 |
+
Classifier head base_lrΓ1.0 Learns the binary classification decision boundary
|
| 321 |
+
from scratch; needs the highest LR for fast convergence
|
| 322 |
+
|
| 323 |
+
Design decision: The LR ratios [0.01, 0.1, 1.0] follow common transfer learning practice;
|
| 324 |
+
the paper "Universal Language Model Fine-tuning" (Howard & Ruder, 2018)
|
| 325 |
+
calls this "discriminative fine-tuning".
|
| 326 |
+
|
| 327 |
+
[CosineAnnealingLR Scheduler]
|
| 328 |
+
The learning rate decays from its initial value toward 0 following a cosine curve:
|
| 329 |
+
lr(t) = lr_min + 0.5 * (lr_max - lr_min) * (1 + cos(pi * t / T_max))
|
| 330 |
+
Advantage: fast learning early on, fine-grained adjustment later, avoiding instability
|
| 331 |
+
from sudden LR drops.
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
Dictionary containing the optimizer and lr_scheduler
|
| 335 |
+
"""
|
| 336 |
+
training_cfg = self.config["training"]
|
| 337 |
+
base_lr = training_cfg["lr"]
|
| 338 |
+
weight_decay = training_cfg.get("weight_decay", 1e-5)
|
| 339 |
+
scheduler_type = training_cfg.get("scheduler", "cosine")
|
| 340 |
+
max_epochs = training_cfg.get("max_epochs", 30)
|
| 341 |
+
|
| 342 |
+
# Design decision: 3 parameter groups correspond to the model's 3 semantic modules;
|
| 343 |
+
# learning rates decrease from downstream to upstream (farther from the task = smaller LR).
|
| 344 |
+
param_groups = [
|
| 345 |
+
{
|
| 346 |
+
"params": list(self.model.encoder.parameters()),
|
| 347 |
+
"lr": base_lr * 0.01,
|
| 348 |
+
"name": "backbone",
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"params": list(self.model.cross_attention.parameters()),
|
| 352 |
+
"lr": base_lr * 0.1,
|
| 353 |
+
"name": "cross_attention",
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"params": list(self.model.classifier.parameters()),
|
| 357 |
+
"lr": base_lr,
|
| 358 |
+
"name": "classifier",
|
| 359 |
+
},
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
optimizer = torch.optim.AdamW(param_groups, weight_decay=weight_decay)
|
| 363 |
+
|
| 364 |
+
# Design decision: CosineAnnealingLR is a safe default choice --
|
| 365 |
+
# it does not require knowing total steps (unlike OneCycleLR), and provides smooth decay.
|
| 366 |
+
if scheduler_type == "cosine":
|
| 367 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
| 368 |
+
elif scheduler_type == "onecycle":
|
| 369 |
+
# OneCycleLR requires total_steps = steps_per_epoch * max_epochs,
|
| 370 |
+
# but at the configure_optimizers stage the DataLoader has not been created yet,
|
| 371 |
+
# so steps_per_epoch is unavailable. Therefore, fall back to CosineAnnealingLR.
|
| 372 |
+
# If OneCycleLR is needed, it should be configured in train.py via the Trainer's
|
| 373 |
+
# estimated_stepping_batches.
|
| 374 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
| 375 |
+
else:
|
| 376 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
|
| 377 |
+
|
| 378 |
+
return {
|
| 379 |
+
"optimizer": optimizer,
|
| 380 |
+
"lr_scheduler": {
|
| 381 |
+
"scheduler": scheduler,
|
| 382 |
+
# Design decision: interval='epoch' adjusts the learning rate once per epoch,
|
| 383 |
+
# which is more stable than 'step' (adjusting after every batch), suitable for small to medium datasets.
|
| 384 |
+
"interval": "epoch",
|
| 385 |
+
},
|
| 386 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
rna-fm
|
| 3 |
+
pytorch-lightning>=2.0
|
| 4 |
+
torchmetrics
|
| 5 |
+
pyyaml
|
| 6 |
+
scikit-learn
|
| 7 |
+
numpy
|
| 8 |
+
pandas
|
| 9 |
+
huggingface-hub
|