Spaces:
Sleeping
Sleeping
| """ | |
| core_engine.py | |
| ============== | |
| DeepCRISPR Mega Model β PyTorch Architecture & Feature Extraction. | |
| Contains the exact neural network architecture used to train mega_model_best.pth: | |
| - CFG: Hyperparameters and nucleotide vocabulary | |
| - encode_pair(): Tokenizes sgRNA + off-target pair | |
| - MultiScaleCNN: Multi-kernel 1D convolutions | |
| - PositionalEncoding: Sinusoidal position embeddings | |
| - CRISPRTransformer: Transformer encoder with cross-attention | |
| - BiLSTMEncoder: Bidirectional LSTM | |
| - CRISPRMegaModel: Fusion of all three β 256-dim embeddings | |
| - extract_bio_features(): Hand-crafted biological features | |
| Architected by Mujahid | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import List, Dict, Tuple | |
| # βββββββββββββββββββββββββββ CONFIGURATION ββββββββββββββββββββββββββββββββββ | |
| class CFG: | |
| SEQ_LEN = 23 | |
| VOCAB_SIZE = 7 # A C G T N - [PAD] | |
| EMBED_DIM = 128 # nucleotide embedding size | |
| CNN_FILTERS = [128, 256, 256, 512] | |
| CNN_KERNELS = [3, 5, 7, 9] # multi-scale kernels | |
| CNN_DROPOUT = 0.2 | |
| TF_HEADS = 8 | |
| TF_LAYERS = 4 | |
| TF_DIM = 256 | |
| TF_FF_DIM = 512 | |
| TF_DROPOUT = 0.1 | |
| LSTM_HIDDEN = 128 | |
| LSTM_LAYERS = 2 | |
| NT = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5, '-': 6, '[PAD]': 0} | |
| MISMATCH_MATRIX = { | |
| ('A', 'G'): 'transition', ('G', 'A'): 'transition', | |
| ('C', 'T'): 'transition', ('T', 'C'): 'transition', | |
| ('A', 'C'): 'transversion', ('C', 'A'): 'transversion', | |
| ('A', 'T'): 'transversion', ('T', 'A'): 'transversion', | |
| ('G', 'C'): 'transversion', ('C', 'G'): 'transversion', | |
| ('G', 'T'): 'transversion', ('T', 'G'): 'transversion', | |
| ('-', 'A'): 'dna_bulge', ('-', 'C'): 'dna_bulge', | |
| ('-', 'G'): 'dna_bulge', ('-', 'T'): 'dna_bulge', | |
| ('A', '-'): 'rna_bulge', ('C', '-'): 'rna_bulge', | |
| ('G', '-'): 'rna_bulge', ('T', '-'): 'rna_bulge', | |
| } | |
| cfg = CFG() | |
| # βββββββββββββββββββββββββββ TOKENIZATION βββββββββββββββββββββββββββββββββββ | |
| def tokenize(seq: str, max_len: int = cfg.SEQ_LEN) -> List[int]: | |
| """Convert a nucleotide sequence to integer token list.""" | |
| seq = seq.upper()[:max_len].ljust(max_len, 'N') | |
| return [cfg.NT.get(c, cfg.NT['N']) for c in seq] | |
| def encode_pair(sgrna: str, off: str) -> Tuple[List[int], List[int], List[int]]: | |
| """ | |
| Encode an sgRNA + off-target pair into three integer lists: | |
| 1. sgRNA tokens | |
| 2. Off-target tokens | |
| 3. Mismatch channel (0=match, 1=transition, 2=transversion, 3=bulge) | |
| """ | |
| sg = sgrna.upper()[:cfg.SEQ_LEN].ljust(cfg.SEQ_LEN, 'N') | |
| of = off.upper()[:cfg.SEQ_LEN].ljust(cfg.SEQ_LEN, 'N') | |
| mm_map = { | |
| 'match': 0, 'transition': 1, 'transversion': 2, | |
| 'dna_bulge': 3, 'rna_bulge': 3, | |
| } | |
| mm_ch = [ | |
| mm_map[cfg.MISMATCH_MATRIX.get((a, b), 'match')] | |
| for a, b in zip(sg, of) | |
| ] | |
| return tokenize(sg), tokenize(of), mm_ch | |
| # βββββββββββββββββββββββββββ MULTI-SCALE CNN ββββββββββββββββββββββββββββββββ | |
| class MultiScaleCNN(nn.Module): | |
| """Multi-kernel 1D CNN operating on concatenated sgRNA + off-target embeddings.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.embed = nn.Embedding(cfg.VOCAB_SIZE, cfg.EMBED_DIM, padding_idx=0) | |
| self.branches = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv1d(cfg.EMBED_DIM * 2 + 4, n_filters, kernel_size=k, padding=k // 2), | |
| nn.BatchNorm1d(n_filters), | |
| nn.GELU(), | |
| nn.Conv1d(n_filters, n_filters, kernel_size=k, padding=k // 2), | |
| nn.BatchNorm1d(n_filters), | |
| nn.GELU(), | |
| nn.Dropout(cfg.CNN_DROPOUT), | |
| ) | |
| for k, n_filters in zip(cfg.CNN_KERNELS, cfg.CNN_FILTERS) | |
| ]) | |
| self.out_dim = sum(cfg.CNN_FILTERS) # 128+256+256+512 = 1152 | |
| def forward(self, sg, off, mm): | |
| sg_e = self.embed(sg) # (B, 23, 128) | |
| off_e = self.embed(off) # (B, 23, 128) | |
| mm_oh = F.one_hot(mm, num_classes=4).float() # (B, 23, 4) | |
| x = torch.cat([sg_e, off_e, mm_oh], dim=-1) # (B, 23, 260) | |
| x = x.permute(0, 2, 1) # (B, 260, 23) | |
| outs = [] | |
| for branch in self.branches: | |
| feat = branch(x) # (B, n_f, 23) | |
| outs.append(feat.mean(dim=-1) + feat.max(dim=-1)[0]) | |
| return torch.cat(outs, dim=-1) # (B, 1152) | |
| # βββββββββββββββββββββββββββ POSITIONAL ENCODING ββββββββββββββββββββββββββββ | |
| class PositionalEncoding(nn.Module): | |
| """Sinusoidal positional encoding.""" | |
| def __init__(self, d_model, max_len=64): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| pos = torch.arange(0, max_len).unsqueeze(1).float() | |
| div = torch.exp( | |
| torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(pos * div) | |
| pe[:, 1::2] = torch.cos(pos * div) | |
| self.register_buffer('pe', pe.unsqueeze(0)) | |
| def forward(self, x): | |
| return x + self.pe[:, :x.size(1)] | |
| # βββββββββββββββββββββββββββ TRANSFORMER ββββββββββββββββββββββββββββββββββββ | |
| class CRISPRTransformer(nn.Module): | |
| """Transformer encoder with cross-attention between sgRNA and off-target.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.embed = nn.Embedding(cfg.VOCAB_SIZE, cfg.TF_DIM, padding_idx=0) | |
| self.pos_enc = PositionalEncoding(cfg.TF_DIM) | |
| self.mm_proj = nn.Linear(4, cfg.TF_DIM) | |
| enc_layer = nn.TransformerEncoderLayer( | |
| d_model=cfg.TF_DIM, nhead=cfg.TF_HEADS, | |
| dim_feedforward=cfg.TF_FF_DIM, dropout=cfg.TF_DROPOUT, | |
| activation='gelu', batch_first=True, norm_first=True, | |
| ) | |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg.TF_LAYERS) | |
| self.cross_attn = nn.MultiheadAttention( | |
| cfg.TF_DIM, cfg.TF_HEADS, dropout=cfg.TF_DROPOUT, batch_first=True, | |
| ) | |
| self.out_dim = cfg.TF_DIM * 4 # 256 * 4 = 1024 | |
| def forward(self, sg, off, mm): | |
| mm_oh = F.one_hot(mm, num_classes=4).float() | |
| sg_e = self.pos_enc(self.embed(sg)) + self.mm_proj(mm_oh) | |
| off_e = self.pos_enc(self.embed(off)) | |
| sg_enc = self.encoder(sg_e) | |
| off_enc = self.encoder(off_e) | |
| cross, _ = self.cross_attn(sg_enc, off_enc, off_enc) | |
| sg_feat = torch.cat([cross.mean(1), cross.max(1)[0]], dim=-1) | |
| off_feat = torch.cat([off_enc.mean(1), off_enc.max(1)[0]], dim=-1) | |
| return torch.cat([sg_feat, off_feat], dim=-1) # (B, 1024) | |
| # βββββββββββββββββββββββββββ BiLSTM βββββββββββββββββββββββββββββββββββββββββ | |
| class BiLSTMEncoder(nn.Module): | |
| """Bidirectional LSTM encoder.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.embed = nn.Embedding(cfg.VOCAB_SIZE, cfg.EMBED_DIM, padding_idx=0) | |
| self.lstm = nn.LSTM( | |
| input_size=cfg.EMBED_DIM * 2 + 4, | |
| hidden_size=cfg.LSTM_HIDDEN, | |
| num_layers=cfg.LSTM_LAYERS, | |
| batch_first=True, | |
| bidirectional=True, | |
| dropout=0.2, | |
| ) | |
| self.out_dim = cfg.LSTM_HIDDEN * 2 * 2 # 128 * 2 (bidir) * 2 (cat) = 512 | |
| def forward(self, sg, off, mm): | |
| sg_e = self.embed(sg) | |
| off_e = self.embed(off) | |
| mm_oh = F.one_hot(mm, num_classes=4).float() | |
| x = torch.cat([sg_e, off_e, mm_oh], dim=-1) | |
| out, (h, _) = self.lstm(x) | |
| mean_pool = out.mean(dim=1) | |
| last_hidden = torch.cat([h[-2], h[-1]], dim=-1) | |
| return torch.cat([mean_pool, last_hidden], dim=-1) # (B, 512) | |
| # βββββββββββββββββββββββββββ MEGA MODEL (FUSION) βββββββββββββββββββββββββββ | |
| class CRISPRMegaModel(nn.Module): | |
| """ | |
| Fusion of MultiScaleCNN + CRISPRTransformer + BiLSTMEncoder. | |
| Outputs 256-dimensional embeddings + off-target / efficiency heads. | |
| Total input to fusion: 1152 (CNN) + 1024 (TF) + 512 (LSTM) = 2688 | |
| Fusion layers: 2688 β 1024 β 512 β 256 | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.cnn = MultiScaleCNN() | |
| self.transformer = CRISPRTransformer() | |
| self.bilstm = BiLSTMEncoder() | |
| total_feats = self.cnn.out_dim + self.transformer.out_dim + self.bilstm.out_dim | |
| self.fusion = nn.Sequential( | |
| nn.Linear(total_feats, 1024), nn.LayerNorm(1024), nn.GELU(), nn.Dropout(0.3), | |
| nn.Linear(1024, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.2), | |
| nn.Linear(512, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(0.1), | |
| ) | |
| self.off_head = nn.Linear(256, 1) | |
| self.eff_head = nn.Linear(256, 1) | |
| self.emb_head = nn.Identity() | |
| def forward(self, sg, off, mm): | |
| cnn_out = self.cnn(sg, off, mm) | |
| tf_out = self.transformer(sg, off, mm) | |
| lstm_out = self.bilstm(sg, off, mm) | |
| combined = torch.cat([cnn_out, tf_out, lstm_out], dim=-1) | |
| emb = self.fusion(combined) # (B, 256) | |
| off_logit = self.off_head(emb).squeeze(-1) | |
| eff = torch.sigmoid(self.eff_head(emb)).squeeze(-1) | |
| off_prob = torch.sigmoid(off_logit) | |
| return { | |
| 'logit': off_logit, | |
| 'off_prob': off_prob, | |
| 'eff': eff, | |
| 'embedding': emb, # 256-dim | |
| } | |
| # βββββββββββββββββββββββββββ BIOLOGICAL FEATURES βββββββββββββββββββββββββββ | |
| def extract_bio_features(sgrna: str, offtarget: str) -> Dict: | |
| """ | |
| Compute hand-crafted biological features for an sgRNA / off-target pair. | |
| Returns a dict with ~50 features used as extra columns for AutoGluon. | |
| """ | |
| SEQ_LEN = cfg.SEQ_LEN | |
| sg = sgrna.upper()[:SEQ_LEN].ljust(SEQ_LEN, 'N') | |
| off = offtarget.upper()[:SEQ_LEN].ljust(SEQ_LEN, 'N') | |
| MM = cfg.MISMATCH_MATRIX | |
| mismatches = [(i, sg[i], off[i]) for i in range(SEQ_LEN) if sg[i] != off[i]] | |
| mm_positions = [m[0] for m in mismatches] | |
| seed_mms = [m for m in mismatches if m[0] >= SEQ_LEN - 12] | |
| def mm_type(a, b): | |
| return MM.get((a, b), 'match') | |
| feats = {} | |
| # ββ Mismatch counts ββ | |
| feats['n_mismatches'] = len(mismatches) | |
| feats['n_transitions'] = sum(1 for _, a, b in mismatches if mm_type(a, b) == 'transition') | |
| feats['n_transversions'] = sum(1 for _, a, b in mismatches if mm_type(a, b) == 'transversion') | |
| feats['seed_mismatches'] = len(seed_mms) | |
| feats['seed_transitions'] = sum(1 for _, a, b in seed_mms if mm_type(a, b) == 'transition') | |
| feats['seed_transversions'] = sum(1 for _, a, b in seed_mms if mm_type(a, b) == 'transversion') | |
| feats['pam_proximal_mm'] = sum(1 for m in mismatches if m[0] >= SEQ_LEN - 5) | |
| # ββ Positional features ββ | |
| feats['first_mm_pos'] = mm_positions[0] if mm_positions else -1 | |
| feats['last_mm_pos'] = mm_positions[-1] if mm_positions else -1 | |
| feats['mm_span'] = (mm_positions[-1] - mm_positions[0]) if len(mm_positions) > 1 else 0 | |
| if len(mm_positions) >= 2: | |
| gaps = [mm_positions[i + 1] - mm_positions[i] for i in range(len(mm_positions) - 1)] | |
| feats['mm_min_gap'] = min(gaps) | |
| feats['mm_mean_gap'] = float(np.mean(gaps)) | |
| feats['mm_clustered'] = float(min(gaps) <= 2) | |
| else: | |
| feats['mm_min_gap'] = SEQ_LEN | |
| feats['mm_mean_gap'] = float(SEQ_LEN) | |
| feats['mm_clustered'] = 0.0 | |
| # ββ Nucleotide composition ββ | |
| for nt in 'ACGT': | |
| feats[f'sg_{nt}_frac'] = sg.count(nt) / SEQ_LEN | |
| feats[f'off_{nt}_frac'] = off.count(nt) / SEQ_LEN | |
| feats['sg_gc'] = (sg.count('G') + sg.count('C')) / SEQ_LEN | |
| feats['off_gc'] = (off.count('G') + off.count('C')) / SEQ_LEN | |
| # ββ Thermodynamic proxy & penalties ββ | |
| at = sg.count('A') + sg.count('T') | |
| gc = sg.count('G') + sg.count('C') | |
| feats['sg_tm_proxy'] = 2 * at + 4 * gc | |
| feats['weighted_mm_penalty'] = sum((1.0 + p / SEQ_LEN) for p in mm_positions) | |
| feats['edit_dist_norm'] = len(mismatches) / SEQ_LEN | |
| feats['pam_is_ngg'] = float(sg[-2:] == 'GG' if len(sg) >= 2 else False) | |
| # ββ Per-position mismatch flags ββ | |
| for i in range(SEQ_LEN): | |
| feats[f'mm_pos_{i}'] = float(sg[i] != off[i]) | |
| return feats | |