| import torch | |
| from transformers import AutoTokenizer, AutoModel, EsmForMaskedLM, AutoModelForMaskedLM | |
| def load_esm2_model(esm_model_path): | |
| tokenizer = AutoTokenizer.from_pretrained(esm_model_path) | |
| model = AutoModelForMaskedLM.from_pretrained(esm_model_path) | |
| return tokenizer, model | |
| def load_mlm_model(esm_model_path, ckpt_path): | |
| tokenizer = AutoTokenizer.from_pretrained(esm_model_path) | |
| model = AutoModelForMaskedLM.from_pretrained(ckpt_path) | |
| return tokenizer, model |