""" encoder.py — MutationEncoder (1103-dim) """ import torch class MutationEncoder: MUTATION_TYPES = { ("A","T"):0,("A","C"):1,("A","G"):2, ("T","A"):3,("T","C"):4,("T","G"):5, ("C","A"):6,("C","T"):7,("C","G"):8, ("G","A"):9,("G","T"):10,("G","C"):11, } def __init__(self): self.nucl_to_idx = {"A":0,"T":1,"G":2,"C":3,"N":4} def encode_sequence(self, sequence, max_len=99): sequence = sequence.upper()[:max_len] + "N"*(max_len-len(sequence)) encoded = torch.zeros(max_len, 5) for i, n in enumerate(sequence): encoded[i, self.nucl_to_idx.get(n,4)] = 1.0 return encoded def encode_mutation(self, ref_seq, mut_seq, chrom=None, pos=None, exon_flag=None, intron_flag=None): max_len = 99 ref_enc = self.encode_sequence(ref_seq, max_len) mut_enc = self.encode_sequence(mut_seq, max_len) diff = torch.zeros(max_len, 1) mut_pos_idx = ref_base = mut_base = None for i in range(min(len(ref_seq), len(mut_seq), max_len)): if ref_seq[i] != mut_seq[i]: diff[i,0] = 1.0 if mut_pos_idx is None: mut_pos_idx = i ref_base = ref_seq[i].upper() mut_base = mut_seq[i].upper() onehot = torch.zeros(12) if ref_base and mut_base: idx = self.MUTATION_TYPES.get((ref_base, mut_base)) if idx is not None: onehot[idx] = 1.0 seq_flat = torch.cat([ref_enc, mut_enc, diff], dim=1).flatten() ef = int(exon_flag) if exon_flag is not None else 0 if_ = int(intron_flag) if intron_flag is not None else 0 return torch.cat([seq_flat, onehot, torch.tensor([ef, if_], dtype=torch.float32)]) def find_mutation_position(self, ref_seq, mut_seq, window_size=99): for i in range(min(len(ref_seq), len(mut_seq), window_size)): if ref_seq[i] != mut_seq[i]: return i return window_size // 2