Load the model:

import torch
from transformers import AutoModel, AutoTokenizer

model_name = "rnalm/446M_H_MM_best"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

# Move model to GPU
model = model.cuda()

Inference without using the track prediction head:

# disable the track head in order to avoid providing the metadata
model.model.predict_tracks = False
inputs = tokenizer("ACGTACGT", return_tensors="pt")

with torch.no_grad():
    outputs = model(input_ids=inputs["input_ids"].cuda())

outputs.last_hidden_state.shape
# torch.Size([1, 8, 1024])

outputs.seq_logits.shape
# torch.Size([1, 8, 11])

Predict tracks using given metadata:

metadata = # path to tensor metadata
# Enable track prediction mode
model.model.predict_tracks = True

# Forward pass
with torch.no_grad():
    outputs = model(
        input_ids=inputs["input_ids"].cuda(),
        metadata=metadata.cuda()
    )

outputs.track_yhat

Get metadata-dependent embeddings:

metadata = # path to tensor metadata
# Enable track prediction mode
model.model.predict_tracks = True

# Forward pass
with torch.no_grad():
    outputs = model(
        input_ids=inputs["input_ids"].cuda(),
        metadata=metadata.cuda()
    )

outputs.last_hidden_state_track.shape
# torch.Size([1, 8, 1024])
Downloads last month
3
Safetensors
Model size
0.5B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support