DeepCRISPR / core_engine.py
mk6783336's picture
Upload 12 files
c1d5fee verified
"""
core_engine.py
==============
DeepCRISPR Mega Model β€” PyTorch Architecture & Feature Extraction.
Contains the exact neural network architecture used to train mega_model_best.pth:
- CFG: Hyperparameters and nucleotide vocabulary
- encode_pair(): Tokenizes sgRNA + off-target pair
- MultiScaleCNN: Multi-kernel 1D convolutions
- PositionalEncoding: Sinusoidal position embeddings
- CRISPRTransformer: Transformer encoder with cross-attention
- BiLSTMEncoder: Bidirectional LSTM
- CRISPRMegaModel: Fusion of all three β†’ 256-dim embeddings
- extract_bio_features(): Hand-crafted biological features
Architected by Mujahid
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple
# ─────────────────────────── CONFIGURATION ──────────────────────────────────
class CFG:
SEQ_LEN = 23
VOCAB_SIZE = 7 # A C G T N - [PAD]
EMBED_DIM = 128 # nucleotide embedding size
CNN_FILTERS = [128, 256, 256, 512]
CNN_KERNELS = [3, 5, 7, 9] # multi-scale kernels
CNN_DROPOUT = 0.2
TF_HEADS = 8
TF_LAYERS = 4
TF_DIM = 256
TF_FF_DIM = 512
TF_DROPOUT = 0.1
LSTM_HIDDEN = 128
LSTM_LAYERS = 2
NT = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5, '-': 6, '[PAD]': 0}
MISMATCH_MATRIX = {
('A', 'G'): 'transition', ('G', 'A'): 'transition',
('C', 'T'): 'transition', ('T', 'C'): 'transition',
('A', 'C'): 'transversion', ('C', 'A'): 'transversion',
('A', 'T'): 'transversion', ('T', 'A'): 'transversion',
('G', 'C'): 'transversion', ('C', 'G'): 'transversion',
('G', 'T'): 'transversion', ('T', 'G'): 'transversion',
('-', 'A'): 'dna_bulge', ('-', 'C'): 'dna_bulge',
('-', 'G'): 'dna_bulge', ('-', 'T'): 'dna_bulge',
('A', '-'): 'rna_bulge', ('C', '-'): 'rna_bulge',
('G', '-'): 'rna_bulge', ('T', '-'): 'rna_bulge',
}
cfg = CFG()
# ─────────────────────────── TOKENIZATION ───────────────────────────────────
def tokenize(seq: str, max_len: int = cfg.SEQ_LEN) -> List[int]:
"""Convert a nucleotide sequence to integer token list."""
seq = seq.upper()[:max_len].ljust(max_len, 'N')
return [cfg.NT.get(c, cfg.NT['N']) for c in seq]
def encode_pair(sgrna: str, off: str) -> Tuple[List[int], List[int], List[int]]:
"""
Encode an sgRNA + off-target pair into three integer lists:
1. sgRNA tokens
2. Off-target tokens
3. Mismatch channel (0=match, 1=transition, 2=transversion, 3=bulge)
"""
sg = sgrna.upper()[:cfg.SEQ_LEN].ljust(cfg.SEQ_LEN, 'N')
of = off.upper()[:cfg.SEQ_LEN].ljust(cfg.SEQ_LEN, 'N')
mm_map = {
'match': 0, 'transition': 1, 'transversion': 2,
'dna_bulge': 3, 'rna_bulge': 3,
}
mm_ch = [
mm_map[cfg.MISMATCH_MATRIX.get((a, b), 'match')]
for a, b in zip(sg, of)
]
return tokenize(sg), tokenize(of), mm_ch
# ─────────────────────────── MULTI-SCALE CNN ────────────────────────────────
class MultiScaleCNN(nn.Module):
"""Multi-kernel 1D CNN operating on concatenated sgRNA + off-target embeddings."""
def __init__(self):
super().__init__()
self.embed = nn.Embedding(cfg.VOCAB_SIZE, cfg.EMBED_DIM, padding_idx=0)
self.branches = nn.ModuleList([
nn.Sequential(
nn.Conv1d(cfg.EMBED_DIM * 2 + 4, n_filters, kernel_size=k, padding=k // 2),
nn.BatchNorm1d(n_filters),
nn.GELU(),
nn.Conv1d(n_filters, n_filters, kernel_size=k, padding=k // 2),
nn.BatchNorm1d(n_filters),
nn.GELU(),
nn.Dropout(cfg.CNN_DROPOUT),
)
for k, n_filters in zip(cfg.CNN_KERNELS, cfg.CNN_FILTERS)
])
self.out_dim = sum(cfg.CNN_FILTERS) # 128+256+256+512 = 1152
def forward(self, sg, off, mm):
sg_e = self.embed(sg) # (B, 23, 128)
off_e = self.embed(off) # (B, 23, 128)
mm_oh = F.one_hot(mm, num_classes=4).float() # (B, 23, 4)
x = torch.cat([sg_e, off_e, mm_oh], dim=-1) # (B, 23, 260)
x = x.permute(0, 2, 1) # (B, 260, 23)
outs = []
for branch in self.branches:
feat = branch(x) # (B, n_f, 23)
outs.append(feat.mean(dim=-1) + feat.max(dim=-1)[0])
return torch.cat(outs, dim=-1) # (B, 1152)
# ─────────────────────────── POSITIONAL ENCODING ────────────────────────────
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding."""
def __init__(self, d_model, max_len=64):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1).float()
div = torch.exp(
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
# ─────────────────────────── TRANSFORMER ────────────────────────────────────
class CRISPRTransformer(nn.Module):
"""Transformer encoder with cross-attention between sgRNA and off-target."""
def __init__(self):
super().__init__()
self.embed = nn.Embedding(cfg.VOCAB_SIZE, cfg.TF_DIM, padding_idx=0)
self.pos_enc = PositionalEncoding(cfg.TF_DIM)
self.mm_proj = nn.Linear(4, cfg.TF_DIM)
enc_layer = nn.TransformerEncoderLayer(
d_model=cfg.TF_DIM, nhead=cfg.TF_HEADS,
dim_feedforward=cfg.TF_FF_DIM, dropout=cfg.TF_DROPOUT,
activation='gelu', batch_first=True, norm_first=True,
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg.TF_LAYERS)
self.cross_attn = nn.MultiheadAttention(
cfg.TF_DIM, cfg.TF_HEADS, dropout=cfg.TF_DROPOUT, batch_first=True,
)
self.out_dim = cfg.TF_DIM * 4 # 256 * 4 = 1024
def forward(self, sg, off, mm):
mm_oh = F.one_hot(mm, num_classes=4).float()
sg_e = self.pos_enc(self.embed(sg)) + self.mm_proj(mm_oh)
off_e = self.pos_enc(self.embed(off))
sg_enc = self.encoder(sg_e)
off_enc = self.encoder(off_e)
cross, _ = self.cross_attn(sg_enc, off_enc, off_enc)
sg_feat = torch.cat([cross.mean(1), cross.max(1)[0]], dim=-1)
off_feat = torch.cat([off_enc.mean(1), off_enc.max(1)[0]], dim=-1)
return torch.cat([sg_feat, off_feat], dim=-1) # (B, 1024)
# ─────────────────────────── BiLSTM ─────────────────────────────────────────
class BiLSTMEncoder(nn.Module):
"""Bidirectional LSTM encoder."""
def __init__(self):
super().__init__()
self.embed = nn.Embedding(cfg.VOCAB_SIZE, cfg.EMBED_DIM, padding_idx=0)
self.lstm = nn.LSTM(
input_size=cfg.EMBED_DIM * 2 + 4,
hidden_size=cfg.LSTM_HIDDEN,
num_layers=cfg.LSTM_LAYERS,
batch_first=True,
bidirectional=True,
dropout=0.2,
)
self.out_dim = cfg.LSTM_HIDDEN * 2 * 2 # 128 * 2 (bidir) * 2 (cat) = 512
def forward(self, sg, off, mm):
sg_e = self.embed(sg)
off_e = self.embed(off)
mm_oh = F.one_hot(mm, num_classes=4).float()
x = torch.cat([sg_e, off_e, mm_oh], dim=-1)
out, (h, _) = self.lstm(x)
mean_pool = out.mean(dim=1)
last_hidden = torch.cat([h[-2], h[-1]], dim=-1)
return torch.cat([mean_pool, last_hidden], dim=-1) # (B, 512)
# ─────────────────────────── MEGA MODEL (FUSION) ───────────────────────────
class CRISPRMegaModel(nn.Module):
"""
Fusion of MultiScaleCNN + CRISPRTransformer + BiLSTMEncoder.
Outputs 256-dimensional embeddings + off-target / efficiency heads.
Total input to fusion: 1152 (CNN) + 1024 (TF) + 512 (LSTM) = 2688
Fusion layers: 2688 β†’ 1024 β†’ 512 β†’ 256
"""
def __init__(self):
super().__init__()
self.cnn = MultiScaleCNN()
self.transformer = CRISPRTransformer()
self.bilstm = BiLSTMEncoder()
total_feats = self.cnn.out_dim + self.transformer.out_dim + self.bilstm.out_dim
self.fusion = nn.Sequential(
nn.Linear(total_feats, 1024), nn.LayerNorm(1024), nn.GELU(), nn.Dropout(0.3),
nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.2),
nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(0.1),
)
self.off_head = nn.Linear(256, 1)
self.eff_head = nn.Linear(256, 1)
self.emb_head = nn.Identity()
def forward(self, sg, off, mm):
cnn_out = self.cnn(sg, off, mm)
tf_out = self.transformer(sg, off, mm)
lstm_out = self.bilstm(sg, off, mm)
combined = torch.cat([cnn_out, tf_out, lstm_out], dim=-1)
emb = self.fusion(combined) # (B, 256)
off_logit = self.off_head(emb).squeeze(-1)
eff = torch.sigmoid(self.eff_head(emb)).squeeze(-1)
off_prob = torch.sigmoid(off_logit)
return {
'logit': off_logit,
'off_prob': off_prob,
'eff': eff,
'embedding': emb, # 256-dim
}
# ─────────────────────────── BIOLOGICAL FEATURES ───────────────────────────
def extract_bio_features(sgrna: str, offtarget: str) -> Dict:
"""
Compute hand-crafted biological features for an sgRNA / off-target pair.
Returns a dict with ~50 features used as extra columns for AutoGluon.
"""
SEQ_LEN = cfg.SEQ_LEN
sg = sgrna.upper()[:SEQ_LEN].ljust(SEQ_LEN, 'N')
off = offtarget.upper()[:SEQ_LEN].ljust(SEQ_LEN, 'N')
MM = cfg.MISMATCH_MATRIX
mismatches = [(i, sg[i], off[i]) for i in range(SEQ_LEN) if sg[i] != off[i]]
mm_positions = [m[0] for m in mismatches]
seed_mms = [m for m in mismatches if m[0] >= SEQ_LEN - 12]
def mm_type(a, b):
return MM.get((a, b), 'match')
feats = {}
# ── Mismatch counts ──
feats['n_mismatches'] = len(mismatches)
feats['n_transitions'] = sum(1 for _, a, b in mismatches if mm_type(a, b) == 'transition')
feats['n_transversions'] = sum(1 for _, a, b in mismatches if mm_type(a, b) == 'transversion')
feats['seed_mismatches'] = len(seed_mms)
feats['seed_transitions'] = sum(1 for _, a, b in seed_mms if mm_type(a, b) == 'transition')
feats['seed_transversions'] = sum(1 for _, a, b in seed_mms if mm_type(a, b) == 'transversion')
feats['pam_proximal_mm'] = sum(1 for m in mismatches if m[0] >= SEQ_LEN - 5)
# ── Positional features ──
feats['first_mm_pos'] = mm_positions[0] if mm_positions else -1
feats['last_mm_pos'] = mm_positions[-1] if mm_positions else -1
feats['mm_span'] = (mm_positions[-1] - mm_positions[0]) if len(mm_positions) > 1 else 0
if len(mm_positions) >= 2:
gaps = [mm_positions[i + 1] - mm_positions[i] for i in range(len(mm_positions) - 1)]
feats['mm_min_gap'] = min(gaps)
feats['mm_mean_gap'] = float(np.mean(gaps))
feats['mm_clustered'] = float(min(gaps) <= 2)
else:
feats['mm_min_gap'] = SEQ_LEN
feats['mm_mean_gap'] = float(SEQ_LEN)
feats['mm_clustered'] = 0.0
# ── Nucleotide composition ──
for nt in 'ACGT':
feats[f'sg_{nt}_frac'] = sg.count(nt) / SEQ_LEN
feats[f'off_{nt}_frac'] = off.count(nt) / SEQ_LEN
feats['sg_gc'] = (sg.count('G') + sg.count('C')) / SEQ_LEN
feats['off_gc'] = (off.count('G') + off.count('C')) / SEQ_LEN
# ── Thermodynamic proxy & penalties ──
at = sg.count('A') + sg.count('T')
gc = sg.count('G') + sg.count('C')
feats['sg_tm_proxy'] = 2 * at + 4 * gc
feats['weighted_mm_penalty'] = sum((1.0 + p / SEQ_LEN) for p in mm_positions)
feats['edit_dist_norm'] = len(mismatches) / SEQ_LEN
feats['pam_is_ngg'] = float(sg[-2:] == 'GG' if len(sg) >= 2 else False)
# ── Per-position mismatch flags ──
for i in range(SEQ_LEN):
feats[f'mm_pos_{i}'] = float(sg[i] != off[i])
return feats