| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ Auto Model class.""" |
|
|
|
|
| from collections import OrderedDict |
|
|
| from ...utils import logging |
| from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update |
| from .configuration_auto import CONFIG_MAPPING_NAMES |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| FLAX_MODEL_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "FlaxAlbertModel"), |
| ("bart", "FlaxBartModel"), |
| ("beit", "FlaxBeitModel"), |
| ("bert", "FlaxBertModel"), |
| ("big_bird", "FlaxBigBirdModel"), |
| ("blenderbot", "FlaxBlenderbotModel"), |
| ("blenderbot-small", "FlaxBlenderbotSmallModel"), |
| ("bloom", "FlaxBloomModel"), |
| ("clip", "FlaxCLIPModel"), |
| ("distilbert", "FlaxDistilBertModel"), |
| ("electra", "FlaxElectraModel"), |
| ("gpt-sw3", "FlaxGPT2Model"), |
| ("gpt2", "FlaxGPT2Model"), |
| ("gpt_neo", "FlaxGPTNeoModel"), |
| ("gptj", "FlaxGPTJModel"), |
| ("longt5", "FlaxLongT5Model"), |
| ("marian", "FlaxMarianModel"), |
| ("mbart", "FlaxMBartModel"), |
| ("mt5", "FlaxMT5Model"), |
| ("opt", "FlaxOPTModel"), |
| ("pegasus", "FlaxPegasusModel"), |
| ("regnet", "FlaxRegNetModel"), |
| ("resnet", "FlaxResNetModel"), |
| ("roberta", "FlaxRobertaModel"), |
| ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"), |
| ("roformer", "FlaxRoFormerModel"), |
| ("t5", "FlaxT5Model"), |
| ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), |
| ("vit", "FlaxViTModel"), |
| ("wav2vec2", "FlaxWav2Vec2Model"), |
| ("whisper", "FlaxWhisperModel"), |
| ("xglm", "FlaxXGLMModel"), |
| ("xlm-roberta", "FlaxXLMRobertaModel"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "FlaxAlbertForPreTraining"), |
| ("bart", "FlaxBartForConditionalGeneration"), |
| ("bert", "FlaxBertForPreTraining"), |
| ("big_bird", "FlaxBigBirdForPreTraining"), |
| ("electra", "FlaxElectraForPreTraining"), |
| ("longt5", "FlaxLongT5ForConditionalGeneration"), |
| ("mbart", "FlaxMBartForConditionalGeneration"), |
| ("mt5", "FlaxMT5ForConditionalGeneration"), |
| ("roberta", "FlaxRobertaForMaskedLM"), |
| ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"), |
| ("roformer", "FlaxRoFormerForMaskedLM"), |
| ("t5", "FlaxT5ForConditionalGeneration"), |
| ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), |
| ("whisper", "FlaxWhisperForConditionalGeneration"), |
| ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "FlaxAlbertForMaskedLM"), |
| ("bart", "FlaxBartForConditionalGeneration"), |
| ("bert", "FlaxBertForMaskedLM"), |
| ("big_bird", "FlaxBigBirdForMaskedLM"), |
| ("distilbert", "FlaxDistilBertForMaskedLM"), |
| ("electra", "FlaxElectraForMaskedLM"), |
| ("mbart", "FlaxMBartForConditionalGeneration"), |
| ("roberta", "FlaxRobertaForMaskedLM"), |
| ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"), |
| ("roformer", "FlaxRoFormerForMaskedLM"), |
| ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("bart", "FlaxBartForConditionalGeneration"), |
| ("blenderbot", "FlaxBlenderbotForConditionalGeneration"), |
| ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"), |
| ("encoder-decoder", "FlaxEncoderDecoderModel"), |
| ("longt5", "FlaxLongT5ForConditionalGeneration"), |
| ("marian", "FlaxMarianMTModel"), |
| ("mbart", "FlaxMBartForConditionalGeneration"), |
| ("mt5", "FlaxMT5ForConditionalGeneration"), |
| ("pegasus", "FlaxPegasusForConditionalGeneration"), |
| ("t5", "FlaxT5ForConditionalGeneration"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("beit", "FlaxBeitForImageClassification"), |
| ("regnet", "FlaxRegNetForImageClassification"), |
| ("resnet", "FlaxResNetForImageClassification"), |
| ("vit", "FlaxViTForImageClassification"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( |
| [ |
| ("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("bart", "FlaxBartForCausalLM"), |
| ("bert", "FlaxBertForCausalLM"), |
| ("big_bird", "FlaxBigBirdForCausalLM"), |
| ("bloom", "FlaxBloomForCausalLM"), |
| ("electra", "FlaxElectraForCausalLM"), |
| ("gpt-sw3", "FlaxGPT2LMHeadModel"), |
| ("gpt2", "FlaxGPT2LMHeadModel"), |
| ("gpt_neo", "FlaxGPTNeoForCausalLM"), |
| ("gptj", "FlaxGPTJForCausalLM"), |
| ("opt", "FlaxOPTForCausalLM"), |
| ("roberta", "FlaxRobertaForCausalLM"), |
| ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"), |
| ("xglm", "FlaxXGLMForCausalLM"), |
| ("xlm-roberta", "FlaxXLMRobertaForCausalLM"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "FlaxAlbertForSequenceClassification"), |
| ("bart", "FlaxBartForSequenceClassification"), |
| ("bert", "FlaxBertForSequenceClassification"), |
| ("big_bird", "FlaxBigBirdForSequenceClassification"), |
| ("distilbert", "FlaxDistilBertForSequenceClassification"), |
| ("electra", "FlaxElectraForSequenceClassification"), |
| ("mbart", "FlaxMBartForSequenceClassification"), |
| ("roberta", "FlaxRobertaForSequenceClassification"), |
| ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"), |
| ("roformer", "FlaxRoFormerForSequenceClassification"), |
| ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "FlaxAlbertForQuestionAnswering"), |
| ("bart", "FlaxBartForQuestionAnswering"), |
| ("bert", "FlaxBertForQuestionAnswering"), |
| ("big_bird", "FlaxBigBirdForQuestionAnswering"), |
| ("distilbert", "FlaxDistilBertForQuestionAnswering"), |
| ("electra", "FlaxElectraForQuestionAnswering"), |
| ("mbart", "FlaxMBartForQuestionAnswering"), |
| ("roberta", "FlaxRobertaForQuestionAnswering"), |
| ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"), |
| ("roformer", "FlaxRoFormerForQuestionAnswering"), |
| ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "FlaxAlbertForTokenClassification"), |
| ("bert", "FlaxBertForTokenClassification"), |
| ("big_bird", "FlaxBigBirdForTokenClassification"), |
| ("distilbert", "FlaxDistilBertForTokenClassification"), |
| ("electra", "FlaxElectraForTokenClassification"), |
| ("roberta", "FlaxRobertaForTokenClassification"), |
| ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"), |
| ("roformer", "FlaxRoFormerForTokenClassification"), |
| ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "FlaxAlbertForMultipleChoice"), |
| ("bert", "FlaxBertForMultipleChoice"), |
| ("big_bird", "FlaxBigBirdForMultipleChoice"), |
| ("distilbert", "FlaxDistilBertForMultipleChoice"), |
| ("electra", "FlaxElectraForMultipleChoice"), |
| ("roberta", "FlaxRobertaForMultipleChoice"), |
| ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"), |
| ("roformer", "FlaxRoFormerForMultipleChoice"), |
| ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( |
| [ |
| ("bert", "FlaxBertForNextSentencePrediction"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( |
| [ |
| ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"), |
| ("whisper", "FlaxWhisperForConditionalGeneration"), |
| ] |
| ) |
|
|
| FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| ("whisper", "FlaxWhisperForAudioClassification"), |
| ] |
| ) |
|
|
| FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES) |
| FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES) |
| FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES) |
| FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES |
| ) |
| FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES |
| ) |
| FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) |
| FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) |
| FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES |
| ) |
| FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES |
| ) |
| FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES |
| ) |
| FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES |
| ) |
| FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES |
| ) |
| FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES |
| ) |
| FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES |
| ) |
|
|
|
|
| class FlaxAutoModel(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_MAPPING |
|
|
|
|
| FlaxAutoModel = auto_class_update(FlaxAutoModel) |
|
|
|
|
| class FlaxAutoModelForPreTraining(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING |
|
|
|
|
| FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining") |
|
|
|
|
| class FlaxAutoModelForCausalLM(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING |
|
|
|
|
| FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling") |
|
|
|
|
| class FlaxAutoModelForMaskedLM(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING |
|
|
|
|
| FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling") |
|
|
|
|
| class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING |
|
|
|
|
| FlaxAutoModelForSeq2SeqLM = auto_class_update( |
| FlaxAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" |
| ) |
|
|
|
|
| class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING |
|
|
|
|
| FlaxAutoModelForSequenceClassification = auto_class_update( |
| FlaxAutoModelForSequenceClassification, head_doc="sequence classification" |
| ) |
|
|
|
|
| class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING |
|
|
|
|
| FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering") |
|
|
|
|
| class FlaxAutoModelForTokenClassification(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING |
|
|
|
|
| FlaxAutoModelForTokenClassification = auto_class_update( |
| FlaxAutoModelForTokenClassification, head_doc="token classification" |
| ) |
|
|
|
|
| class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING |
|
|
|
|
| FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice") |
|
|
|
|
| class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING |
|
|
|
|
| FlaxAutoModelForNextSentencePrediction = auto_class_update( |
| FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction" |
| ) |
|
|
|
|
| class FlaxAutoModelForImageClassification(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING |
|
|
|
|
| FlaxAutoModelForImageClassification = auto_class_update( |
| FlaxAutoModelForImageClassification, head_doc="image classification" |
| ) |
|
|
|
|
| class FlaxAutoModelForVision2Seq(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING |
|
|
|
|
| FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling") |
|
|
|
|
| class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass): |
| _model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING |
|
|
|
|
| FlaxAutoModelForSpeechSeq2Seq = auto_class_update( |
| FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" |
| ) |
|
|