| | """ |
| | MutationEncoder for DNA mutation encoding |
| | """ |
| | import torch |
| |
|
| | class MutationEncoder: |
| | """Encode DNA mutations for neural network input""" |
| | |
| | MUTATION_TYPES = { |
| | ('A','T'): 0, ('A','C'): 1, ('A','G'): 2, |
| | ('T','A'): 3, ('T','C'): 4, ('T','G'): 5, |
| | ('C','A'): 6, ('C','T'): 7, ('C','G'): 8, |
| | ('G','A'): 9, ('G','T'): 10, ('G','C'): 11 |
| | } |
| |
|
| | def __init__(self): |
| | self.nucl_to_idx = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4} |
| |
|
| | def encode_sequence(self, sequence, max_len=99): |
| | """One-hot encode DNA sequence""" |
| | sequence = sequence.upper()[:max_len] |
| | sequence = sequence + 'N' * (max_len - len(sequence)) |
| | encoded = torch.zeros(max_len, 5) |
| | for i, nucl in enumerate(sequence): |
| | idx = self.nucl_to_idx.get(nucl, 4) |
| | encoded[i, idx] = 1.0 |
| | return encoded |
| |
|
| | def encode_mutation(self, ref_seq, mut_seq): |
| | """ |
| | Encode mutation as 1101-dimensional tensor |
| | |
| | Args: |
| | ref_seq: Reference DNA sequence (99bp) |
| | mut_seq: Mutated DNA sequence (99bp) |
| | |
| | Returns: |
| | torch.Tensor of shape (1101,) |
| | """ |
| | max_len = 99 |
| | ref_encoded = self.encode_sequence(ref_seq, max_len) |
| | mut_encoded = self.encode_sequence(mut_seq, max_len) |
| | |
| | |
| | diff_mask = torch.zeros(max_len, 1) |
| | mut_pos = None |
| | ref_base = None |
| | mut_base = None |
| | |
| | for i in range(min(len(ref_seq), len(mut_seq), max_len)): |
| | if ref_seq[i] != mut_seq[i]: |
| | diff_mask[i, 0] = 1.0 |
| | if mut_pos is None: |
| | mut_pos = i |
| | ref_base = ref_seq[i].upper() |
| | mut_base = mut_seq[i].upper() |
| | |
| | |
| | mutation_onehot = torch.zeros(12) |
| | if ref_base is not None and mut_base is not None: |
| | key = (ref_base, mut_base) |
| | idx = self.MUTATION_TYPES.get(key, None) |
| | if idx is not None: |
| | mutation_onehot[idx] = 1.0 |
| | |
| | |
| | seq_part = torch.cat([ref_encoded, mut_encoded, diff_mask], dim=1) |
| | seq_flat = seq_part.flatten() |
| | return torch.cat([seq_flat, mutation_onehot]) |
| |
|