pantoniadis commited on
Commit
d3a7262
·
verified ·
1 Parent(s): ce0c667

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +69 -0
README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ ---
6
+
7
+ Load the model:
8
+
9
+ ```
10
+ import torch
11
+ from transformers import AutoModel, AutoTokenizer
12
+
13
+ model_name = "rnalm/446M_H_MM_best"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
16
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
17
+
18
+ # Move model to GPU
19
+ model = model.cuda()
20
+ ```
21
+
22
+ Inference without using the track prediction head:
23
+ ```
24
+ # disable the track head in order to avoid providing the metadata
25
+ model.model.predict_tracks = False
26
+ inputs = tokenizer("ACGTACGT", return_tensors="pt")
27
+
28
+ with torch.no_grad():
29
+ outputs = model(input_ids=inputs["input_ids"].cuda())
30
+
31
+ outputs.last_hidden_state.shape
32
+ # torch.Size([1, 8, 1024])
33
+
34
+ outputs.seq_logits.shape
35
+ # torch.Size([1, 8, 11])
36
+ ```
37
+
38
+ Predict tracks using given metadata:
39
+ ```
40
+ metadata = # path to tensor metadata
41
+ # Enable track prediction mode
42
+ model.model.predict_tracks = True
43
+
44
+ # Forward pass
45
+ with torch.no_grad():
46
+ outputs = model(
47
+ input_ids=inputs["input_ids"].cuda(),
48
+ metadata=metadata.cuda()
49
+ )
50
+
51
+ outputs.track_yhat
52
+ ```
53
+
54
+ Get metadata-dependent embeddings:
55
+ ```
56
+ metadata = # path to tensor metadata
57
+ # Enable track prediction mode
58
+ model.model.predict_tracks = True
59
+
60
+ # Forward pass
61
+ with torch.no_grad():
62
+ outputs = model(
63
+ input_ids=inputs["input_ids"].cuda(),
64
+ metadata=metadata.cuda()
65
+ )
66
+
67
+ outputs.last_hidden_state_track.shape
68
+ # torch.Size([1, 8, 1024])
69
+ ```