|
|
import torch |
|
|
import tempfile |
|
|
import pathlib |
|
|
import lightning as L |
|
|
from huggingface_hub import PyTorchModelHubMixin, HfApi, hf_hub_download |
|
|
|
|
|
UNK_IDX, PAD_IDX = 0, 1 |
|
|
special_symbols = ['<unk>', '<pad>'] |
|
|
|
|
|
def multihot_tensor(indices: torch.Tensor, num_classes: int, dtype=torch.int64, device=None): |
|
|
*bs, _ = indices.shape |
|
|
return torch.zeros((*bs, num_classes,), device=device, dtype=dtype).scatter(1, indices, 1) |
|
|
|
|
|
class Vocab: |
|
|
def __init__(self, vocab, default_index=0): |
|
|
self.vocab = vocab |
|
|
self.default_index = default_index |
|
|
self.lookup = {token: i for i, token in enumerate(vocab)} |
|
|
|
|
|
def __call__(self, sentence): |
|
|
return [self.lookup.get(token, self.default_index) for token in sentence] |
|
|
|
|
|
@staticmethod |
|
|
def build_vocab_from_iterator(it, min_freq=1, specials=[], special_first=True): |
|
|
vocab = [] |
|
|
if special_first: |
|
|
vocab += specials |
|
|
from collections import Counter |
|
|
tokens = Counter() |
|
|
for sentence in it: |
|
|
tokens.update(sentence) |
|
|
for token, freq in tokens.most_common(): |
|
|
if freq < min_freq: continue |
|
|
vocab.append(token) |
|
|
if not special_first: |
|
|
vocab += specials |
|
|
return Vocab(vocab) |
|
|
|
|
|
def set_default_index(self, default_index): |
|
|
self.default_index = default_index |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.vocab) |
|
|
|
|
|
def __reduce__(self): |
|
|
return (Vocab, (self.vocab,)) |
|
|
|
|
|
def save_txt(self, filename): |
|
|
with open(filename, 'w') as fw: |
|
|
for token in self.vocab: |
|
|
print(token, file=fw) |
|
|
|
|
|
@staticmethod |
|
|
def from_txt(filename): |
|
|
with open(filename, 'r') as fr: |
|
|
return Vocab([line for line in map(str.rstrip, fr) if line]) |
|
|
|
|
|
@staticmethod |
|
|
def from_pretrained(repo_id: str, path_in_repo='vocab.txt'): |
|
|
vocab_txt = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=path_in_repo, |
|
|
) |
|
|
return Vocab.from_txt(vocab_txt) |
|
|
|
|
|
def push_to_hub(self, repo_id: str, path_in_repo='vocab.txt'): |
|
|
api = HfApi() |
|
|
api.create_repo(repo_id, exist_ok=True) |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
tmpdir = pathlib.Path(tmpdir) |
|
|
self.save_txt(tmpdir/'vocab.txt') |
|
|
return api.upload_file(path_or_fileobj=tmpdir/'vocab.txt', repo_id=repo_id, path_in_repo=path_in_repo) |
|
|
|
|
|
class MLP(torch.nn.Module): |
|
|
def __init__(self, *dims, activation=torch.nn.ReLU, dropout=0.2): |
|
|
super().__init__() |
|
|
activation = activation() |
|
|
dropout = torch.nn.Dropout(dropout) |
|
|
self.layers = torch.nn.ModuleList([ |
|
|
layer |
|
|
for a, b in zip(dims, dims[1:]) |
|
|
for layer in ( |
|
|
torch.nn.Linear(a, b), |
|
|
activation, |
|
|
dropout, |
|
|
) |
|
|
][:-2]) |
|
|
def forward(self, x): |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
return x |
|
|
|
|
|
class GSFM( |
|
|
L.LightningModule, |
|
|
PyTorchModelHubMixin, |
|
|
tags=["gene", "gene set", "bioinformatics"], |
|
|
): |
|
|
def __init__(self, vocab_size, d_model=256, depth=2, dropout=0.2, partition=0, weighted_loss=None): |
|
|
super().__init__() |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.depth = depth |
|
|
self.dropout = dropout |
|
|
self.partition = partition |
|
|
self.weighted_loss = weighted_loss |
|
|
self.encoder = MLP(vocab_size, *[d_model*(2**(n-1)) for n in range(depth, 1, -1)], d_model, dropout=dropout) |
|
|
self.decoder = MLP(d_model, *[d_model*(2**(n-1)) for n in range(1, depth)], vocab_size, dropout=dropout) |
|
|
self.save_hyperparameters() |
|
|
|
|
|
def encode(self, x): |
|
|
x = multihot_tensor(x, num_classes=self.vocab_size, device=self.device, dtype=torch.float) |
|
|
x[:, PAD_IDX] = 0 |
|
|
return self.encoder(x) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.encode(x) |
|
|
x = self.decoder(x) |
|
|
return x |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
x_idx = y_idx = batch |
|
|
y_ = self(x_idx) |
|
|
y = multihot_tensor(y_idx, num_classes=self.vocab_size, device=self.device, dtype=torch.float) |
|
|
y[:, PAD_IDX] = 0 |
|
|
criterion = torch.nn.BCEWithLogitsLoss() |
|
|
loss = criterion(y_, y) |
|
|
self.log('loss', loss, prog_bar=True) |
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
return self.training_step(batch, batch_idx) |
|
|
|
|
|
def configure_optimizers(self): |
|
|
optimizer = torch.optim.Adam(self.parameters()) |
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.25) |
|
|
return [optimizer], [{ |
|
|
"scheduler": scheduler, |
|
|
"monitor": "loss", |
|
|
"frequency": 1, |
|
|
}] |
|
|
|