File size: 2,779 Bytes
534e5a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch import nn


class ValueEmbedding(nn.Module):
    def __init__(self, hidden_dim, do_fft):
        super(ValueEmbedding, self).__init__()
        self.embedding = nn.Linear(1, hidden_dim)
        self.do_fft = do_fft

    def forward(self, x):
        if self.do_fft:
            x = torch.fft.fft(x, dim=-1).real
        return self.embedding(x.unsqueeze(-1))


class PositionalEmbeddingWithDnaPosition(nn.Module):
    def __init__(self, d_model, batch_size, seq_length, positional_temp, device):
        super(PositionalEmbeddingWithDnaPosition, self).__init__()
        self.d_model = d_model
        self.encoding = torch.zeros(batch_size, seq_length, self.d_model, device=device)
        # self.encoding = torch.zeros(batch_size, seq_length, self.d_model)
        self.encoding.requires_grad_(False)
        self.positional_temp = positional_temp

    def forward(self, x):
        pos = x.float().unsqueeze(-1) / 1
        _2i = torch.arange(0, self.d_model, 2, device=x.device)
        self.encoding[:x.shape[0], :, 0::2] = torch.sin(pos / self.positional_temp / (10000 ** (_2i / self.d_model)))
        self.encoding[:x.shape[0], :, 1::2] = torch.cos(pos / self.positional_temp / (10000 ** (_2i / self.d_model)))
        return self.encoding[:x.shape[0], :, :]


class PretrainEmbeddingSimple(nn.Module):
    def __init__(
        self,
        embedding_dim,
        chromosome_size,
        embedding_dropout,
        positional_embedding_type,
        positional_temp,
        batch_size,
        seq_length,
        device,
        chromatin_embedding
    ):
        super(PretrainEmbeddingSimple, self).__init__()
        self.value_embedding = ValueEmbedding(embedding_dim, False)
        self.chromatin_embedding = chromatin_embedding
        self.positional_embedding_type = positional_embedding_type
        if self.chromatin_embedding:
            self.chromosome_embedding = nn.Embedding(chromosome_size, embedding_dim)
            self.position_embedding = PositionalEmbeddingWithDnaPosition(
                embedding_dim,
                batch_size,
                seq_length,
                positional_temp,
                device
            )
        self.embedding_dropout = embedding_dropout
        self.dropout = nn.Dropout(p=self.embedding_dropout)

    def forward(self, value, chromosome, hg38_start, hg38_end):
        if self.chromatin_embedding:
            embedding = self.value_embedding(value) + self.chromosome_embedding(chromosome) + self.position_embedding(
                hg38_start) + self.position_embedding(hg38_end)
        else:
            embedding = self.value_embedding(value)
        if self.embedding_dropout > 0:
            return self.dropout(embedding)
        else:
            return embedding