File size: 4,382 Bytes
27515d6 ccd7396 27515d6 ccd7396 27515d6 ccd7396 27515d6 ccd7396 27515d6 ccd7396 27515d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
import tempfile
import pathlib
import lightning as L
from huggingface_hub import PyTorchModelHubMixin, HfApi, hf_hub_download
UNK_IDX, PAD_IDX = 0, 1
special_symbols = ['<unk>', '<pad>']
def multihot_tensor(indices: torch.Tensor, num_classes: int, dtype=torch.int64, device=None):
*bs, _ = indices.shape
return torch.zeros((*bs, num_classes,), device=device, dtype=dtype).scatter(1, indices, 1)
class Vocab:
def __init__(self, vocab, default_index=0):
self.vocab = vocab
self.default_index = default_index
self.lookup = {token: i for i, token in enumerate(vocab)}
def __call__(self, sentence):
return [self.lookup.get(token, self.default_index) for token in sentence]
@staticmethod
def build_vocab_from_iterator(it, min_freq=1, specials=[], special_first=True):
vocab = []
if special_first:
vocab += specials
from collections import Counter
tokens = Counter()
for sentence in it:
tokens.update(sentence)
for token, freq in tokens.most_common():
if freq < min_freq: continue
vocab.append(token)
if not special_first:
vocab += specials
return Vocab(vocab)
def set_default_index(self, default_index):
self.default_index = default_index
def __len__(self):
return len(self.vocab)
def __reduce__(self):
return (Vocab, (self.vocab,))
def save_txt(self, filename):
with open(filename, 'w') as fw:
for token in self.vocab:
print(token, file=fw)
@staticmethod
def from_txt(filename):
with open(filename, 'r') as fr:
return Vocab([line for line in map(str.rstrip, fr) if line])
@staticmethod
def from_pretrained(repo_id: str, path_in_repo='vocab.txt'):
vocab_txt = hf_hub_download(
repo_id=repo_id,
filename=path_in_repo,
)
return Vocab.from_txt(vocab_txt)
def push_to_hub(self, repo_id: str, path_in_repo='vocab.txt'):
api = HfApi()
api.create_repo(repo_id, exist_ok=True)
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
self.save_txt(tmpdir/'vocab.txt')
return api.upload_file(path_or_fileobj=tmpdir/'vocab.txt', repo_id=repo_id, path_in_repo=path_in_repo)
class MLP(torch.nn.Module):
def __init__(self, *dims, activation=torch.nn.ReLU, dropout=0.2):
super().__init__()
activation = activation()
dropout = torch.nn.Dropout(dropout)
self.layers = torch.nn.ModuleList([
layer
for a, b in zip(dims, dims[1:])
for layer in (
torch.nn.Linear(a, b),
activation,
dropout,
)
][:-2]) # the last layer doesn't need activation/dropout
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class GSFM(
L.LightningModule,
PyTorchModelHubMixin,
tags=["gene", "gene set", "bioinformatics"],
):
def __init__(self, vocab_size, d_model=256, depth=2, dropout=0.2, partition=0, weighted_loss=None):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.depth = depth
self.dropout = dropout
self.partition = partition
self.weighted_loss = weighted_loss
self.encoder = MLP(vocab_size, *[d_model*(2**(n-1)) for n in range(depth, 1, -1)], d_model, dropout=dropout)
self.decoder = MLP(d_model, *[d_model*(2**(n-1)) for n in range(1, depth)], vocab_size, dropout=dropout)
self.save_hyperparameters()
def encode(self, x):
x = multihot_tensor(x, num_classes=self.vocab_size, device=self.device, dtype=torch.float)
x[:, PAD_IDX] = 0
return self.encoder(x)
def forward(self, x):
x = self.encode(x)
x = self.decoder(x)
return x
def training_step(self, batch, batch_idx):
x_idx = y_idx = batch
y_ = self(x_idx)
y = multihot_tensor(y_idx, num_classes=self.vocab_size, device=self.device, dtype=torch.float)
y[:, PAD_IDX] = 0
criterion = torch.nn.BCEWithLogitsLoss()
loss = criterion(y_, y)
self.log('loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
return self.training_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.25)
return [optimizer], [{
"scheduler": scheduler,
"monitor": "loss",
"frequency": 1,
}]
|