File size: 2,186 Bytes
ced6e93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
Registration script for Autoencoder models with Hugging Face AutoModel framework.
"""

from transformers import AutoConfig, AutoModel
from configuration_autoencoder import AutoencoderConfig
from modeling_autoencoder import AutoencoderModel, AutoencoderForReconstruction


def register_autoencoder_models():
    """
    Register the autoencoder models with the Hugging Face AutoModel framework.
    
    This function registers:
    - AutoencoderConfig with AutoConfig
    - AutoencoderModel with AutoModel
    - AutoencoderForReconstruction with AutoModel (for reconstruction tasks)
    
    After calling this function, you can use:
    - AutoConfig.from_pretrained() to load autoencoder configs
    - AutoModel.from_pretrained() to load autoencoder models
    """
    
    # Register configuration
    AutoConfig.register("autoencoder", AutoencoderConfig)
    
    # Register base model
    AutoModel.register(AutoencoderConfig, AutoencoderModel)
    
    # Note: For task-specific models like AutoencoderForReconstruction,
    # we would typically create a custom AutoModelForReconstruction class
    # and register it separately. For now, users can import directly.
    
    print("✅ Autoencoder models registered with Hugging Face AutoModel framework!")
    print("You can now use:")
    print("  - AutoConfig.from_pretrained() for configs")
    print("  - AutoModel.from_pretrained() for models")
    print("  - Direct imports for task-specific models")


def register_for_auto_class():
    """
    Register models for auto class functionality when saving/loading.
    
    This enables the models to be automatically discovered when using
    save_pretrained() and from_pretrained() methods.
    """
    
    # Register config for auto class
    AutoencoderConfig.register_for_auto_class()
    
    # Register models for auto class
    AutoencoderModel.register_for_auto_class("AutoModel")
    AutoencoderForReconstruction.register_for_auto_class("AutoModel")
    
    print("✅ Models registered for auto class functionality!")


if __name__ == "__main__":
    # Register models when script is run directly
    register_autoencoder_models()
    register_for_auto_class()