File size: 2,484 Bytes
ab6c03c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | import os
import torch
from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer
import argparse
# Define MODEL_CLASSES as it's required by your loadmodel function
MODEL_CLASSES = {
"dna": (BertConfig, BertForMaskedLM, DNATokenizer),
# ... (other classes omitted for brevity)
}
def loadmodel(model_dir):
config_class, model_class, tokenizer_class = MODEL_CLASSES['dna'] # Changed 'DNA' to 'dna' for Python keys
print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}")
# 1. Load Configuration
config = config_class.from_pretrained(
model_dir,
cache_dir = None,
)
# 2. Load Model Weights
# NOTE: Since you are extracting embeddings, we should use BertModel, not BertForMaskedLM
# BertModel is the base transformer without the MLM head.
base_model_class = BertModel if model_class == BertForMaskedLM else model_class
model = base_model_class.from_pretrained(
model_dir,
from_tf=bool(".ckpt" in model_dir),
config=config,
cache_dir= None,
)
# 3. Set Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval() # Set model to evaluation mode
print(f"Model loaded onto device: {device}")
# 4. Load Tokenizer (using custom environment variables)
#tokenizer_class.vocab_files_names = {"vocab_file": os.getenv("VOCAB_NAME")}
#tokenizer_class.pretrained_vocab_files_map = {"vocab_file": {'dna': os.getenv("VOCAB_PATH")}} # Use 'dna' key
tokenizer = tokenizer_class.from_pretrained(model_dir)
print(f"Tokenizer vocabulary size: {len(tokenizer)}")
return config, model, tokenizer
# --- Main Call ---
# Use the environment variable set in the shell as the model directory
parser = argparse.ArgumentParser()
parser.add_argument("--MODEL_DIR", type=str, required=True)
args = parser.parse_args()
model_dir = args.MODEL_DIR
if model_dir != "/path/to/default":
config, model, tokenizer = loadmodel(model_dir)
print("Model and Tokenizer loaded successfully.")
embedding_layer = model.get_input_embeddings()
print(embedding_layer.weight.shape)
seq = "ACGTACGTACGT"
tokens = tokenizer.tokenize(" ".join([seq[i:i+6] for i in range(len(seq)-5)]))
print(tokens[:10])
else:
print("Error: MODEL_DIR environment variable was not set.")
|