File size: 2,329 Bytes
87460fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
MutationEncoder for DNA mutation encoding
"""
import torch

class MutationEncoder:
    """Encode DNA mutations for neural network input"""
    
    MUTATION_TYPES = {
        ('A','T'): 0,  ('A','C'): 1,  ('A','G'): 2,
        ('T','A'): 3,  ('T','C'): 4,  ('T','G'): 5,
        ('C','A'): 6,  ('C','T'): 7,  ('C','G'): 8,
        ('G','A'): 9,  ('G','T'): 10, ('G','C'): 11
    }

    def __init__(self):
        self.nucl_to_idx = {'A': 0, 'T': 1, 'G': 2, 'C': 3, 'N': 4}

    def encode_sequence(self, sequence, max_len=99):
        """One-hot encode DNA sequence"""
        sequence = sequence.upper()[:max_len]
        sequence = sequence + 'N' * (max_len - len(sequence))
        encoded = torch.zeros(max_len, 5)
        for i, nucl in enumerate(sequence):
            idx = self.nucl_to_idx.get(nucl, 4)
            encoded[i, idx] = 1.0
        return encoded

    def encode_mutation(self, ref_seq, mut_seq):
        """
        Encode mutation as 1101-dimensional tensor
        
        Args:
            ref_seq: Reference DNA sequence (99bp)
            mut_seq: Mutated DNA sequence (99bp)
            
        Returns:
            torch.Tensor of shape (1101,)
        """
        max_len = 99
        ref_encoded = self.encode_sequence(ref_seq, max_len)
        mut_encoded = self.encode_sequence(mut_seq, max_len)
        
        # Difference mask
        diff_mask = torch.zeros(max_len, 1)
        mut_pos = None
        ref_base = None
        mut_base = None
        
        for i in range(min(len(ref_seq), len(mut_seq), max_len)):
            if ref_seq[i] != mut_seq[i]:
                diff_mask[i, 0] = 1.0
                if mut_pos is None:
                    mut_pos = i
                    ref_base = ref_seq[i].upper()
                    mut_base = mut_seq[i].upper()
        
        # Mutation type one-hot
        mutation_onehot = torch.zeros(12)
        if ref_base is not None and mut_base is not None:
            key = (ref_base, mut_base)
            idx = self.MUTATION_TYPES.get(key, None)
            if idx is not None:
                mutation_onehot[idx] = 1.0
        
        # Concatenate all components
        seq_part = torch.cat([ref_encoded, mut_encoded, diff_mask], dim=1)
        seq_flat = seq_part.flatten()
        return torch.cat([seq_flat, mutation_onehot])