nileshhanotia's picture
Add encoder.py
be1dba9 verified
"""
encoder.py — MutationEncoder (1103-dim)
"""
import torch
class MutationEncoder:
MUTATION_TYPES = {
("A","T"):0,("A","C"):1,("A","G"):2,
("T","A"):3,("T","C"):4,("T","G"):5,
("C","A"):6,("C","T"):7,("C","G"):8,
("G","A"):9,("G","T"):10,("G","C"):11,
}
def __init__(self):
self.nucl_to_idx = {"A":0,"T":1,"G":2,"C":3,"N":4}
def encode_sequence(self, sequence, max_len=99):
sequence = sequence.upper()[:max_len] + "N"*(max_len-len(sequence))
encoded = torch.zeros(max_len, 5)
for i, n in enumerate(sequence):
encoded[i, self.nucl_to_idx.get(n,4)] = 1.0
return encoded
def encode_mutation(self, ref_seq, mut_seq, chrom=None, pos=None,
exon_flag=None, intron_flag=None):
max_len = 99
ref_enc = self.encode_sequence(ref_seq, max_len)
mut_enc = self.encode_sequence(mut_seq, max_len)
diff = torch.zeros(max_len, 1)
mut_pos_idx = ref_base = mut_base = None
for i in range(min(len(ref_seq), len(mut_seq), max_len)):
if ref_seq[i] != mut_seq[i]:
diff[i,0] = 1.0
if mut_pos_idx is None:
mut_pos_idx = i
ref_base = ref_seq[i].upper()
mut_base = mut_seq[i].upper()
onehot = torch.zeros(12)
if ref_base and mut_base:
idx = self.MUTATION_TYPES.get((ref_base, mut_base))
if idx is not None:
onehot[idx] = 1.0
seq_flat = torch.cat([ref_enc, mut_enc, diff], dim=1).flatten()
ef = int(exon_flag) if exon_flag is not None else 0
if_ = int(intron_flag) if intron_flag is not None else 0
return torch.cat([seq_flat, onehot,
torch.tensor([ef, if_], dtype=torch.float32)])
def find_mutation_position(self, ref_seq, mut_seq, window_size=99):
for i in range(min(len(ref_seq), len(mut_seq), window_size)):
if ref_seq[i] != mut_seq[i]:
return i
return window_size // 2