gsfm / gsfm.py
u8sand's picture
Update gsfm.py
ccd7396 verified
import torch
import tempfile
import pathlib
import lightning as L
from huggingface_hub import PyTorchModelHubMixin, HfApi, hf_hub_download
UNK_IDX, PAD_IDX = 0, 1
special_symbols = ['<unk>', '<pad>']
def multihot_tensor(indices: torch.Tensor, num_classes: int, dtype=torch.int64, device=None):
*bs, _ = indices.shape
return torch.zeros((*bs, num_classes,), device=device, dtype=dtype).scatter(1, indices, 1)
class Vocab:
def __init__(self, vocab, default_index=0):
self.vocab = vocab
self.default_index = default_index
self.lookup = {token: i for i, token in enumerate(vocab)}
def __call__(self, sentence):
return [self.lookup.get(token, self.default_index) for token in sentence]
@staticmethod
def build_vocab_from_iterator(it, min_freq=1, specials=[], special_first=True):
vocab = []
if special_first:
vocab += specials
from collections import Counter
tokens = Counter()
for sentence in it:
tokens.update(sentence)
for token, freq in tokens.most_common():
if freq < min_freq: continue
vocab.append(token)
if not special_first:
vocab += specials
return Vocab(vocab)
def set_default_index(self, default_index):
self.default_index = default_index
def __len__(self):
return len(self.vocab)
def __reduce__(self):
return (Vocab, (self.vocab,))
def save_txt(self, filename):
with open(filename, 'w') as fw:
for token in self.vocab:
print(token, file=fw)
@staticmethod
def from_txt(filename):
with open(filename, 'r') as fr:
return Vocab([line for line in map(str.rstrip, fr) if line])
@staticmethod
def from_pretrained(repo_id: str, path_in_repo='vocab.txt'):
vocab_txt = hf_hub_download(
repo_id=repo_id,
filename=path_in_repo,
)
return Vocab.from_txt(vocab_txt)
def push_to_hub(self, repo_id: str, path_in_repo='vocab.txt'):
api = HfApi()
api.create_repo(repo_id, exist_ok=True)
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
self.save_txt(tmpdir/'vocab.txt')
return api.upload_file(path_or_fileobj=tmpdir/'vocab.txt', repo_id=repo_id, path_in_repo=path_in_repo)
class MLP(torch.nn.Module):
def __init__(self, *dims, activation=torch.nn.ReLU, dropout=0.2):
super().__init__()
activation = activation()
dropout = torch.nn.Dropout(dropout)
self.layers = torch.nn.ModuleList([
layer
for a, b in zip(dims, dims[1:])
for layer in (
torch.nn.Linear(a, b),
activation,
dropout,
)
][:-2]) # the last layer doesn't need activation/dropout
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class GSFM(
L.LightningModule,
PyTorchModelHubMixin,
tags=["gene", "gene set", "bioinformatics"],
):
def __init__(self, vocab_size, d_model=256, depth=2, dropout=0.2, partition=0, weighted_loss=None):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.depth = depth
self.dropout = dropout
self.partition = partition
self.weighted_loss = weighted_loss
self.encoder = MLP(vocab_size, *[d_model*(2**(n-1)) for n in range(depth, 1, -1)], d_model, dropout=dropout)
self.decoder = MLP(d_model, *[d_model*(2**(n-1)) for n in range(1, depth)], vocab_size, dropout=dropout)
self.save_hyperparameters()
def encode(self, x):
x = multihot_tensor(x, num_classes=self.vocab_size, device=self.device, dtype=torch.float)
x[:, PAD_IDX] = 0
return self.encoder(x)
def forward(self, x):
x = self.encode(x)
x = self.decoder(x)
return x
def training_step(self, batch, batch_idx):
x_idx = y_idx = batch
y_ = self(x_idx)
y = multihot_tensor(y_idx, num_classes=self.vocab_size, device=self.device, dtype=torch.float)
y[:, PAD_IDX] = 0
criterion = torch.nn.BCEWithLogitsLoss()
loss = criterion(y_, y)
self.log('loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
return self.training_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.25)
return [optimizer], [{
"scheduler": scheduler,
"monitor": "loss",
"frequency": 1,
}]