Spaces:
Runtime error
Runtime error
File size: 8,550 Bytes
f6e15bf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 | """
Sequence Encoder for CRISPR sgRNA-DNA Pairs
Converts sgRNA and DNA sequences into paired encodings for CNN and BERT models
"""
import numpy as np
# ========== ENCODING DICTIONARIES ==========
# CNN: 7-dimensional one-hot encoding for each paired token
CNN_ENCODING = {
'AA': [1, 0, 0, 0, 0, 0, 0], 'AT': [1, 1, 0, 0, 0, 1, 0], 'AG': [1, 0, 1, 0, 0, 1, 0], 'AC': [1, 0, 0, 1, 0, 1, 0],
'TA': [1, 1, 0, 0, 0, 0, 1], 'TT': [0, 1, 0, 0, 0, 0, 0], 'TG': [0, 1, 1, 0, 0, 1, 0], 'TC': [0, 1, 0, 1, 0, 1, 0],
'GA': [1, 0, 1, 0, 0, 0, 1], 'GT': [0, 1, 1, 0, 0, 0, 1], 'GG': [0, 0, 1, 0, 0, 0, 0], 'GC': [0, 0, 1, 1, 0, 1, 0],
'CA': [1, 0, 0, 1, 0, 0, 1], 'CT': [0, 1, 0, 1, 0, 0, 1], 'CG': [0, 0, 1, 1, 0, 0, 1], 'CC': [0, 0, 0, 1, 0, 0, 0],
'A_': [1, 0, 0, 0, 1, 1, 0], 'T_': [0, 1, 0, 0, 1, 1, 0], 'G_': [0, 0, 1, 0, 1, 1, 0], 'C_': [0, 0, 0, 1, 1, 1, 0],
'_A': [1, 0, 0, 0, 1, 0, 1], '_T': [0, 1, 0, 0, 1, 0, 1], '_G': [0, 0, 1, 0, 1, 0, 1], '_C': [0, 0, 0, 1, 1, 0, 1],
'--': [0, 0, 0, 0, 0, 0, 0],
'[CLS]': [0, 0, 0, 0, 0, 0, 0], # Special token for start
'[SEP]': [0, 0, 0, 0, 0, 0, 0], # Special token for end
'[PAD]': [0, 0, 0, 0, 0, 0, 0] # Padding token
}
# Fixed sequence length (excluding special tokens)
FIXED_SEQ_LENGTH = 24 # 24 base pairs
TOTAL_LENGTH = 26 # 24 pairs + [CLS] + [SEP]
# BERT: Token ID mapping for each paired token
BERT_TOKEN_DICT = {
"AA": 2, "AC": 3, "AG": 4, "AT": 5,
"CA": 6, "CC": 7, "CG": 8, "CT": 9,
"GA": 10, "GC": 11, "GG": 12, "GT": 13,
"TA": 14, "TC": 15, "TG": 16, "TT": 17,
"A_": 18, "_A": 19, "C_": 20, "_C": 21,
"G_": 22, "_G": 23, "T_": 24, "_T": 25,
"--": 26, # Both positions deleted
"[CLS]": 0, "[SEP]": 1, "[PAD]": 27
}
# ========== SEQUENCE PAIRING ==========
def pair_sequences(sgrna, dna):
"""
Pair sgRNA and DNA sequences character by character.
Args:
sgrna (str): sgRNA sequence (e.g., "GAGTCCGAGCAG")
dna (str): DNA sequence (e.g., "GGAGTCCGTGCA")
Returns:
list: Paired tokens (e.g., ["GG", "GA", "AG", ...])
"""
# Convert to uppercase for consistency
sgrna = sgrna.upper()
dna = dna.upper()
# Ensure sequences are same length
if len(sgrna) != len(dna):
raise ValueError(f"Sequences must be same length: sgRNA={len(sgrna)}, DNA={len(dna)}")
# Pair each position: sgRNA[i] + DNA[i]
paired_tokens = []
for i in range(len(sgrna)):
pair = sgrna[i] + dna[i]
# Convert __ to -- (both positions deleted)
if pair == '__':
pair = '--'
paired_tokens.append(pair)
return paired_tokens
# ========== CNN ENCODING ==========
def encode_for_cnn(sgrna, dna, fixed_length=FIXED_SEQ_LENGTH):
"""
Encode sgRNA-DNA pair for CNN model with FIXED LENGTH.
Returns 26x7 matrix: [CLS] + 24 base pairs + [SEP]
Args:
sgrna (str): sgRNA sequence
dna (str): DNA sequence
fixed_length (int): Fixed sequence length (default: 24)
Returns:
numpy.ndarray: Shape (26, 7) - binary encoded matrix
"""
# Step 1: Pair the sequences
paired_tokens = pair_sequences(sgrna, dna)
# Step 2: Pad or truncate to fixed length (24 positions)
if len(paired_tokens) < fixed_length:
# Pad with [PAD] tokens
padding_needed = fixed_length - len(paired_tokens)
paired_tokens = paired_tokens + ['[PAD]'] * padding_needed
elif len(paired_tokens) > fixed_length:
# Truncate to fixed length
paired_tokens = paired_tokens[:fixed_length]
# Step 3: Add [CLS] at start and [SEP] at end
paired_tokens = ['[CLS]'] + paired_tokens + ['[SEP]']
# Step 4: Convert each token to 7-dim vector
encoded_sequence = []
for token in paired_tokens:
if token in CNN_ENCODING:
encoded_sequence.append(CNN_ENCODING[token])
else:
# Unknown token: use all zeros
encoded_sequence.append([0, 0, 0, 0, 0, 0, 0])
# Step 5: Convert to numpy array - should be (26, 7)
result = np.array(encoded_sequence, dtype=np.float32)
assert result.shape == (TOTAL_LENGTH, 7), f"Expected shape (26, 7), got {result.shape}"
return result
# ========== BERT ENCODING ==========
def encode_for_bert(sgrna, dna, fixed_length=FIXED_SEQ_LENGTH):
"""
Encode sgRNA-DNA pair for BERT model with FIXED LENGTH.
Returns length 26: [CLS] + 24 base pairs + [SEP]
Args:
sgrna (str): sgRNA sequence
dna (str): DNA sequence
fixed_length (int): Fixed sequence length (default: 24)
Returns:
numpy.ndarray: Token IDs (shape: 26)
"""
# Step 1: Pair the sequences
paired_tokens = pair_sequences(sgrna, dna)
# Step 2: Pad or truncate to fixed length (24 positions)
if len(paired_tokens) < fixed_length:
# Pad with [PAD] tokens
padding_needed = fixed_length - len(paired_tokens)
paired_tokens = paired_tokens + ['[PAD]'] * padding_needed
elif len(paired_tokens) > fixed_length:
# Truncate to fixed length
paired_tokens = paired_tokens[:fixed_length]
# Step 3: Add [CLS] at start and [SEP] at end
paired_tokens = ['[CLS]'] + paired_tokens + ['[SEP]']
# Step 4: Convert each token to token ID
token_ids = []
for token in paired_tokens:
if token in BERT_TOKEN_DICT:
token_ids.append(BERT_TOKEN_DICT[token])
else:
# Unknown token: use [PAD] token ID
token_ids.append(BERT_TOKEN_DICT["[PAD]"])
# Step 5: Convert to numpy array - should be (26,)
result = np.array(token_ids, dtype=np.int32)
assert result.shape == (TOTAL_LENGTH,), f"Expected shape (26,), got {result.shape}"
return result
# ========== BATCH ENCODING ==========
def encode_batch_for_cnn(sgrna_list, dna_list):
"""
Encode multiple sgRNA-DNA pairs for CNN model.
All sequences are encoded to fixed size (26, 7).
Args:
sgrna_list (list): List of sgRNA sequences
dna_list (list): List of DNA sequences
Returns:
numpy.ndarray: Shape (batch_size, 26, 7)
"""
# Encode all sequences - each will be (26, 7)
batch_encoded = []
for sgrna, dna in zip(sgrna_list, dna_list):
encoded = encode_for_cnn(sgrna, dna) # Returns (26, 7)
batch_encoded.append(encoded)
# Stack into batch
return np.array(batch_encoded, dtype=np.float32)
def encode_batch_for_bert(sgrna_list, dna_list):
"""
Encode multiple sgRNA-DNA pairs for BERT model.
All sequences are encoded to fixed size (26,).
Args:
sgrna_list (list): List of sgRNA sequences
dna_list (list): List of DNA sequences
Returns:
numpy.ndarray: Shape (batch_size, 26)
"""
# Encode all sequences - each will be (26,)
batch_encoded = []
for sgrna, dna in zip(sgrna_list, dna_list):
encoded = encode_for_bert(sgrna, dna) # Returns (26,)
batch_encoded.append(encoded)
# Stack into batch
return np.array(batch_encoded, dtype=np.int32)
# ========== UTILITY FUNCTIONS ==========
def decode_cnn_encoding(encoded_matrix):
"""
Decode CNN encoded matrix back to paired tokens (for debugging).
Args:
encoded_matrix (numpy.ndarray): Shape (seq_length, 7)
Returns:
list: Paired tokens
"""
# Create reverse mapping
reverse_map = {tuple(v): k for k, v in CNN_ENCODING.items()}
decoded_tokens = []
for vector in encoded_matrix:
key = tuple(vector.astype(int).tolist())
token = reverse_map.get(key, "??")
decoded_tokens.append(token)
return decoded_tokens
def decode_bert_encoding(token_ids):
"""
Decode BERT token IDs back to paired tokens (for debugging).
Args:
token_ids (numpy.ndarray): Token IDs
Returns:
list: Paired tokens
"""
# Create reverse mapping
reverse_map = {v: k for k, v in BERT_TOKEN_DICT.items()}
decoded_tokens = []
for token_id in token_ids:
token = reverse_map.get(int(token_id), "??")
decoded_tokens.append(token)
return decoded_tokens
|