""" MutationEncoder for DNA mutation encoding """ import torch class MutationEncoder: """Encode DNA mutations for neural network input""" MUTATION_TYPES = { ('A','T'): 0, ('A','C'): 1, ('A','G'): 2, ('T','A'): 3, ('T','C'): 4, ('T','G'): 5, ('C','A'): 6, ('C','T'): 7, ('C','G'): 8, ('G','A'): 9, ('G','T'): 10, ('G','C'): 11 } def __init__(self): self.nucl_to_idx = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4} def encode_sequence(self, sequence, max_len=99): """One-hot encode DNA sequence""" sequence = sequence.upper()[:max_len] sequence = sequence + 'N' * (max_len - len(sequence)) encoded = torch.zeros(max_len, 5) for i, nucl in enumerate(sequence): idx = self.nucl_to_idx.get(nucl, 4) encoded[i, idx] = 1.0 return encoded def encode_mutation(self, ref_seq, mut_seq): """ Encode mutation as 1101-dimensional tensor Args: ref_seq: Reference DNA sequence (99bp) mut_seq: Mutated DNA sequence (99bp) Returns: torch.Tensor of shape (1101,) """ max_len = 99 ref_encoded = self.encode_sequence(ref_seq, max_len) mut_encoded = self.encode_sequence(mut_seq, max_len) # Difference mask diff_mask = torch.zeros(max_len, 1) mut_pos = None ref_base = None mut_base = None for i in range(min(len(ref_seq), len(mut_seq), max_len)): if ref_seq[i] != mut_seq[i]: diff_mask[i, 0] = 1.0 if mut_pos is None: mut_pos = i ref_base = ref_seq[i].upper() mut_base = mut_seq[i].upper() # Mutation type one-hot mutation_onehot = torch.zeros(12) if ref_base is not None and mut_base is not None: key = (ref_base, mut_base) idx = self.MUTATION_TYPES.get(key, None) if idx is not None: mutation_onehot[idx] = 1.0 # Concatenate all components seq_part = torch.cat([ref_encoded, mut_encoded, diff_mask], dim=1) seq_flat = seq_part.flatten() return torch.cat([seq_flat, mutation_onehot])