""" core_engine.py ============== DeepCRISPR Mega Model — PyTorch Architecture & Feature Extraction. Contains the exact neural network architecture used to train mega_model_best.pth: - CFG: Hyperparameters and nucleotide vocabulary - encode_pair(): Tokenizes sgRNA + off-target pair - MultiScaleCNN: Multi-kernel 1D convolutions - PositionalEncoding: Sinusoidal position embeddings - CRISPRTransformer: Transformer encoder with cross-attention - BiLSTMEncoder: Bidirectional LSTM - CRISPRMegaModel: Fusion of all three → 256-dim embeddings - extract_bio_features(): Hand-crafted biological features Architected by Mujahid """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import List, Dict, Tuple # ─────────────────────────── CONFIGURATION ────────────────────────────────── class CFG: SEQ_LEN = 23 VOCAB_SIZE = 7 # A C G T N - [PAD] EMBED_DIM = 128 # nucleotide embedding size CNN_FILTERS = [128, 256, 256, 512] CNN_KERNELS = [3, 5, 7, 9] # multi-scale kernels CNN_DROPOUT = 0.2 TF_HEADS = 8 TF_LAYERS = 4 TF_DIM = 256 TF_FF_DIM = 512 TF_DROPOUT = 0.1 LSTM_HIDDEN = 128 LSTM_LAYERS = 2 NT = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5, '-': 6, '[PAD]': 0} MISMATCH_MATRIX = { ('A', 'G'): 'transition', ('G', 'A'): 'transition', ('C', 'T'): 'transition', ('T', 'C'): 'transition', ('A', 'C'): 'transversion', ('C', 'A'): 'transversion', ('A', 'T'): 'transversion', ('T', 'A'): 'transversion', ('G', 'C'): 'transversion', ('C', 'G'): 'transversion', ('G', 'T'): 'transversion', ('T', 'G'): 'transversion', ('-', 'A'): 'dna_bulge', ('-', 'C'): 'dna_bulge', ('-', 'G'): 'dna_bulge', ('-', 'T'): 'dna_bulge', ('A', '-'): 'rna_bulge', ('C', '-'): 'rna_bulge', ('G', '-'): 'rna_bulge', ('T', '-'): 'rna_bulge', } cfg = CFG() # ─────────────────────────── TOKENIZATION ─────────────────────────────────── def tokenize(seq: str, max_len: int = cfg.SEQ_LEN) -> List[int]: """Convert a nucleotide sequence to integer token list.""" seq = seq.upper()[:max_len].ljust(max_len, 'N') return [cfg.NT.get(c, cfg.NT['N']) for c in seq] def encode_pair(sgrna: str, off: str) -> Tuple[List[int], List[int], List[int]]: """ Encode an sgRNA + off-target pair into three integer lists: 1. sgRNA tokens 2. Off-target tokens 3. Mismatch channel (0=match, 1=transition, 2=transversion, 3=bulge) """ sg = sgrna.upper()[:cfg.SEQ_LEN].ljust(cfg.SEQ_LEN, 'N') of = off.upper()[:cfg.SEQ_LEN].ljust(cfg.SEQ_LEN, 'N') mm_map = { 'match': 0, 'transition': 1, 'transversion': 2, 'dna_bulge': 3, 'rna_bulge': 3, } mm_ch = [ mm_map[cfg.MISMATCH_MATRIX.get((a, b), 'match')] for a, b in zip(sg, of) ] return tokenize(sg), tokenize(of), mm_ch # ─────────────────────────── MULTI-SCALE CNN ──────────────────────────────── class MultiScaleCNN(nn.Module): """Multi-kernel 1D CNN operating on concatenated sgRNA + off-target embeddings.""" def __init__(self): super().__init__() self.embed = nn.Embedding(cfg.VOCAB_SIZE, cfg.EMBED_DIM, padding_idx=0) self.branches = nn.ModuleList([ nn.Sequential( nn.Conv1d(cfg.EMBED_DIM * 2 + 4, n_filters, kernel_size=k, padding=k // 2), nn.BatchNorm1d(n_filters), nn.GELU(), nn.Conv1d(n_filters, n_filters, kernel_size=k, padding=k // 2), nn.BatchNorm1d(n_filters), nn.GELU(), nn.Dropout(cfg.CNN_DROPOUT), ) for k, n_filters in zip(cfg.CNN_KERNELS, cfg.CNN_FILTERS) ]) self.out_dim = sum(cfg.CNN_FILTERS) # 128+256+256+512 = 1152 def forward(self, sg, off, mm): sg_e = self.embed(sg) # (B, 23, 128) off_e = self.embed(off) # (B, 23, 128) mm_oh = F.one_hot(mm, num_classes=4).float() # (B, 23, 4) x = torch.cat([sg_e, off_e, mm_oh], dim=-1) # (B, 23, 260) x = x.permute(0, 2, 1) # (B, 260, 23) outs = [] for branch in self.branches: feat = branch(x) # (B, n_f, 23) outs.append(feat.mean(dim=-1) + feat.max(dim=-1)[0]) return torch.cat(outs, dim=-1) # (B, 1152) # ─────────────────────────── POSITIONAL ENCODING ──────────────────────────── class PositionalEncoding(nn.Module): """Sinusoidal positional encoding.""" def __init__(self, d_model, max_len=64): super().__init__() pe = torch.zeros(max_len, d_model) pos = torch.arange(0, max_len).unsqueeze(1).float() div = torch.exp( torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)] # ─────────────────────────── TRANSFORMER ──────────────────────────────────── class CRISPRTransformer(nn.Module): """Transformer encoder with cross-attention between sgRNA and off-target.""" def __init__(self): super().__init__() self.embed = nn.Embedding(cfg.VOCAB_SIZE, cfg.TF_DIM, padding_idx=0) self.pos_enc = PositionalEncoding(cfg.TF_DIM) self.mm_proj = nn.Linear(4, cfg.TF_DIM) enc_layer = nn.TransformerEncoderLayer( d_model=cfg.TF_DIM, nhead=cfg.TF_HEADS, dim_feedforward=cfg.TF_FF_DIM, dropout=cfg.TF_DROPOUT, activation='gelu', batch_first=True, norm_first=True, ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg.TF_LAYERS) self.cross_attn = nn.MultiheadAttention( cfg.TF_DIM, cfg.TF_HEADS, dropout=cfg.TF_DROPOUT, batch_first=True, ) self.out_dim = cfg.TF_DIM * 4 # 256 * 4 = 1024 def forward(self, sg, off, mm): mm_oh = F.one_hot(mm, num_classes=4).float() sg_e = self.pos_enc(self.embed(sg)) + self.mm_proj(mm_oh) off_e = self.pos_enc(self.embed(off)) sg_enc = self.encoder(sg_e) off_enc = self.encoder(off_e) cross, _ = self.cross_attn(sg_enc, off_enc, off_enc) sg_feat = torch.cat([cross.mean(1), cross.max(1)[0]], dim=-1) off_feat = torch.cat([off_enc.mean(1), off_enc.max(1)[0]], dim=-1) return torch.cat([sg_feat, off_feat], dim=-1) # (B, 1024) # ─────────────────────────── BiLSTM ───────────────────────────────────────── class BiLSTMEncoder(nn.Module): """Bidirectional LSTM encoder.""" def __init__(self): super().__init__() self.embed = nn.Embedding(cfg.VOCAB_SIZE, cfg.EMBED_DIM, padding_idx=0) self.lstm = nn.LSTM( input_size=cfg.EMBED_DIM * 2 + 4, hidden_size=cfg.LSTM_HIDDEN, num_layers=cfg.LSTM_LAYERS, batch_first=True, bidirectional=True, dropout=0.2, ) self.out_dim = cfg.LSTM_HIDDEN * 2 * 2 # 128 * 2 (bidir) * 2 (cat) = 512 def forward(self, sg, off, mm): sg_e = self.embed(sg) off_e = self.embed(off) mm_oh = F.one_hot(mm, num_classes=4).float() x = torch.cat([sg_e, off_e, mm_oh], dim=-1) out, (h, _) = self.lstm(x) mean_pool = out.mean(dim=1) last_hidden = torch.cat([h[-2], h[-1]], dim=-1) return torch.cat([mean_pool, last_hidden], dim=-1) # (B, 512) # ─────────────────────────── MEGA MODEL (FUSION) ─────────────────────────── class CRISPRMegaModel(nn.Module): """ Fusion of MultiScaleCNN + CRISPRTransformer + BiLSTMEncoder. Outputs 256-dimensional embeddings + off-target / efficiency heads. Total input to fusion: 1152 (CNN) + 1024 (TF) + 512 (LSTM) = 2688 Fusion layers: 2688 → 1024 → 512 → 256 """ def __init__(self): super().__init__() self.cnn = MultiScaleCNN() self.transformer = CRISPRTransformer() self.bilstm = BiLSTMEncoder() total_feats = self.cnn.out_dim + self.transformer.out_dim + self.bilstm.out_dim self.fusion = nn.Sequential( nn.Linear(total_feats, 1024), nn.LayerNorm(1024), nn.GELU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.2), nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(0.1), ) self.off_head = nn.Linear(256, 1) self.eff_head = nn.Linear(256, 1) self.emb_head = nn.Identity() def forward(self, sg, off, mm): cnn_out = self.cnn(sg, off, mm) tf_out = self.transformer(sg, off, mm) lstm_out = self.bilstm(sg, off, mm) combined = torch.cat([cnn_out, tf_out, lstm_out], dim=-1) emb = self.fusion(combined) # (B, 256) off_logit = self.off_head(emb).squeeze(-1) eff = torch.sigmoid(self.eff_head(emb)).squeeze(-1) off_prob = torch.sigmoid(off_logit) return { 'logit': off_logit, 'off_prob': off_prob, 'eff': eff, 'embedding': emb, # 256-dim } # ─────────────────────────── BIOLOGICAL FEATURES ─────────────────────────── def extract_bio_features(sgrna: str, offtarget: str) -> Dict: """ Compute hand-crafted biological features for an sgRNA / off-target pair. Returns a dict with ~50 features used as extra columns for AutoGluon. """ SEQ_LEN = cfg.SEQ_LEN sg = sgrna.upper()[:SEQ_LEN].ljust(SEQ_LEN, 'N') off = offtarget.upper()[:SEQ_LEN].ljust(SEQ_LEN, 'N') MM = cfg.MISMATCH_MATRIX mismatches = [(i, sg[i], off[i]) for i in range(SEQ_LEN) if sg[i] != off[i]] mm_positions = [m[0] for m in mismatches] seed_mms = [m for m in mismatches if m[0] >= SEQ_LEN - 12] def mm_type(a, b): return MM.get((a, b), 'match') feats = {} # ── Mismatch counts ── feats['n_mismatches'] = len(mismatches) feats['n_transitions'] = sum(1 for _, a, b in mismatches if mm_type(a, b) == 'transition') feats['n_transversions'] = sum(1 for _, a, b in mismatches if mm_type(a, b) == 'transversion') feats['seed_mismatches'] = len(seed_mms) feats['seed_transitions'] = sum(1 for _, a, b in seed_mms if mm_type(a, b) == 'transition') feats['seed_transversions'] = sum(1 for _, a, b in seed_mms if mm_type(a, b) == 'transversion') feats['pam_proximal_mm'] = sum(1 for m in mismatches if m[0] >= SEQ_LEN - 5) # ── Positional features ── feats['first_mm_pos'] = mm_positions[0] if mm_positions else -1 feats['last_mm_pos'] = mm_positions[-1] if mm_positions else -1 feats['mm_span'] = (mm_positions[-1] - mm_positions[0]) if len(mm_positions) > 1 else 0 if len(mm_positions) >= 2: gaps = [mm_positions[i + 1] - mm_positions[i] for i in range(len(mm_positions) - 1)] feats['mm_min_gap'] = min(gaps) feats['mm_mean_gap'] = float(np.mean(gaps)) feats['mm_clustered'] = float(min(gaps) <= 2) else: feats['mm_min_gap'] = SEQ_LEN feats['mm_mean_gap'] = float(SEQ_LEN) feats['mm_clustered'] = 0.0 # ── Nucleotide composition ── for nt in 'ACGT': feats[f'sg_{nt}_frac'] = sg.count(nt) / SEQ_LEN feats[f'off_{nt}_frac'] = off.count(nt) / SEQ_LEN feats['sg_gc'] = (sg.count('G') + sg.count('C')) / SEQ_LEN feats['off_gc'] = (off.count('G') + off.count('C')) / SEQ_LEN # ── Thermodynamic proxy & penalties ── at = sg.count('A') + sg.count('T') gc = sg.count('G') + sg.count('C') feats['sg_tm_proxy'] = 2 * at + 4 * gc feats['weighted_mm_penalty'] = sum((1.0 + p / SEQ_LEN) for p in mm_positions) feats['edit_dist_norm'] = len(mismatches) / SEQ_LEN feats['pam_is_ngg'] = float(sg[-2:] == 'GG' if len(sg) >= 2 else False) # ── Per-position mismatch flags ── for i in range(SEQ_LEN): feats[f'mm_pos_{i}'] = float(sg[i] != off[i]) return feats