| """ | |
| encoder.py — MutationEncoder (1103-dim) | |
| """ | |
| import torch | |
| class MutationEncoder: | |
| MUTATION_TYPES = { | |
| ("A","T"):0,("A","C"):1,("A","G"):2, | |
| ("T","A"):3,("T","C"):4,("T","G"):5, | |
| ("C","A"):6,("C","T"):7,("C","G"):8, | |
| ("G","A"):9,("G","T"):10,("G","C"):11, | |
| } | |
| def __init__(self): | |
| self.nucl_to_idx = {"A":0,"T":1,"G":2,"C":3,"N":4} | |
| def encode_sequence(self, sequence, max_len=99): | |
| sequence = sequence.upper()[:max_len] + "N"*(max_len-len(sequence)) | |
| encoded = torch.zeros(max_len, 5) | |
| for i, n in enumerate(sequence): | |
| encoded[i, self.nucl_to_idx.get(n,4)] = 1.0 | |
| return encoded | |
| def encode_mutation(self, ref_seq, mut_seq, chrom=None, pos=None, | |
| exon_flag=None, intron_flag=None): | |
| max_len = 99 | |
| ref_enc = self.encode_sequence(ref_seq, max_len) | |
| mut_enc = self.encode_sequence(mut_seq, max_len) | |
| diff = torch.zeros(max_len, 1) | |
| mut_pos_idx = ref_base = mut_base = None | |
| for i in range(min(len(ref_seq), len(mut_seq), max_len)): | |
| if ref_seq[i] != mut_seq[i]: | |
| diff[i,0] = 1.0 | |
| if mut_pos_idx is None: | |
| mut_pos_idx = i | |
| ref_base = ref_seq[i].upper() | |
| mut_base = mut_seq[i].upper() | |
| onehot = torch.zeros(12) | |
| if ref_base and mut_base: | |
| idx = self.MUTATION_TYPES.get((ref_base, mut_base)) | |
| if idx is not None: | |
| onehot[idx] = 1.0 | |
| seq_flat = torch.cat([ref_enc, mut_enc, diff], dim=1).flatten() | |
| ef = int(exon_flag) if exon_flag is not None else 0 | |
| if_ = int(intron_flag) if intron_flag is not None else 0 | |
| return torch.cat([seq_flat, onehot, | |
| torch.tensor([ef, if_], dtype=torch.float32)]) | |
| def find_mutation_position(self, ref_seq, mut_seq, window_size=99): | |
| for i in range(min(len(ref_seq), len(mut_seq), window_size)): | |
| if ref_seq[i] != mut_seq[i]: | |
| return i | |
| return window_size // 2 | |