from transformers import AutoModelForMaskedLM from .modeling_custom_llama import CustomLlamaConfig, CustomLlamaForCausalLM, CustomLlamaForMaskedLM from transformers import CONFIG_MAPPING, MODEL_MAPPING # Assuming CustomLlamaConfig is your config class AutoModelForMaskedLM.register(CustomLlamaConfig, CustomLlamaForMaskedLM) CONFIG_MAPPING.update({"custom_llama": CustomLlamaConfig}) # MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM})