--- license: mit language: - en --- Load the model: ``` import torch from transformers import AutoModel, AutoTokenizer model_name = "rnalm/144M_MS_MM_last" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModel.from_pretrained(model_name, trust_remote_code=True) # Move model to GPU model = model.cuda() ``` Inference without using the track prediction head: ``` # disable the track head in order to avoid providing the metadata model.model.predict_tracks = False inputs = tokenizer("ACGTACGT", return_tensors="pt") # always add taxonomy information in the multispecies model assert model.model.use_taxonomy == True # use human taxonomy # for a full list of taxonomies check 'rnalm/tokenizers/taxonomy_mappings/processed_taxonomy.json' human_taxonomy = torch.tensor([2317, 2318, 2319, 2266, 2248, 2072, 2053, 1875]) with torch.no_grad(): outputs = model(input_ids=inputs["input_ids"].cuda(), masked_taxonomy=human_taxonomy.cuda()) last_hidden_state_w_taxonomy = outputs.last_hidden_state last_hidden_state_wo_taxonomy = outputs.last_hidden_state[:, 1:, :] last_hidden_state_w_taxonomy.shape # torch.Size([1, 9, 768]) last_hidden_state_wo_taxonomy.shape # torch.Size([1, 8, 768]) outputs.seq_logits.shape # torch.Size([1, 8, 11]) ``` Predict tracks using given metadata: ``` metadata = # path to tensor metadata # Enable track prediction mode model.model.predict_tracks = True # Forward pass with torch.no_grad(): outputs = model( input_ids=inputs["input_ids"].cuda(), metadata=metadata.cuda(), masked_taxonomy=human_taxonomy.cuda() ) outputs.track_yhat ``` Get metadata-dependent embeddings: ``` metadata = # path to tensor metadata # Enable track prediction mode model.model.predict_tracks = True # Forward pass with torch.no_grad(): outputs = model( input_ids=inputs["input_ids"].cuda(), metadata=metadata.cuda(), masked_taxonomy=human_taxonomy.cuda() ) outputs.last_hidden_state_track.shape # torch.Size([1, 8, 768]) ```