|
|
""" |
|
|
MutationEncoder for DNA mutation encoding |
|
|
""" |
|
|
import torch |
|
|
|
|
|
class MutationEncoder: |
|
|
"""Encode DNA mutations for neural network input""" |
|
|
|
|
|
MUTATION_TYPES = { |
|
|
('A','T'): 0, ('A','C'): 1, ('A','G'): 2, |
|
|
('T','A'): 3, ('T','C'): 4, ('T','G'): 5, |
|
|
('C','A'): 6, ('C','T'): 7, ('C','G'): 8, |
|
|
('G','A'): 9, ('G','T'): 10, ('G','C'): 11 |
|
|
} |
|
|
|
|
|
def __init__(self): |
|
|
self.nucl_to_idx = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4} |
|
|
|
|
|
def encode_sequence(self, sequence, max_len=99): |
|
|
"""One-hot encode DNA sequence""" |
|
|
sequence = sequence.upper()[:max_len] |
|
|
sequence = sequence + 'N' * (max_len - len(sequence)) |
|
|
encoded = torch.zeros(max_len, 5) |
|
|
for i, nucl in enumerate(sequence): |
|
|
idx = self.nucl_to_idx.get(nucl, 4) |
|
|
encoded[i, idx] = 1.0 |
|
|
return encoded |
|
|
|
|
|
def encode_mutation(self, ref_seq, mut_seq): |
|
|
""" |
|
|
Encode mutation as 1101-dimensional tensor |
|
|
|
|
|
Args: |
|
|
ref_seq: Reference DNA sequence (99bp) |
|
|
mut_seq: Mutated DNA sequence (99bp) |
|
|
|
|
|
Returns: |
|
|
torch.Tensor of shape (1101,) |
|
|
""" |
|
|
max_len = 99 |
|
|
ref_encoded = self.encode_sequence(ref_seq, max_len) |
|
|
mut_encoded = self.encode_sequence(mut_seq, max_len) |
|
|
|
|
|
|
|
|
diff_mask = torch.zeros(max_len, 1) |
|
|
mut_pos = None |
|
|
ref_base = None |
|
|
mut_base = None |
|
|
|
|
|
for i in range(min(len(ref_seq), len(mut_seq), max_len)): |
|
|
if ref_seq[i] != mut_seq[i]: |
|
|
diff_mask[i, 0] = 1.0 |
|
|
if mut_pos is None: |
|
|
mut_pos = i |
|
|
ref_base = ref_seq[i].upper() |
|
|
mut_base = mut_seq[i].upper() |
|
|
|
|
|
|
|
|
mutation_onehot = torch.zeros(12) |
|
|
if ref_base is not None and mut_base is not None: |
|
|
key = (ref_base, mut_base) |
|
|
idx = self.MUTATION_TYPES.get(key, None) |
|
|
if idx is not None: |
|
|
mutation_onehot[idx] = 1.0 |
|
|
|
|
|
|
|
|
seq_part = torch.cat([ref_encoded, mut_encoded, diff_mask], dim=1) |
|
|
seq_flat = seq_part.flatten() |
|
|
return torch.cat([seq_flat, mutation_onehot]) |
|
|
|