File size: 2,186 Bytes
ced6e93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
"""
Registration script for Autoencoder models with Hugging Face AutoModel framework.
"""
from transformers import AutoConfig, AutoModel
from configuration_autoencoder import AutoencoderConfig
from modeling_autoencoder import AutoencoderModel, AutoencoderForReconstruction
def register_autoencoder_models():
"""
Register the autoencoder models with the Hugging Face AutoModel framework.
This function registers:
- AutoencoderConfig with AutoConfig
- AutoencoderModel with AutoModel
- AutoencoderForReconstruction with AutoModel (for reconstruction tasks)
After calling this function, you can use:
- AutoConfig.from_pretrained() to load autoencoder configs
- AutoModel.from_pretrained() to load autoencoder models
"""
# Register configuration
AutoConfig.register("autoencoder", AutoencoderConfig)
# Register base model
AutoModel.register(AutoencoderConfig, AutoencoderModel)
# Note: For task-specific models like AutoencoderForReconstruction,
# we would typically create a custom AutoModelForReconstruction class
# and register it separately. For now, users can import directly.
print("✅ Autoencoder models registered with Hugging Face AutoModel framework!")
print("You can now use:")
print(" - AutoConfig.from_pretrained() for configs")
print(" - AutoModel.from_pretrained() for models")
print(" - Direct imports for task-specific models")
def register_for_auto_class():
"""
Register models for auto class functionality when saving/loading.
This enables the models to be automatically discovered when using
save_pretrained() and from_pretrained() methods.
"""
# Register config for auto class
AutoencoderConfig.register_for_auto_class()
# Register models for auto class
AutoencoderModel.register_for_auto_class("AutoModel")
AutoencoderForReconstruction.register_for_auto_class("AutoModel")
print("✅ Models registered for auto class functionality!")
if __name__ == "__main__":
# Register models when script is run directly
register_autoencoder_models()
register_for_auto_class()
|