File size: 2,105 Bytes
be1dba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""
encoder.py — MutationEncoder (1103-dim)
"""
import torch

class MutationEncoder:
    MUTATION_TYPES = {
        ("A","T"):0,("A","C"):1,("A","G"):2,
        ("T","A"):3,("T","C"):4,("T","G"):5,
        ("C","A"):6,("C","T"):7,("C","G"):8,
        ("G","A"):9,("G","T"):10,("G","C"):11,
    }
    def __init__(self):
        self.nucl_to_idx = {"A":0,"T":1,"G":2,"C":3,"N":4}

    def encode_sequence(self, sequence, max_len=99):
        sequence = sequence.upper()[:max_len] + "N"*(max_len-len(sequence))
        encoded = torch.zeros(max_len, 5)
        for i, n in enumerate(sequence):
            encoded[i, self.nucl_to_idx.get(n,4)] = 1.0
        return encoded

    def encode_mutation(self, ref_seq, mut_seq, chrom=None, pos=None,
                        exon_flag=None, intron_flag=None):
        max_len = 99
        ref_enc = self.encode_sequence(ref_seq, max_len)
        mut_enc = self.encode_sequence(mut_seq, max_len)
        diff    = torch.zeros(max_len, 1)
        mut_pos_idx = ref_base = mut_base = None
        for i in range(min(len(ref_seq), len(mut_seq), max_len)):
            if ref_seq[i] != mut_seq[i]:
                diff[i,0] = 1.0
                if mut_pos_idx is None:
                    mut_pos_idx = i
                    ref_base = ref_seq[i].upper()
                    mut_base = mut_seq[i].upper()
        onehot = torch.zeros(12)
        if ref_base and mut_base:
            idx = self.MUTATION_TYPES.get((ref_base, mut_base))
            if idx is not None:
                onehot[idx] = 1.0
        seq_flat = torch.cat([ref_enc, mut_enc, diff], dim=1).flatten()
        ef = int(exon_flag)   if exon_flag   is not None else 0
        if_ = int(intron_flag) if intron_flag is not None else 0
        return torch.cat([seq_flat, onehot,
                          torch.tensor([ef, if_], dtype=torch.float32)])

    def find_mutation_position(self, ref_seq, mut_seq, window_size=99):
        for i in range(min(len(ref_seq), len(mut_seq), window_size)):
            if ref_seq[i] != mut_seq[i]:
                return i
        return window_size // 2