crispr-bert-api / sequence_encoder.py
santu0032's picture
Upload 9 files
f6e15bf verified
Raw
History Blame Contribute Delete
8.55 kB
"""
Sequence Encoder for CRISPR sgRNA-DNA Pairs
Converts sgRNA and DNA sequences into paired encodings for CNN and BERT models
"""
import numpy as np
# ========== ENCODING DICTIONARIES ==========
# CNN: 7-dimensional one-hot encoding for each paired token
CNN_ENCODING = {
'AA': [1, 0, 0, 0, 0, 0, 0], 'AT': [1, 1, 0, 0, 0, 1, 0], 'AG': [1, 0, 1, 0, 0, 1, 0], 'AC': [1, 0, 0, 1, 0, 1, 0],
'TA': [1, 1, 0, 0, 0, 0, 1], 'TT': [0, 1, 0, 0, 0, 0, 0], 'TG': [0, 1, 1, 0, 0, 1, 0], 'TC': [0, 1, 0, 1, 0, 1, 0],
'GA': [1, 0, 1, 0, 0, 0, 1], 'GT': [0, 1, 1, 0, 0, 0, 1], 'GG': [0, 0, 1, 0, 0, 0, 0], 'GC': [0, 0, 1, 1, 0, 1, 0],
'CA': [1, 0, 0, 1, 0, 0, 1], 'CT': [0, 1, 0, 1, 0, 0, 1], 'CG': [0, 0, 1, 1, 0, 0, 1], 'CC': [0, 0, 0, 1, 0, 0, 0],
'A_': [1, 0, 0, 0, 1, 1, 0], 'T_': [0, 1, 0, 0, 1, 1, 0], 'G_': [0, 0, 1, 0, 1, 1, 0], 'C_': [0, 0, 0, 1, 1, 1, 0],
'_A': [1, 0, 0, 0, 1, 0, 1], '_T': [0, 1, 0, 0, 1, 0, 1], '_G': [0, 0, 1, 0, 1, 0, 1], '_C': [0, 0, 0, 1, 1, 0, 1],
'--': [0, 0, 0, 0, 0, 0, 0],
'[CLS]': [0, 0, 0, 0, 0, 0, 0], # Special token for start
'[SEP]': [0, 0, 0, 0, 0, 0, 0], # Special token for end
'[PAD]': [0, 0, 0, 0, 0, 0, 0] # Padding token
}
# Fixed sequence length (excluding special tokens)
FIXED_SEQ_LENGTH = 24 # 24 base pairs
TOTAL_LENGTH = 26 # 24 pairs + [CLS] + [SEP]
# BERT: Token ID mapping for each paired token
BERT_TOKEN_DICT = {
"AA": 2, "AC": 3, "AG": 4, "AT": 5,
"CA": 6, "CC": 7, "CG": 8, "CT": 9,
"GA": 10, "GC": 11, "GG": 12, "GT": 13,
"TA": 14, "TC": 15, "TG": 16, "TT": 17,
"A_": 18, "_A": 19, "C_": 20, "_C": 21,
"G_": 22, "_G": 23, "T_": 24, "_T": 25,
"--": 26, # Both positions deleted
"[CLS]": 0, "[SEP]": 1, "[PAD]": 27
}
# ========== SEQUENCE PAIRING ==========
def pair_sequences(sgrna, dna):
"""
Pair sgRNA and DNA sequences character by character.
Args:
sgrna (str): sgRNA sequence (e.g., "GAGTCCGAGCAG")
dna (str): DNA sequence (e.g., "GGAGTCCGTGCA")
Returns:
list: Paired tokens (e.g., ["GG", "GA", "AG", ...])
"""
# Convert to uppercase for consistency
sgrna = sgrna.upper()
dna = dna.upper()
# Ensure sequences are same length
if len(sgrna) != len(dna):
raise ValueError(f"Sequences must be same length: sgRNA={len(sgrna)}, DNA={len(dna)}")
# Pair each position: sgRNA[i] + DNA[i]
paired_tokens = []
for i in range(len(sgrna)):
pair = sgrna[i] + dna[i]
# Convert __ to -- (both positions deleted)
if pair == '__':
pair = '--'
paired_tokens.append(pair)
return paired_tokens
# ========== CNN ENCODING ==========
def encode_for_cnn(sgrna, dna, fixed_length=FIXED_SEQ_LENGTH):
"""
Encode sgRNA-DNA pair for CNN model with FIXED LENGTH.
Returns 26x7 matrix: [CLS] + 24 base pairs + [SEP]
Args:
sgrna (str): sgRNA sequence
dna (str): DNA sequence
fixed_length (int): Fixed sequence length (default: 24)
Returns:
numpy.ndarray: Shape (26, 7) - binary encoded matrix
"""
# Step 1: Pair the sequences
paired_tokens = pair_sequences(sgrna, dna)
# Step 2: Pad or truncate to fixed length (24 positions)
if len(paired_tokens) < fixed_length:
# Pad with [PAD] tokens
padding_needed = fixed_length - len(paired_tokens)
paired_tokens = paired_tokens + ['[PAD]'] * padding_needed
elif len(paired_tokens) > fixed_length:
# Truncate to fixed length
paired_tokens = paired_tokens[:fixed_length]
# Step 3: Add [CLS] at start and [SEP] at end
paired_tokens = ['[CLS]'] + paired_tokens + ['[SEP]']
# Step 4: Convert each token to 7-dim vector
encoded_sequence = []
for token in paired_tokens:
if token in CNN_ENCODING:
encoded_sequence.append(CNN_ENCODING[token])
else:
# Unknown token: use all zeros
encoded_sequence.append([0, 0, 0, 0, 0, 0, 0])
# Step 5: Convert to numpy array - should be (26, 7)
result = np.array(encoded_sequence, dtype=np.float32)
assert result.shape == (TOTAL_LENGTH, 7), f"Expected shape (26, 7), got {result.shape}"
return result
# ========== BERT ENCODING ==========
def encode_for_bert(sgrna, dna, fixed_length=FIXED_SEQ_LENGTH):
"""
Encode sgRNA-DNA pair for BERT model with FIXED LENGTH.
Returns length 26: [CLS] + 24 base pairs + [SEP]
Args:
sgrna (str): sgRNA sequence
dna (str): DNA sequence
fixed_length (int): Fixed sequence length (default: 24)
Returns:
numpy.ndarray: Token IDs (shape: 26)
"""
# Step 1: Pair the sequences
paired_tokens = pair_sequences(sgrna, dna)
# Step 2: Pad or truncate to fixed length (24 positions)
if len(paired_tokens) < fixed_length:
# Pad with [PAD] tokens
padding_needed = fixed_length - len(paired_tokens)
paired_tokens = paired_tokens + ['[PAD]'] * padding_needed
elif len(paired_tokens) > fixed_length:
# Truncate to fixed length
paired_tokens = paired_tokens[:fixed_length]
# Step 3: Add [CLS] at start and [SEP] at end
paired_tokens = ['[CLS]'] + paired_tokens + ['[SEP]']
# Step 4: Convert each token to token ID
token_ids = []
for token in paired_tokens:
if token in BERT_TOKEN_DICT:
token_ids.append(BERT_TOKEN_DICT[token])
else:
# Unknown token: use [PAD] token ID
token_ids.append(BERT_TOKEN_DICT["[PAD]"])
# Step 5: Convert to numpy array - should be (26,)
result = np.array(token_ids, dtype=np.int32)
assert result.shape == (TOTAL_LENGTH,), f"Expected shape (26,), got {result.shape}"
return result
# ========== BATCH ENCODING ==========
def encode_batch_for_cnn(sgrna_list, dna_list):
"""
Encode multiple sgRNA-DNA pairs for CNN model.
All sequences are encoded to fixed size (26, 7).
Args:
sgrna_list (list): List of sgRNA sequences
dna_list (list): List of DNA sequences
Returns:
numpy.ndarray: Shape (batch_size, 26, 7)
"""
# Encode all sequences - each will be (26, 7)
batch_encoded = []
for sgrna, dna in zip(sgrna_list, dna_list):
encoded = encode_for_cnn(sgrna, dna) # Returns (26, 7)
batch_encoded.append(encoded)
# Stack into batch
return np.array(batch_encoded, dtype=np.float32)
def encode_batch_for_bert(sgrna_list, dna_list):
"""
Encode multiple sgRNA-DNA pairs for BERT model.
All sequences are encoded to fixed size (26,).
Args:
sgrna_list (list): List of sgRNA sequences
dna_list (list): List of DNA sequences
Returns:
numpy.ndarray: Shape (batch_size, 26)
"""
# Encode all sequences - each will be (26,)
batch_encoded = []
for sgrna, dna in zip(sgrna_list, dna_list):
encoded = encode_for_bert(sgrna, dna) # Returns (26,)
batch_encoded.append(encoded)
# Stack into batch
return np.array(batch_encoded, dtype=np.int32)
# ========== UTILITY FUNCTIONS ==========
def decode_cnn_encoding(encoded_matrix):
"""
Decode CNN encoded matrix back to paired tokens (for debugging).
Args:
encoded_matrix (numpy.ndarray): Shape (seq_length, 7)
Returns:
list: Paired tokens
"""
# Create reverse mapping
reverse_map = {tuple(v): k for k, v in CNN_ENCODING.items()}
decoded_tokens = []
for vector in encoded_matrix:
key = tuple(vector.astype(int).tolist())
token = reverse_map.get(key, "??")
decoded_tokens.append(token)
return decoded_tokens
def decode_bert_encoding(token_ids):
"""
Decode BERT token IDs back to paired tokens (for debugging).
Args:
token_ids (numpy.ndarray): Token IDs
Returns:
list: Paired tokens
"""
# Create reverse mapping
reverse_map = {v: k for k, v in BERT_TOKEN_DICT.items()}
decoded_tokens = []
for token_id in token_ids:
token = reverse_map.get(int(token_id), "??")
decoded_tokens.append(token)
return decoded_tokens