File size: 2,779 Bytes
534e5a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
from torch import nn
class ValueEmbedding(nn.Module):
def __init__(self, hidden_dim, do_fft):
super(ValueEmbedding, self).__init__()
self.embedding = nn.Linear(1, hidden_dim)
self.do_fft = do_fft
def forward(self, x):
if self.do_fft:
x = torch.fft.fft(x, dim=-1).real
return self.embedding(x.unsqueeze(-1))
class PositionalEmbeddingWithDnaPosition(nn.Module):
def __init__(self, d_model, batch_size, seq_length, positional_temp, device):
super(PositionalEmbeddingWithDnaPosition, self).__init__()
self.d_model = d_model
self.encoding = torch.zeros(batch_size, seq_length, self.d_model, device=device)
# self.encoding = torch.zeros(batch_size, seq_length, self.d_model)
self.encoding.requires_grad_(False)
self.positional_temp = positional_temp
def forward(self, x):
pos = x.float().unsqueeze(-1) / 1
_2i = torch.arange(0, self.d_model, 2, device=x.device)
self.encoding[:x.shape[0], :, 0::2] = torch.sin(pos / self.positional_temp / (10000 ** (_2i / self.d_model)))
self.encoding[:x.shape[0], :, 1::2] = torch.cos(pos / self.positional_temp / (10000 ** (_2i / self.d_model)))
return self.encoding[:x.shape[0], :, :]
class PretrainEmbeddingSimple(nn.Module):
def __init__(
self,
embedding_dim,
chromosome_size,
embedding_dropout,
positional_embedding_type,
positional_temp,
batch_size,
seq_length,
device,
chromatin_embedding
):
super(PretrainEmbeddingSimple, self).__init__()
self.value_embedding = ValueEmbedding(embedding_dim, False)
self.chromatin_embedding = chromatin_embedding
self.positional_embedding_type = positional_embedding_type
if self.chromatin_embedding:
self.chromosome_embedding = nn.Embedding(chromosome_size, embedding_dim)
self.position_embedding = PositionalEmbeddingWithDnaPosition(
embedding_dim,
batch_size,
seq_length,
positional_temp,
device
)
self.embedding_dropout = embedding_dropout
self.dropout = nn.Dropout(p=self.embedding_dropout)
def forward(self, value, chromosome, hg38_start, hg38_end):
if self.chromatin_embedding:
embedding = self.value_embedding(value) + self.chromosome_embedding(chromosome) + self.position_embedding(
hg38_start) + self.position_embedding(hg38_end)
else:
embedding = self.value_embedding(value)
if self.embedding_dropout > 0:
return self.dropout(embedding)
else:
return embedding
|