import os import torch from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer import argparse # Define MODEL_CLASSES as it's required by your loadmodel function MODEL_CLASSES = { "dna": (BertConfig, BertForMaskedLM, DNATokenizer), # ... (other classes omitted for brevity) } def loadmodel(model_dir): config_class, model_class, tokenizer_class = MODEL_CLASSES['dna'] # Changed 'DNA' to 'dna' for Python keys print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}") # 1. Load Configuration config = config_class.from_pretrained( model_dir, cache_dir = None, ) # 2. Load Model Weights # NOTE: Since you are extracting embeddings, we should use BertModel, not BertForMaskedLM # BertModel is the base transformer without the MLM head. base_model_class = BertModel if model_class == BertForMaskedLM else model_class model = base_model_class.from_pretrained( model_dir, from_tf=bool(".ckpt" in model_dir), config=config, cache_dir= None, ) # 3. Set Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # Set model to evaluation mode print(f"Model loaded onto device: {device}") # 4. Load Tokenizer (using custom environment variables) #tokenizer_class.vocab_files_names = {"vocab_file": os.getenv("VOCAB_NAME")} #tokenizer_class.pretrained_vocab_files_map = {"vocab_file": {'dna': os.getenv("VOCAB_PATH")}} # Use 'dna' key tokenizer = tokenizer_class.from_pretrained(model_dir) print(f"Tokenizer vocabulary size: {len(tokenizer)}") return config, model, tokenizer # --- Main Call --- # Use the environment variable set in the shell as the model directory parser = argparse.ArgumentParser() parser.add_argument("--MODEL_DIR", type=str, required=True) args = parser.parse_args() model_dir = args.MODEL_DIR if model_dir != "/path/to/default": config, model, tokenizer = loadmodel(model_dir) print("Model and Tokenizer loaded successfully.") embedding_layer = model.get_input_embeddings() print(embedding_layer.weight.shape) seq = "ACGTACGTACGT" tokens = tokenizer.tokenize(" ".join([seq[i:i+6] for i in range(len(seq)-5)])) print(tokens[:10]) else: print("Error: MODEL_DIR environment variable was not set.")