Update README.md
Browse files
README.md
CHANGED
|
@@ -28,21 +28,21 @@ from gsfm import Vocab, GSFM
|
|
| 28 |
|
| 29 |
# load gsfm vocabulary and model weights
|
| 30 |
vocab = Vocab.from_pretrained('maayanlab/gsfm')
|
| 31 |
-
|
| 32 |
|
| 33 |
# convert gene symbols into token ids
|
| 34 |
token_ids = torch.tensor(vocab(['ACE1', 'ACE2']))[None, :]
|
| 35 |
|
| 36 |
# use model to predict missing genes from the set
|
| 37 |
-
logits = torch.squeeze(
|
| 38 |
top_10 = sorted(zip(logits, vocab.vocab))[-10:]
|
| 39 |
top_10
|
| 40 |
|
| 41 |
# get gene embedding
|
| 42 |
-
gene_embeddings =
|
| 43 |
gene_embeddings
|
| 44 |
|
| 45 |
# get model middle layer
|
| 46 |
-
gene_set_encoding =
|
| 47 |
gene_set_encoding
|
| 48 |
```
|
|
|
|
| 28 |
|
| 29 |
# load gsfm vocabulary and model weights
|
| 30 |
vocab = Vocab.from_pretrained('maayanlab/gsfm')
|
| 31 |
+
gsfm = GSFM.from_pretrained('maayanlab/gsfm')
|
| 32 |
|
| 33 |
# convert gene symbols into token ids
|
| 34 |
token_ids = torch.tensor(vocab(['ACE1', 'ACE2']))[None, :]
|
| 35 |
|
| 36 |
# use model to predict missing genes from the set
|
| 37 |
+
logits = torch.squeeze(gsfm(token_ids))
|
| 38 |
top_10 = sorted(zip(logits, vocab.vocab))[-10:]
|
| 39 |
top_10
|
| 40 |
|
| 41 |
# get gene embedding
|
| 42 |
+
gene_embeddings = gsfm.embedding(token_ids)
|
| 43 |
gene_embeddings
|
| 44 |
|
| 45 |
# get model middle layer
|
| 46 |
+
gene_set_encoding = gsfm.encode(token_ids)
|
| 47 |
gene_set_encoding
|
| 48 |
```
|