ChrisMcCormick's picture
Debugging
2a903c8 verified
# -*- coding: utf-8 -*-
"""
Hugging Face Model Repository Init File
This file should be renamed to __init__.py and placed at the root of your
Hugging Face model repository. It ensures that the model registration code
is executed when the model is loaded from the Hub.
"""
print("\n========================================\n")
print(" root/__init__.py: Is this being run?")
print("\n========================================\n")
# Import the registration code from models/__init__.py
# This will execute the AutoConfig and AutoModel registration
try:
from models import *
except ImportError as e:
# If there are import issues, try importing the specific registration code
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from models.shared_space_config import SharedSpaceDecoderConfig
from models.shared_space_decoder import (
SharedSpaceDecoderPreTrainedModel,
SharedSpaceDecoderModel,
)
from layers.task_heads import SharedSpaceDecoderForCausalLM
# Register the configuration class with AutoConfig
AutoConfig.register("shared_space_decoder", SharedSpaceDecoderConfig)
# Register the model classes with AutoModel
AutoModel.register(SharedSpaceDecoderConfig, SharedSpaceDecoderModel)
AutoModelForCausalLM.register(SharedSpaceDecoderConfig, SharedSpaceDecoderForCausalLM)