u8sand commited on
Commit
27515d6
·
verified ·
1 Parent(s): 37bae66

Create gsfm.py

Browse files
Files changed (1) hide show
  1. gsfm.py +136 -0
gsfm.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tempfile
3
+ import pathlib
4
+ import lightning as L
5
+ from huggingface_hub import PyTorchModelHubMixin, HfApi, hf_hub_download
6
+
7
+ UNK_IDX, PAD_IDX = 0, 1
8
+ special_symbols = ['<unk>', '<pad>']
9
+
10
+ class Vocab:
11
+ def __init__(self, vocab, default_index=0):
12
+ self.vocab = vocab
13
+ self.default_index = default_index
14
+ self.lookup = {token: i for i, token in enumerate(vocab)}
15
+
16
+ def __call__(self, sentence):
17
+ return [self.lookup.get(token, self.default_index) for token in sentence]
18
+
19
+ @staticmethod
20
+ def build_vocab_from_iterator(it, min_freq=1, specials=[], special_first=True):
21
+ vocab = []
22
+ if special_first:
23
+ vocab += specials
24
+ from collections import Counter
25
+ tokens = Counter()
26
+ for sentence in it:
27
+ tokens.update(sentence)
28
+ for token, freq in tokens.most_common():
29
+ if freq < min_freq: continue
30
+ vocab.append(token)
31
+ if not special_first:
32
+ vocab += specials
33
+ return Vocab(vocab)
34
+
35
+ def set_default_index(self, default_index):
36
+ self.default_index = default_index
37
+
38
+ def __len__(self):
39
+ return len(self.vocab)
40
+
41
+ def __reduce__(self):
42
+ return (Vocab, (self.vocab,))
43
+
44
+ def save_txt(self, filename):
45
+ with open(filename, 'w') as fw:
46
+ for token in self.vocab:
47
+ print(token, file=fw)
48
+
49
+ @staticmethod
50
+ def from_txt(filename):
51
+ with open(filename, 'r') as fr:
52
+ return Vocab([line for line in map(str.rstrip, fr) if line])
53
+
54
+ @staticmethod
55
+ def from_pretrained(repo_id: str, path_in_repo='vocab.txt'):
56
+ vocab_txt = hf_hub_download(
57
+ repo_id=repo_id,
58
+ filename=path_in_repo,
59
+ )
60
+ return Vocab.from_txt(vocab_txt)
61
+
62
+ def push_to_hub(self, repo_id: str, path_in_repo='vocab.txt'):
63
+ api = HfApi()
64
+ api.create_repo(repo_id, exist_ok=True)
65
+ with tempfile.TemporaryDirectory() as tmpdir:
66
+ tmpdir = pathlib.Path(tmpdir)
67
+ self.save_txt(tmpdir/'vocab.txt')
68
+ return api.upload_file(path_or_fileobj=tmpdir/'vocab.txt', repo_id=repo_id, path_in_repo=path_in_repo)
69
+
70
+ class MLP(torch.nn.Module):
71
+ def __init__(self, *dims, activation=torch.nn.ReLU, dropout=0.2):
72
+ super().__init__()
73
+ activation = activation()
74
+ dropout = torch.nn.Dropout(dropout)
75
+ self.layers = torch.nn.ModuleList([
76
+ layer
77
+ for a, b in zip(dims, dims[1:])
78
+ for layer in (
79
+ torch.nn.Linear(a, b),
80
+ activation,
81
+ dropout,
82
+ )
83
+ ][:-2]) # the last layer doesn't need activation/dropout
84
+ def forward(self, x):
85
+ for layer in self.layers:
86
+ x = layer(x)
87
+ return x
88
+
89
+ class GSFM(
90
+ L.LightningModule,
91
+ PyTorchModelHubMixin,
92
+ tags=["gene", "gene set", "bioinformatics"],
93
+ ):
94
+ def __init__(self, vocab_size, d_model=256, depth=2):
95
+ super().__init__()
96
+ self.vocab_size = vocab_size
97
+ self.d_model = d_model
98
+ self.depth = depth
99
+ self.embedding = torch.nn.Embedding(vocab_size, d_model, padding_idx=PAD_IDX)
100
+ self.encoder = MLP(*[d_model**n for n in range(1, depth)], d_model)
101
+ self.decoder = MLP(d_model*2, *[d_model**n for n in range(2, depth)], vocab_size)
102
+ self.save_hyperparameters()
103
+
104
+ def encode(self, x):
105
+ x = emb = self.embedding(x)
106
+ x = enc = self.encoder(emb)
107
+ x = torch.cat([enc.mean(1), emb.mean(1)], -1)
108
+ return x
109
+
110
+ def forward(self, x):
111
+ x = self.encode(x)
112
+ x = self.decoder(x)
113
+ return x
114
+
115
+ def training_step(self, batch, batch_idx):
116
+ x, y = batch
117
+ is_x = torch.isnan(y)
118
+ y = torch.where(is_x, 0, y)
119
+ pos_weight = torch.where(is_x, 0, 1)
120
+ y_ = self(x)
121
+ criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
122
+ loss = criterion(y_, y)
123
+ self.log('loss', loss, prog_bar=True)
124
+ return loss
125
+
126
+ def validation_step(self, batch, batch_idx):
127
+ return self.training_step(batch, batch_idx)
128
+
129
+ def configure_optimizers(self):
130
+ optimizer = torch.optim.Adam(self.parameters())
131
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.25)
132
+ return [optimizer], [{
133
+ "scheduler": scheduler,
134
+ "monitor": "loss",
135
+ "frequency": 1,
136
+ }]