|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
--- |
|
|
|
|
|
Load the model: |
|
|
|
|
|
``` |
|
|
import torch |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
model_name = "rnalm/144M_MS_MM_last" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
# Move model to GPU |
|
|
model = model.cuda() |
|
|
``` |
|
|
|
|
|
Inference without using the track prediction head: |
|
|
``` |
|
|
# disable the track head in order to avoid providing the metadata |
|
|
model.model.predict_tracks = False |
|
|
inputs = tokenizer("ACGTACGT", return_tensors="pt") |
|
|
# always add taxonomy information in the multispecies model |
|
|
assert model.model.use_taxonomy == True |
|
|
# use human taxonomy |
|
|
# for a full list of taxonomies check 'rnalm/tokenizers/taxonomy_mappings/processed_taxonomy.json' |
|
|
human_taxonomy = torch.tensor([2317, 2318, 2319, 2266, 2248, 2072, 2053, 1875]) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids=inputs["input_ids"].cuda(), masked_taxonomy=human_taxonomy.cuda()) |
|
|
|
|
|
|
|
|
last_hidden_state_w_taxonomy = outputs.last_hidden_state |
|
|
last_hidden_state_wo_taxonomy = outputs.last_hidden_state[:, 1:, :] |
|
|
|
|
|
last_hidden_state_w_taxonomy.shape |
|
|
# torch.Size([1, 9, 768]) |
|
|
|
|
|
last_hidden_state_wo_taxonomy.shape |
|
|
# torch.Size([1, 8, 768]) |
|
|
|
|
|
outputs.seq_logits.shape |
|
|
# torch.Size([1, 8, 11]) |
|
|
``` |
|
|
|
|
|
Predict tracks using given metadata: |
|
|
``` |
|
|
metadata = # path to tensor metadata |
|
|
# Enable track prediction mode |
|
|
model.model.predict_tracks = True |
|
|
|
|
|
# Forward pass |
|
|
with torch.no_grad(): |
|
|
outputs = model( |
|
|
input_ids=inputs["input_ids"].cuda(), |
|
|
metadata=metadata.cuda(), |
|
|
masked_taxonomy=human_taxonomy.cuda() |
|
|
) |
|
|
|
|
|
outputs.track_yhat |
|
|
``` |
|
|
|
|
|
Get metadata-dependent embeddings: |
|
|
``` |
|
|
metadata = # path to tensor metadata |
|
|
# Enable track prediction mode |
|
|
model.model.predict_tracks = True |
|
|
|
|
|
# Forward pass |
|
|
with torch.no_grad(): |
|
|
outputs = model( |
|
|
input_ids=inputs["input_ids"].cuda(), |
|
|
metadata=metadata.cuda(), |
|
|
masked_taxonomy=human_taxonomy.cuda() |
|
|
) |
|
|
|
|
|
outputs.last_hidden_state_track.shape |
|
|
# torch.Size([1, 8, 768]) |
|
|
``` |
|
|
|
|
|
|