| # -*- coding: utf-8 -*- | |
| from transformers import AutoConfig, AutoModel, AutoModelForCausalLM | |
| from fla.models.transformer.configuration_transformer import TransformerConfig | |
| from fla.models.transformer.modeling_transformer import TransformerForCausalLM, TransformerModel | |
| AutoConfig.register(TransformerConfig.model_type, TransformerConfig) | |
| AutoModel.register(TransformerConfig, TransformerModel) | |
| AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM) | |
| __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'] | |