nanogpt-openwebtext / __init__.py
chegde's picture
Fix model registration for AutoModel compatibility
063be31 verified
"""NanoGPT HuggingFace Integration"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
# Import our classes
try:
from .configuration_nanogpt import NanoGPTConfig
from .modeling_nanogpt import (
NanoGPTModel,
NanoGPTForCausalLM,
NanoGPTPreTrainedModel
)
except ImportError:
from configuration_nanogpt import NanoGPTConfig
from modeling_nanogpt import (
NanoGPTModel,
NanoGPTForCausalLM,
NanoGPTPreTrainedModel
)
# Register the model with Auto* classes
AutoConfig.register("nanogpt", NanoGPTConfig)
AutoModel.register(NanoGPTConfig, NanoGPTModel)
AutoModelForCausalLM.register(NanoGPTConfig, NanoGPTForCausalLM)
__all__ = [
"NanoGPTConfig",
"NanoGPTModel",
"NanoGPTForCausalLM",
"NanoGPTPreTrainedModel"
]