pantoniadis commited on
Commit
9bf9ca5
·
verified ·
1 Parent(s): 380c39d

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +84 -0
README.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ ---
6
+
7
+
8
+ Load the model:
9
+
10
+ ```
11
+ import torch
12
+ from transformers import AutoModel, AutoTokenizer
13
+
14
+ model_name = "rnalm/144M_MS_MM_best"
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
17
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
18
+
19
+ # Move model to GPU
20
+ model = model.cuda()
21
+ ```
22
+
23
+ Inference without using the track prediction head:
24
+ ```
25
+ # disable the track head in order to avoid providing the metadata
26
+ model.model.predict_tracks = False
27
+ inputs = tokenizer("ACGTACGT", return_tensors="pt")
28
+ # always add taxonomy information in the multispecies model
29
+ assert model.model.use_taxonomy == True
30
+ # use human taxonomy
31
+ # for a full list of taxonomies check 'rnalm/tokenizers/taxonomy_mappings/processed_taxonomy.json'
32
+ human_taxonomy = torch.tensor([2317, 2318, 2319, 2266, 2248, 2072, 2053, 1875])
33
+
34
+ with torch.no_grad():
35
+ outputs = model(input_ids=inputs["input_ids"].cuda(), masked_taxonomy=human_taxonomy.cuda())
36
+
37
+
38
+ last_hidden_state_w_taxonomy = outputs.last_hidden_state
39
+ last_hidden_state_wo_taxonomy = outputs.last_hidden_state[:, 1:, :]
40
+
41
+ last_hidden_state_w_taxonomy.shape
42
+ # torch.Size([1, 9, 768])
43
+
44
+ last_hidden_state_wo_taxonomy.shape
45
+ # torch.Size([1, 8, 768])
46
+
47
+ outputs.seq_logits.shape
48
+ # torch.Size([1, 8, 11])
49
+ ```
50
+
51
+ Predict tracks using given metadata:
52
+ ```
53
+ metadata = # path to tensor metadata
54
+ # Enable track prediction mode
55
+ model.model.predict_tracks = True
56
+
57
+ # Forward pass
58
+ with torch.no_grad():
59
+ outputs = model(
60
+ input_ids=inputs["input_ids"].cuda(),
61
+ metadata=metadata.cuda(),
62
+ masked_taxonomy=human_taxonomy.cuda()
63
+ )
64
+
65
+ outputs.track_yhat
66
+ ```
67
+
68
+ Get metadata-dependent embeddings:
69
+ ```
70
+ metadata = # path to tensor metadata
71
+ # Enable track prediction mode
72
+ model.model.predict_tracks = True
73
+
74
+ # Forward pass
75
+ with torch.no_grad():
76
+ outputs = model(
77
+ input_ids=inputs["input_ids"].cuda(),
78
+ metadata=metadata.cuda(),
79
+ masked_taxonomy=human_taxonomy.cuda()
80
+ )
81
+
82
+ outputs.last_hidden_state_track.shape
83
+ # torch.Size([1, 8, 768])
84
+ ```