File size: 470 Bytes
0a2b89e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 | # -*- coding: utf-8 -*-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from fla.models.bitnet.configuration_bitnet import BitNetConfig
from fla.models.bitnet.modeling_bitnet import BitNetForCausalLM, BitNetModel
AutoConfig.register(BitNetConfig.model_type, BitNetConfig)
AutoModel.register(BitNetConfig, BitNetModel)
AutoModelForCausalLM.register(BitNetConfig, BitNetForCausalLM)
__all__ = ['BitNetConfig', 'BitNetForCausalLM', 'BitNetModel']
|