import torch import tempfile import pathlib import lightning as L from huggingface_hub import PyTorchModelHubMixin, HfApi, hf_hub_download UNK_IDX, PAD_IDX = 0, 1 special_symbols = ['', ''] def multihot_tensor(indices: torch.Tensor, num_classes: int, dtype=torch.int64, device=None): *bs, _ = indices.shape return torch.zeros((*bs, num_classes,), device=device, dtype=dtype).scatter(1, indices, 1) class Vocab: def __init__(self, vocab, default_index=0): self.vocab = vocab self.default_index = default_index self.lookup = {token: i for i, token in enumerate(vocab)} def __call__(self, sentence): return [self.lookup.get(token, self.default_index) for token in sentence] @staticmethod def build_vocab_from_iterator(it, min_freq=1, specials=[], special_first=True): vocab = [] if special_first: vocab += specials from collections import Counter tokens = Counter() for sentence in it: tokens.update(sentence) for token, freq in tokens.most_common(): if freq < min_freq: continue vocab.append(token) if not special_first: vocab += specials return Vocab(vocab) def set_default_index(self, default_index): self.default_index = default_index def __len__(self): return len(self.vocab) def __reduce__(self): return (Vocab, (self.vocab,)) def save_txt(self, filename): with open(filename, 'w') as fw: for token in self.vocab: print(token, file=fw) @staticmethod def from_txt(filename): with open(filename, 'r') as fr: return Vocab([line for line in map(str.rstrip, fr) if line]) @staticmethod def from_pretrained(repo_id: str, path_in_repo='vocab.txt'): vocab_txt = hf_hub_download( repo_id=repo_id, filename=path_in_repo, ) return Vocab.from_txt(vocab_txt) def push_to_hub(self, repo_id: str, path_in_repo='vocab.txt'): api = HfApi() api.create_repo(repo_id, exist_ok=True) with tempfile.TemporaryDirectory() as tmpdir: tmpdir = pathlib.Path(tmpdir) self.save_txt(tmpdir/'vocab.txt') return api.upload_file(path_or_fileobj=tmpdir/'vocab.txt', repo_id=repo_id, path_in_repo=path_in_repo) class MLP(torch.nn.Module): def __init__(self, *dims, activation=torch.nn.ReLU, dropout=0.2): super().__init__() activation = activation() dropout = torch.nn.Dropout(dropout) self.layers = torch.nn.ModuleList([ layer for a, b in zip(dims, dims[1:]) for layer in ( torch.nn.Linear(a, b), activation, dropout, ) ][:-2]) # the last layer doesn't need activation/dropout def forward(self, x): for layer in self.layers: x = layer(x) return x class GSFM( L.LightningModule, PyTorchModelHubMixin, tags=["gene", "gene set", "bioinformatics"], ): def __init__(self, vocab_size, d_model=256, depth=2, dropout=0.2, partition=0, weighted_loss=None): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.depth = depth self.dropout = dropout self.partition = partition self.weighted_loss = weighted_loss self.encoder = MLP(vocab_size, *[d_model*(2**(n-1)) for n in range(depth, 1, -1)], d_model, dropout=dropout) self.decoder = MLP(d_model, *[d_model*(2**(n-1)) for n in range(1, depth)], vocab_size, dropout=dropout) self.save_hyperparameters() def encode(self, x): x = multihot_tensor(x, num_classes=self.vocab_size, device=self.device, dtype=torch.float) x[:, PAD_IDX] = 0 return self.encoder(x) def forward(self, x): x = self.encode(x) x = self.decoder(x) return x def training_step(self, batch, batch_idx): x_idx = y_idx = batch y_ = self(x_idx) y = multihot_tensor(y_idx, num_classes=self.vocab_size, device=self.device, dtype=torch.float) y[:, PAD_IDX] = 0 criterion = torch.nn.BCEWithLogitsLoss() loss = criterion(y_, y) self.log('loss', loss, prog_bar=True) return loss def validation_step(self, batch, batch_idx): return self.training_step(batch, batch_idx) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters()) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.25) return [optimizer], [{ "scheduler": scheduler, "monitor": "loss", "frequency": 1, }]