pantoniadis commited on
Commit
080b8fc
·
verified ·
1 Parent(s): 56c05b2

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +83 -0
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ ---
6
+
7
+ Load the model:
8
+
9
+ ```
10
+ import torch
11
+ from transformers import AutoModel, AutoTokenizer
12
+
13
+ model_name = "rnalm/446M_MS_MM_last"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
16
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
17
+
18
+ # Move model to GPU
19
+ model = model.cuda()
20
+ ```
21
+
22
+ Inference without using the track prediction head:
23
+ ```
24
+ # disable the track head in order to avoid providing the metadata
25
+ model.model.predict_tracks = False
26
+ inputs = tokenizer("ACGTACGT", return_tensors="pt")
27
+ # always add taxonomy information in the multispecies model
28
+ assert model.model.use_taxonomy == True
29
+ # use human taxonomy
30
+ # for a full list of taxonomies check 'rnalm/tokenizers/taxonomy_mappings/processed_taxonomy.json'
31
+ human_taxonomy = torch.tensor([2317, 2318, 2319, 2266, 2248, 2072, 2053, 1875])
32
+
33
+ with torch.no_grad():
34
+ outputs = model(input_ids=inputs["input_ids"].cuda(), masked_taxonomy=human_taxonomy.cuda())
35
+
36
+
37
+ last_hidden_state_w_taxonomy = outputs.last_hidden_state
38
+ last_hidden_state_wo_taxonomy = outputs.last_hidden_state[:, 1:, :]
39
+
40
+ last_hidden_state_w_taxonomy.shape
41
+ # torch.Size([1, 9, 1024])
42
+
43
+ last_hidden_state_wo_taxonomy.shape
44
+ # torch.Size([1, 8, 1024])
45
+
46
+ outputs.seq_logits.shape
47
+ # torch.Size([1, 8, 11])
48
+ ```
49
+
50
+ Predict tracks using given metadata:
51
+ ```
52
+ metadata = # path to tensor metadata
53
+ # Enable track prediction mode
54
+ model.model.predict_tracks = True
55
+
56
+ # Forward pass
57
+ with torch.no_grad():
58
+ outputs = model(
59
+ input_ids=inputs["input_ids"].cuda(),
60
+ metadata=metadata.cuda(),
61
+ masked_taxonomy=human_taxonomy.cuda()
62
+ )
63
+
64
+ outputs.track_yhat
65
+ ```
66
+
67
+ Get metadata-dependent embeddings:
68
+ ```
69
+ metadata = # path to tensor metadata
70
+ # Enable track prediction mode
71
+ model.model.predict_tracks = True
72
+
73
+ # Forward pass
74
+ with torch.no_grad():
75
+ outputs = model(
76
+ input_ids=inputs["input_ids"].cuda(),
77
+ metadata=metadata.cuda(),
78
+ masked_taxonomy=human_taxonomy.cuda()
79
+ )
80
+
81
+ outputs.last_hidden_state_track.shape
82
+ # torch.Size([1, 8, 1024])
83
+ ```