Spaces:
Runtime error
Runtime error
| """ | |
| Sequence Encoder for CRISPR sgRNA-DNA Pairs | |
| Converts sgRNA and DNA sequences into paired encodings for CNN and BERT models | |
| """ | |
| import numpy as np | |
| # ========== ENCODING DICTIONARIES ========== | |
| # CNN: 7-dimensional one-hot encoding for each paired token | |
| CNN_ENCODING = { | |
| 'AA': [1, 0, 0, 0, 0, 0, 0], 'AT': [1, 1, 0, 0, 0, 1, 0], 'AG': [1, 0, 1, 0, 0, 1, 0], 'AC': [1, 0, 0, 1, 0, 1, 0], | |
| 'TA': [1, 1, 0, 0, 0, 0, 1], 'TT': [0, 1, 0, 0, 0, 0, 0], 'TG': [0, 1, 1, 0, 0, 1, 0], 'TC': [0, 1, 0, 1, 0, 1, 0], | |
| 'GA': [1, 0, 1, 0, 0, 0, 1], 'GT': [0, 1, 1, 0, 0, 0, 1], 'GG': [0, 0, 1, 0, 0, 0, 0], 'GC': [0, 0, 1, 1, 0, 1, 0], | |
| 'CA': [1, 0, 0, 1, 0, 0, 1], 'CT': [0, 1, 0, 1, 0, 0, 1], 'CG': [0, 0, 1, 1, 0, 0, 1], 'CC': [0, 0, 0, 1, 0, 0, 0], | |
| 'A_': [1, 0, 0, 0, 1, 1, 0], 'T_': [0, 1, 0, 0, 1, 1, 0], 'G_': [0, 0, 1, 0, 1, 1, 0], 'C_': [0, 0, 0, 1, 1, 1, 0], | |
| '_A': [1, 0, 0, 0, 1, 0, 1], '_T': [0, 1, 0, 0, 1, 0, 1], '_G': [0, 0, 1, 0, 1, 0, 1], '_C': [0, 0, 0, 1, 1, 0, 1], | |
| '--': [0, 0, 0, 0, 0, 0, 0], | |
| '[CLS]': [0, 0, 0, 0, 0, 0, 0], # Special token for start | |
| '[SEP]': [0, 0, 0, 0, 0, 0, 0], # Special token for end | |
| '[PAD]': [0, 0, 0, 0, 0, 0, 0] # Padding token | |
| } | |
| # Fixed sequence length (excluding special tokens) | |
| FIXED_SEQ_LENGTH = 24 # 24 base pairs | |
| TOTAL_LENGTH = 26 # 24 pairs + [CLS] + [SEP] | |
| # BERT: Token ID mapping for each paired token | |
| BERT_TOKEN_DICT = { | |
| "AA": 2, "AC": 3, "AG": 4, "AT": 5, | |
| "CA": 6, "CC": 7, "CG": 8, "CT": 9, | |
| "GA": 10, "GC": 11, "GG": 12, "GT": 13, | |
| "TA": 14, "TC": 15, "TG": 16, "TT": 17, | |
| "A_": 18, "_A": 19, "C_": 20, "_C": 21, | |
| "G_": 22, "_G": 23, "T_": 24, "_T": 25, | |
| "--": 26, # Both positions deleted | |
| "[CLS]": 0, "[SEP]": 1, "[PAD]": 27 | |
| } | |
| # ========== SEQUENCE PAIRING ========== | |
| def pair_sequences(sgrna, dna): | |
| """ | |
| Pair sgRNA and DNA sequences character by character. | |
| Args: | |
| sgrna (str): sgRNA sequence (e.g., "GAGTCCGAGCAG") | |
| dna (str): DNA sequence (e.g., "GGAGTCCGTGCA") | |
| Returns: | |
| list: Paired tokens (e.g., ["GG", "GA", "AG", ...]) | |
| """ | |
| # Convert to uppercase for consistency | |
| sgrna = sgrna.upper() | |
| dna = dna.upper() | |
| # Ensure sequences are same length | |
| if len(sgrna) != len(dna): | |
| raise ValueError(f"Sequences must be same length: sgRNA={len(sgrna)}, DNA={len(dna)}") | |
| # Pair each position: sgRNA[i] + DNA[i] | |
| paired_tokens = [] | |
| for i in range(len(sgrna)): | |
| pair = sgrna[i] + dna[i] | |
| # Convert __ to -- (both positions deleted) | |
| if pair == '__': | |
| pair = '--' | |
| paired_tokens.append(pair) | |
| return paired_tokens | |
| # ========== CNN ENCODING ========== | |
| def encode_for_cnn(sgrna, dna, fixed_length=FIXED_SEQ_LENGTH): | |
| """ | |
| Encode sgRNA-DNA pair for CNN model with FIXED LENGTH. | |
| Returns 26x7 matrix: [CLS] + 24 base pairs + [SEP] | |
| Args: | |
| sgrna (str): sgRNA sequence | |
| dna (str): DNA sequence | |
| fixed_length (int): Fixed sequence length (default: 24) | |
| Returns: | |
| numpy.ndarray: Shape (26, 7) - binary encoded matrix | |
| """ | |
| # Step 1: Pair the sequences | |
| paired_tokens = pair_sequences(sgrna, dna) | |
| # Step 2: Pad or truncate to fixed length (24 positions) | |
| if len(paired_tokens) < fixed_length: | |
| # Pad with [PAD] tokens | |
| padding_needed = fixed_length - len(paired_tokens) | |
| paired_tokens = paired_tokens + ['[PAD]'] * padding_needed | |
| elif len(paired_tokens) > fixed_length: | |
| # Truncate to fixed length | |
| paired_tokens = paired_tokens[:fixed_length] | |
| # Step 3: Add [CLS] at start and [SEP] at end | |
| paired_tokens = ['[CLS]'] + paired_tokens + ['[SEP]'] | |
| # Step 4: Convert each token to 7-dim vector | |
| encoded_sequence = [] | |
| for token in paired_tokens: | |
| if token in CNN_ENCODING: | |
| encoded_sequence.append(CNN_ENCODING[token]) | |
| else: | |
| # Unknown token: use all zeros | |
| encoded_sequence.append([0, 0, 0, 0, 0, 0, 0]) | |
| # Step 5: Convert to numpy array - should be (26, 7) | |
| result = np.array(encoded_sequence, dtype=np.float32) | |
| assert result.shape == (TOTAL_LENGTH, 7), f"Expected shape (26, 7), got {result.shape}" | |
| return result | |
| # ========== BERT ENCODING ========== | |
| def encode_for_bert(sgrna, dna, fixed_length=FIXED_SEQ_LENGTH): | |
| """ | |
| Encode sgRNA-DNA pair for BERT model with FIXED LENGTH. | |
| Returns length 26: [CLS] + 24 base pairs + [SEP] | |
| Args: | |
| sgrna (str): sgRNA sequence | |
| dna (str): DNA sequence | |
| fixed_length (int): Fixed sequence length (default: 24) | |
| Returns: | |
| numpy.ndarray: Token IDs (shape: 26) | |
| """ | |
| # Step 1: Pair the sequences | |
| paired_tokens = pair_sequences(sgrna, dna) | |
| # Step 2: Pad or truncate to fixed length (24 positions) | |
| if len(paired_tokens) < fixed_length: | |
| # Pad with [PAD] tokens | |
| padding_needed = fixed_length - len(paired_tokens) | |
| paired_tokens = paired_tokens + ['[PAD]'] * padding_needed | |
| elif len(paired_tokens) > fixed_length: | |
| # Truncate to fixed length | |
| paired_tokens = paired_tokens[:fixed_length] | |
| # Step 3: Add [CLS] at start and [SEP] at end | |
| paired_tokens = ['[CLS]'] + paired_tokens + ['[SEP]'] | |
| # Step 4: Convert each token to token ID | |
| token_ids = [] | |
| for token in paired_tokens: | |
| if token in BERT_TOKEN_DICT: | |
| token_ids.append(BERT_TOKEN_DICT[token]) | |
| else: | |
| # Unknown token: use [PAD] token ID | |
| token_ids.append(BERT_TOKEN_DICT["[PAD]"]) | |
| # Step 5: Convert to numpy array - should be (26,) | |
| result = np.array(token_ids, dtype=np.int32) | |
| assert result.shape == (TOTAL_LENGTH,), f"Expected shape (26,), got {result.shape}" | |
| return result | |
| # ========== BATCH ENCODING ========== | |
| def encode_batch_for_cnn(sgrna_list, dna_list): | |
| """ | |
| Encode multiple sgRNA-DNA pairs for CNN model. | |
| All sequences are encoded to fixed size (26, 7). | |
| Args: | |
| sgrna_list (list): List of sgRNA sequences | |
| dna_list (list): List of DNA sequences | |
| Returns: | |
| numpy.ndarray: Shape (batch_size, 26, 7) | |
| """ | |
| # Encode all sequences - each will be (26, 7) | |
| batch_encoded = [] | |
| for sgrna, dna in zip(sgrna_list, dna_list): | |
| encoded = encode_for_cnn(sgrna, dna) # Returns (26, 7) | |
| batch_encoded.append(encoded) | |
| # Stack into batch | |
| return np.array(batch_encoded, dtype=np.float32) | |
| def encode_batch_for_bert(sgrna_list, dna_list): | |
| """ | |
| Encode multiple sgRNA-DNA pairs for BERT model. | |
| All sequences are encoded to fixed size (26,). | |
| Args: | |
| sgrna_list (list): List of sgRNA sequences | |
| dna_list (list): List of DNA sequences | |
| Returns: | |
| numpy.ndarray: Shape (batch_size, 26) | |
| """ | |
| # Encode all sequences - each will be (26,) | |
| batch_encoded = [] | |
| for sgrna, dna in zip(sgrna_list, dna_list): | |
| encoded = encode_for_bert(sgrna, dna) # Returns (26,) | |
| batch_encoded.append(encoded) | |
| # Stack into batch | |
| return np.array(batch_encoded, dtype=np.int32) | |
| # ========== UTILITY FUNCTIONS ========== | |
| def decode_cnn_encoding(encoded_matrix): | |
| """ | |
| Decode CNN encoded matrix back to paired tokens (for debugging). | |
| Args: | |
| encoded_matrix (numpy.ndarray): Shape (seq_length, 7) | |
| Returns: | |
| list: Paired tokens | |
| """ | |
| # Create reverse mapping | |
| reverse_map = {tuple(v): k for k, v in CNN_ENCODING.items()} | |
| decoded_tokens = [] | |
| for vector in encoded_matrix: | |
| key = tuple(vector.astype(int).tolist()) | |
| token = reverse_map.get(key, "??") | |
| decoded_tokens.append(token) | |
| return decoded_tokens | |
| def decode_bert_encoding(token_ids): | |
| """ | |
| Decode BERT token IDs back to paired tokens (for debugging). | |
| Args: | |
| token_ids (numpy.ndarray): Token IDs | |
| Returns: | |
| list: Paired tokens | |
| """ | |
| # Create reverse mapping | |
| reverse_map = {v: k for k, v in BERT_TOKEN_DICT.items()} | |
| decoded_tokens = [] | |
| for token_id in token_ids: | |
| token = reverse_map.get(int(token_id), "??") | |
| decoded_tokens.append(token) | |
| return decoded_tokens | |