File size: 4,382 Bytes
27515d6
 
 
 
 
 
 
 
 
ccd7396
 
 
 
27515d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccd7396
27515d6
 
 
 
ccd7396
 
 
 
 
27515d6
 
 
ccd7396
 
 
27515d6
 
 
 
 
 
 
ccd7396
 
 
 
 
27515d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import tempfile
import pathlib
import lightning as L
from huggingface_hub import PyTorchModelHubMixin, HfApi, hf_hub_download

UNK_IDX, PAD_IDX = 0, 1
special_symbols = ['<unk>', '<pad>']

def multihot_tensor(indices: torch.Tensor, num_classes: int, dtype=torch.int64, device=None):
  *bs, _ = indices.shape
  return torch.zeros((*bs, num_classes,), device=device, dtype=dtype).scatter(1, indices, 1)

class Vocab:
  def __init__(self, vocab, default_index=0):
    self.vocab = vocab
    self.default_index = default_index
    self.lookup = {token: i for i, token in enumerate(vocab)}

  def __call__(self, sentence):
    return [self.lookup.get(token, self.default_index) for token in sentence]

  @staticmethod
  def build_vocab_from_iterator(it, min_freq=1, specials=[], special_first=True):
    vocab = []
    if special_first:
      vocab += specials
    from collections import Counter
    tokens = Counter()
    for sentence in it:
      tokens.update(sentence)
    for token, freq in tokens.most_common():
      if freq < min_freq: continue
      vocab.append(token)
    if not special_first:
      vocab += specials
    return Vocab(vocab)

  def set_default_index(self, default_index):
    self.default_index = default_index

  def __len__(self):
    return len(self.vocab)

  def __reduce__(self):
    return (Vocab, (self.vocab,))

  def save_txt(self, filename):
    with open(filename, 'w') as fw:
      for token in self.vocab:
        print(token, file=fw)

  @staticmethod
  def from_txt(filename):
    with open(filename, 'r') as fr:
      return Vocab([line for line in map(str.rstrip, fr) if line])

  @staticmethod
  def from_pretrained(repo_id: str, path_in_repo='vocab.txt'):
    vocab_txt = hf_hub_download(
      repo_id=repo_id,
      filename=path_in_repo,
    )
    return Vocab.from_txt(vocab_txt)

  def push_to_hub(self, repo_id: str, path_in_repo='vocab.txt'):
    api = HfApi()
    api.create_repo(repo_id, exist_ok=True)
    with tempfile.TemporaryDirectory() as tmpdir:
      tmpdir = pathlib.Path(tmpdir)
      self.save_txt(tmpdir/'vocab.txt')
      return api.upload_file(path_or_fileobj=tmpdir/'vocab.txt', repo_id=repo_id, path_in_repo=path_in_repo)

class MLP(torch.nn.Module):
  def __init__(self, *dims, activation=torch.nn.ReLU, dropout=0.2):
    super().__init__()
    activation = activation()
    dropout = torch.nn.Dropout(dropout)
    self.layers = torch.nn.ModuleList([
      layer
      for a, b in zip(dims, dims[1:])
      for layer in (
        torch.nn.Linear(a, b),
        activation,
        dropout,
      )
    ][:-2]) # the last layer doesn't need activation/dropout
  def forward(self, x):
    for layer in self.layers:
      x = layer(x)
    return x

class GSFM(
  L.LightningModule,
  PyTorchModelHubMixin,
  tags=["gene", "gene set", "bioinformatics"],
):
  def __init__(self, vocab_size, d_model=256, depth=2, dropout=0.2, partition=0, weighted_loss=None):
    super().__init__()
    self.vocab_size = vocab_size
    self.d_model = d_model
    self.depth = depth
    self.dropout = dropout
    self.partition = partition
    self.weighted_loss = weighted_loss
    self.encoder = MLP(vocab_size, *[d_model*(2**(n-1)) for n in range(depth, 1, -1)], d_model, dropout=dropout)
    self.decoder = MLP(d_model, *[d_model*(2**(n-1)) for n in range(1, depth)], vocab_size, dropout=dropout)
    self.save_hyperparameters()

  def encode(self, x):
    x = multihot_tensor(x, num_classes=self.vocab_size, device=self.device, dtype=torch.float)
    x[:, PAD_IDX] = 0
    return self.encoder(x)

  def forward(self, x):
    x = self.encode(x)
    x = self.decoder(x)
    return x

  def training_step(self, batch, batch_idx):
    x_idx = y_idx = batch
    y_ = self(x_idx)
    y = multihot_tensor(y_idx, num_classes=self.vocab_size, device=self.device, dtype=torch.float)
    y[:, PAD_IDX] = 0
    criterion = torch.nn.BCEWithLogitsLoss()
    loss = criterion(y_, y)
    self.log('loss', loss, prog_bar=True)
    return loss

  def validation_step(self, batch, batch_idx):
    return self.training_step(batch, batch_idx)

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.25)
    return [optimizer], [{
        "scheduler": scheduler,
        "monitor": "loss",
        "frequency": 1,
    }]