| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ Auto Model class.""" |
|
|
| import warnings |
| from collections import OrderedDict |
|
|
| from ...utils import logging |
| from .auto_factory import _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update |
| from .configuration_auto import CONFIG_MAPPING_NAMES |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| MODEL_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "AlbertModel"), |
| ("align", "AlignModel"), |
| ("altclip", "AltCLIPModel"), |
| ("audio-spectrogram-transformer", "ASTModel"), |
| ("autoformer", "AutoformerModel"), |
| ("bark", "BarkModel"), |
| ("bart", "BartModel"), |
| ("beit", "BeitModel"), |
| ("bert", "BertModel"), |
| ("bert-generation", "BertGenerationEncoder"), |
| ("big_bird", "BigBirdModel"), |
| ("bigbird_pegasus", "BigBirdPegasusModel"), |
| ("biogpt", "BioGptModel"), |
| ("bit", "BitModel"), |
| ("blenderbot", "BlenderbotModel"), |
| ("blenderbot-small", "BlenderbotSmallModel"), |
| ("blip", "BlipModel"), |
| ("blip-2", "Blip2Model"), |
| ("bloom", "BloomModel"), |
| ("bridgetower", "BridgeTowerModel"), |
| ("bros", "BrosModel"), |
| ("camembert", "CamembertModel"), |
| ("canine", "CanineModel"), |
| ("chinese_clip", "ChineseCLIPModel"), |
| ("clap", "ClapModel"), |
| ("clip", "CLIPModel"), |
| ("clipseg", "CLIPSegModel"), |
| ("code_llama", "LlamaModel"), |
| ("codegen", "CodeGenModel"), |
| ("conditional_detr", "ConditionalDetrModel"), |
| ("convbert", "ConvBertModel"), |
| ("convnext", "ConvNextModel"), |
| ("convnextv2", "ConvNextV2Model"), |
| ("cpmant", "CpmAntModel"), |
| ("ctrl", "CTRLModel"), |
| ("cvt", "CvtModel"), |
| ("data2vec-audio", "Data2VecAudioModel"), |
| ("data2vec-text", "Data2VecTextModel"), |
| ("data2vec-vision", "Data2VecVisionModel"), |
| ("deberta", "DebertaModel"), |
| ("deberta-v2", "DebertaV2Model"), |
| ("decision_transformer", "DecisionTransformerModel"), |
| ("deformable_detr", "DeformableDetrModel"), |
| ("deit", "DeiTModel"), |
| ("deta", "DetaModel"), |
| ("detr", "DetrModel"), |
| ("dinat", "DinatModel"), |
| ("dinov2", "Dinov2Model"), |
| ("distilbert", "DistilBertModel"), |
| ("donut-swin", "DonutSwinModel"), |
| ("dpr", "DPRQuestionEncoder"), |
| ("dpt", "DPTModel"), |
| ("efficientformer", "EfficientFormerModel"), |
| ("efficientnet", "EfficientNetModel"), |
| ("electra", "ElectraModel"), |
| ("encodec", "EncodecModel"), |
| ("ernie", "ErnieModel"), |
| ("ernie_m", "ErnieMModel"), |
| ("esm", "EsmModel"), |
| ("falcon", "FalconModel"), |
| ("flaubert", "FlaubertModel"), |
| ("flava", "FlavaModel"), |
| ("fnet", "FNetModel"), |
| ("focalnet", "FocalNetModel"), |
| ("fsmt", "FSMTModel"), |
| ("funnel", ("FunnelModel", "FunnelBaseModel")), |
| ("git", "GitModel"), |
| ("glpn", "GLPNModel"), |
| ("gpt-sw3", "GPT2Model"), |
| ("gpt2", "GPT2Model"), |
| ("gpt_bigcode", "GPTBigCodeModel"), |
| ("gpt_neo", "GPTNeoModel"), |
| ("gpt_neox", "GPTNeoXModel"), |
| ("gpt_neox_japanese", "GPTNeoXJapaneseModel"), |
| ("gptj", "GPTJModel"), |
| ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), |
| ("graphormer", "GraphormerModel"), |
| ("groupvit", "GroupViTModel"), |
| ("hubert", "HubertModel"), |
| ("ibert", "IBertModel"), |
| ("idefics", "IdeficsModel"), |
| ("imagegpt", "ImageGPTModel"), |
| ("informer", "InformerModel"), |
| ("jukebox", "JukeboxModel"), |
| ("layoutlm", "LayoutLMModel"), |
| ("layoutlmv2", "LayoutLMv2Model"), |
| ("layoutlmv3", "LayoutLMv3Model"), |
| ("led", "LEDModel"), |
| ("levit", "LevitModel"), |
| ("lilt", "LiltModel"), |
| ("llama", "LlamaModel"), |
| ("longformer", "LongformerModel"), |
| ("longt5", "LongT5Model"), |
| ("luke", "LukeModel"), |
| ("lxmert", "LxmertModel"), |
| ("m2m_100", "M2M100Model"), |
| ("marian", "MarianModel"), |
| ("markuplm", "MarkupLMModel"), |
| ("mask2former", "Mask2FormerModel"), |
| ("maskformer", "MaskFormerModel"), |
| ("maskformer-swin", "MaskFormerSwinModel"), |
| ("mbart", "MBartModel"), |
| ("mctct", "MCTCTModel"), |
| ("mega", "MegaModel"), |
| ("megatron-bert", "MegatronBertModel"), |
| ("mgp-str", "MgpstrForSceneTextRecognition"), |
| ("mistral", "MistralModel"), |
| ("mobilebert", "MobileBertModel"), |
| ("mobilenet_v1", "MobileNetV1Model"), |
| ("mobilenet_v2", "MobileNetV2Model"), |
| ("mobilevit", "MobileViTModel"), |
| ("mobilevitv2", "MobileViTV2Model"), |
| ("mpnet", "MPNetModel"), |
| ("mpt", "MptModel"), |
| ("mra", "MraModel"), |
| ("mt5", "MT5Model"), |
| ("mvp", "MvpModel"), |
| ("nat", "NatModel"), |
| ("nezha", "NezhaModel"), |
| ("nllb-moe", "NllbMoeModel"), |
| ("nystromformer", "NystromformerModel"), |
| ("oneformer", "OneFormerModel"), |
| ("open-llama", "OpenLlamaModel"), |
| ("openai-gpt", "OpenAIGPTModel"), |
| ("opt", "OPTModel"), |
| ("owlvit", "OwlViTModel"), |
| ("pegasus", "PegasusModel"), |
| ("pegasus_x", "PegasusXModel"), |
| ("perceiver", "PerceiverModel"), |
| ("persimmon", "PersimmonModel"), |
| ("plbart", "PLBartModel"), |
| ("poolformer", "PoolFormerModel"), |
| ("prophetnet", "ProphetNetModel"), |
| ("pvt", "PvtModel"), |
| ("qdqbert", "QDQBertModel"), |
| ("reformer", "ReformerModel"), |
| ("regnet", "RegNetModel"), |
| ("rembert", "RemBertModel"), |
| ("resnet", "ResNetModel"), |
| ("retribert", "RetriBertModel"), |
| ("roberta", "RobertaModel"), |
| ("roberta-prelayernorm", "RobertaPreLayerNormModel"), |
| ("roc_bert", "RoCBertModel"), |
| ("roformer", "RoFormerModel"), |
| ("rwkv", "RwkvModel"), |
| ("sam", "SamModel"), |
| ("segformer", "SegformerModel"), |
| ("sew", "SEWModel"), |
| ("sew-d", "SEWDModel"), |
| ("speech_to_text", "Speech2TextModel"), |
| ("speecht5", "SpeechT5Model"), |
| ("splinter", "SplinterModel"), |
| ("squeezebert", "SqueezeBertModel"), |
| ("swiftformer", "SwiftFormerModel"), |
| ("swin", "SwinModel"), |
| ("swin2sr", "Swin2SRModel"), |
| ("swinv2", "Swinv2Model"), |
| ("switch_transformers", "SwitchTransformersModel"), |
| ("t5", "T5Model"), |
| ("table-transformer", "TableTransformerModel"), |
| ("tapas", "TapasModel"), |
| ("time_series_transformer", "TimeSeriesTransformerModel"), |
| ("timesformer", "TimesformerModel"), |
| ("timm_backbone", "TimmBackbone"), |
| ("trajectory_transformer", "TrajectoryTransformerModel"), |
| ("transfo-xl", "TransfoXLModel"), |
| ("tvlt", "TvltModel"), |
| ("umt5", "UMT5Model"), |
| ("unispeech", "UniSpeechModel"), |
| ("unispeech-sat", "UniSpeechSatModel"), |
| ("van", "VanModel"), |
| ("videomae", "VideoMAEModel"), |
| ("vilt", "ViltModel"), |
| ("vision-text-dual-encoder", "VisionTextDualEncoderModel"), |
| ("visual_bert", "VisualBertModel"), |
| ("vit", "ViTModel"), |
| ("vit_hybrid", "ViTHybridModel"), |
| ("vit_mae", "ViTMAEModel"), |
| ("vit_msn", "ViTMSNModel"), |
| ("vitdet", "VitDetModel"), |
| ("vits", "VitsModel"), |
| ("vivit", "VivitModel"), |
| ("wav2vec2", "Wav2Vec2Model"), |
| ("wav2vec2-conformer", "Wav2Vec2ConformerModel"), |
| ("wavlm", "WavLMModel"), |
| ("whisper", "WhisperModel"), |
| ("xclip", "XCLIPModel"), |
| ("xglm", "XGLMModel"), |
| ("xlm", "XLMModel"), |
| ("xlm-prophetnet", "XLMProphetNetModel"), |
| ("xlm-roberta", "XLMRobertaModel"), |
| ("xlm-roberta-xl", "XLMRobertaXLModel"), |
| ("xlnet", "XLNetModel"), |
| ("xmod", "XmodModel"), |
| ("yolos", "YolosModel"), |
| ("yoso", "YosoModel"), |
| ] |
| ) |
|
|
| MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "AlbertForPreTraining"), |
| ("bart", "BartForConditionalGeneration"), |
| ("bert", "BertForPreTraining"), |
| ("big_bird", "BigBirdForPreTraining"), |
| ("bloom", "BloomForCausalLM"), |
| ("camembert", "CamembertForMaskedLM"), |
| ("ctrl", "CTRLLMHeadModel"), |
| ("data2vec-text", "Data2VecTextForMaskedLM"), |
| ("deberta", "DebertaForMaskedLM"), |
| ("deberta-v2", "DebertaV2ForMaskedLM"), |
| ("distilbert", "DistilBertForMaskedLM"), |
| ("electra", "ElectraForPreTraining"), |
| ("ernie", "ErnieForPreTraining"), |
| ("flaubert", "FlaubertWithLMHeadModel"), |
| ("flava", "FlavaForPreTraining"), |
| ("fnet", "FNetForPreTraining"), |
| ("fsmt", "FSMTForConditionalGeneration"), |
| ("funnel", "FunnelForPreTraining"), |
| ("gpt-sw3", "GPT2LMHeadModel"), |
| ("gpt2", "GPT2LMHeadModel"), |
| ("gpt_bigcode", "GPTBigCodeForCausalLM"), |
| ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), |
| ("ibert", "IBertForMaskedLM"), |
| ("idefics", "IdeficsForVisionText2Text"), |
| ("layoutlm", "LayoutLMForMaskedLM"), |
| ("longformer", "LongformerForMaskedLM"), |
| ("luke", "LukeForMaskedLM"), |
| ("lxmert", "LxmertForPreTraining"), |
| ("mega", "MegaForMaskedLM"), |
| ("megatron-bert", "MegatronBertForPreTraining"), |
| ("mobilebert", "MobileBertForPreTraining"), |
| ("mpnet", "MPNetForMaskedLM"), |
| ("mpt", "MptForCausalLM"), |
| ("mra", "MraForMaskedLM"), |
| ("mvp", "MvpForConditionalGeneration"), |
| ("nezha", "NezhaForPreTraining"), |
| ("nllb-moe", "NllbMoeForConditionalGeneration"), |
| ("openai-gpt", "OpenAIGPTLMHeadModel"), |
| ("retribert", "RetriBertModel"), |
| ("roberta", "RobertaForMaskedLM"), |
| ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), |
| ("roc_bert", "RoCBertForPreTraining"), |
| ("rwkv", "RwkvForCausalLM"), |
| ("splinter", "SplinterForPreTraining"), |
| ("squeezebert", "SqueezeBertForMaskedLM"), |
| ("switch_transformers", "SwitchTransformersForConditionalGeneration"), |
| ("t5", "T5ForConditionalGeneration"), |
| ("tapas", "TapasForMaskedLM"), |
| ("transfo-xl", "TransfoXLLMHeadModel"), |
| ("tvlt", "TvltForPreTraining"), |
| ("unispeech", "UniSpeechForPreTraining"), |
| ("unispeech-sat", "UniSpeechSatForPreTraining"), |
| ("videomae", "VideoMAEForPreTraining"), |
| ("visual_bert", "VisualBertForPreTraining"), |
| ("vit_mae", "ViTMAEForPreTraining"), |
| ("wav2vec2", "Wav2Vec2ForPreTraining"), |
| ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"), |
| ("xlm", "XLMWithLMHeadModel"), |
| ("xlm-roberta", "XLMRobertaForMaskedLM"), |
| ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), |
| ("xlnet", "XLNetLMHeadModel"), |
| ("xmod", "XmodForMaskedLM"), |
| ] |
| ) |
|
|
| MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "AlbertForMaskedLM"), |
| ("bart", "BartForConditionalGeneration"), |
| ("bert", "BertForMaskedLM"), |
| ("big_bird", "BigBirdForMaskedLM"), |
| ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), |
| ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), |
| ("bloom", "BloomForCausalLM"), |
| ("camembert", "CamembertForMaskedLM"), |
| ("codegen", "CodeGenForCausalLM"), |
| ("convbert", "ConvBertForMaskedLM"), |
| ("cpmant", "CpmAntForCausalLM"), |
| ("ctrl", "CTRLLMHeadModel"), |
| ("data2vec-text", "Data2VecTextForMaskedLM"), |
| ("deberta", "DebertaForMaskedLM"), |
| ("deberta-v2", "DebertaV2ForMaskedLM"), |
| ("distilbert", "DistilBertForMaskedLM"), |
| ("electra", "ElectraForMaskedLM"), |
| ("encoder-decoder", "EncoderDecoderModel"), |
| ("ernie", "ErnieForMaskedLM"), |
| ("esm", "EsmForMaskedLM"), |
| ("flaubert", "FlaubertWithLMHeadModel"), |
| ("fnet", "FNetForMaskedLM"), |
| ("fsmt", "FSMTForConditionalGeneration"), |
| ("funnel", "FunnelForMaskedLM"), |
| ("git", "GitForCausalLM"), |
| ("gpt-sw3", "GPT2LMHeadModel"), |
| ("gpt2", "GPT2LMHeadModel"), |
| ("gpt_bigcode", "GPTBigCodeForCausalLM"), |
| ("gpt_neo", "GPTNeoForCausalLM"), |
| ("gpt_neox", "GPTNeoXForCausalLM"), |
| ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), |
| ("gptj", "GPTJForCausalLM"), |
| ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), |
| ("ibert", "IBertForMaskedLM"), |
| ("layoutlm", "LayoutLMForMaskedLM"), |
| ("led", "LEDForConditionalGeneration"), |
| ("longformer", "LongformerForMaskedLM"), |
| ("longt5", "LongT5ForConditionalGeneration"), |
| ("luke", "LukeForMaskedLM"), |
| ("m2m_100", "M2M100ForConditionalGeneration"), |
| ("marian", "MarianMTModel"), |
| ("mega", "MegaForMaskedLM"), |
| ("megatron-bert", "MegatronBertForCausalLM"), |
| ("mobilebert", "MobileBertForMaskedLM"), |
| ("mpnet", "MPNetForMaskedLM"), |
| ("mpt", "MptForCausalLM"), |
| ("mra", "MraForMaskedLM"), |
| ("mvp", "MvpForConditionalGeneration"), |
| ("nezha", "NezhaForMaskedLM"), |
| ("nllb-moe", "NllbMoeForConditionalGeneration"), |
| ("nystromformer", "NystromformerForMaskedLM"), |
| ("openai-gpt", "OpenAIGPTLMHeadModel"), |
| ("pegasus_x", "PegasusXForConditionalGeneration"), |
| ("plbart", "PLBartForConditionalGeneration"), |
| ("pop2piano", "Pop2PianoForConditionalGeneration"), |
| ("qdqbert", "QDQBertForMaskedLM"), |
| ("reformer", "ReformerModelWithLMHead"), |
| ("rembert", "RemBertForMaskedLM"), |
| ("roberta", "RobertaForMaskedLM"), |
| ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), |
| ("roc_bert", "RoCBertForMaskedLM"), |
| ("roformer", "RoFormerForMaskedLM"), |
| ("rwkv", "RwkvForCausalLM"), |
| ("speech_to_text", "Speech2TextForConditionalGeneration"), |
| ("squeezebert", "SqueezeBertForMaskedLM"), |
| ("switch_transformers", "SwitchTransformersForConditionalGeneration"), |
| ("t5", "T5ForConditionalGeneration"), |
| ("tapas", "TapasForMaskedLM"), |
| ("transfo-xl", "TransfoXLLMHeadModel"), |
| ("wav2vec2", "Wav2Vec2ForMaskedLM"), |
| ("whisper", "WhisperForConditionalGeneration"), |
| ("xlm", "XLMWithLMHeadModel"), |
| ("xlm-roberta", "XLMRobertaForMaskedLM"), |
| ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), |
| ("xlnet", "XLNetLMHeadModel"), |
| ("xmod", "XmodForMaskedLM"), |
| ("yoso", "YosoForMaskedLM"), |
| ] |
| ) |
|
|
| MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("bart", "BartForCausalLM"), |
| ("bert", "BertLMHeadModel"), |
| ("bert-generation", "BertGenerationDecoder"), |
| ("big_bird", "BigBirdForCausalLM"), |
| ("bigbird_pegasus", "BigBirdPegasusForCausalLM"), |
| ("biogpt", "BioGptForCausalLM"), |
| ("blenderbot", "BlenderbotForCausalLM"), |
| ("blenderbot-small", "BlenderbotSmallForCausalLM"), |
| ("bloom", "BloomForCausalLM"), |
| ("camembert", "CamembertForCausalLM"), |
| ("code_llama", "LlamaForCausalLM"), |
| ("codegen", "CodeGenForCausalLM"), |
| ("cpmant", "CpmAntForCausalLM"), |
| ("ctrl", "CTRLLMHeadModel"), |
| ("data2vec-text", "Data2VecTextForCausalLM"), |
| ("electra", "ElectraForCausalLM"), |
| ("ernie", "ErnieForCausalLM"), |
| ("falcon", "FalconForCausalLM"), |
| ("git", "GitForCausalLM"), |
| ("gpt-sw3", "GPT2LMHeadModel"), |
| ("gpt2", "GPT2LMHeadModel"), |
| ("gpt_bigcode", "GPTBigCodeForCausalLM"), |
| ("gpt_neo", "GPTNeoForCausalLM"), |
| ("gpt_neox", "GPTNeoXForCausalLM"), |
| ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), |
| ("gptj", "GPTJForCausalLM"), |
| ("llama", "LlamaForCausalLM"), |
| ("marian", "MarianForCausalLM"), |
| ("mbart", "MBartForCausalLM"), |
| ("mega", "MegaForCausalLM"), |
| ("megatron-bert", "MegatronBertForCausalLM"), |
| ("mistral", "MistralForCausalLM"), |
| ("mpt", "MptForCausalLM"), |
| ("musicgen", "MusicgenForCausalLM"), |
| ("mvp", "MvpForCausalLM"), |
| ("open-llama", "OpenLlamaForCausalLM"), |
| ("openai-gpt", "OpenAIGPTLMHeadModel"), |
| ("opt", "OPTForCausalLM"), |
| ("pegasus", "PegasusForCausalLM"), |
| ("persimmon", "PersimmonForCausalLM"), |
| ("plbart", "PLBartForCausalLM"), |
| ("prophetnet", "ProphetNetForCausalLM"), |
| ("qdqbert", "QDQBertLMHeadModel"), |
| ("reformer", "ReformerModelWithLMHead"), |
| ("rembert", "RemBertForCausalLM"), |
| ("roberta", "RobertaForCausalLM"), |
| ("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"), |
| ("roc_bert", "RoCBertForCausalLM"), |
| ("roformer", "RoFormerForCausalLM"), |
| ("rwkv", "RwkvForCausalLM"), |
| ("speech_to_text_2", "Speech2Text2ForCausalLM"), |
| ("transfo-xl", "TransfoXLLMHeadModel"), |
| ("trocr", "TrOCRForCausalLM"), |
| ("xglm", "XGLMForCausalLM"), |
| ("xlm", "XLMWithLMHeadModel"), |
| ("xlm-prophetnet", "XLMProphetNetForCausalLM"), |
| ("xlm-roberta", "XLMRobertaForCausalLM"), |
| ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"), |
| ("xlnet", "XLNetLMHeadModel"), |
| ("xmod", "XmodForCausalLM"), |
| ] |
| ) |
|
|
| MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( |
| [ |
| ("deit", "DeiTForMaskedImageModeling"), |
| ("focalnet", "FocalNetForMaskedImageModeling"), |
| ("swin", "SwinForMaskedImageModeling"), |
| ("swinv2", "Swinv2ForMaskedImageModeling"), |
| ("vit", "ViTForMaskedImageModeling"), |
| ] |
| ) |
|
|
|
|
| MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( |
| |
| [ |
| ("imagegpt", "ImageGPTForCausalImageModeling"), |
| ] |
| ) |
|
|
| MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("beit", "BeitForImageClassification"), |
| ("bit", "BitForImageClassification"), |
| ("convnext", "ConvNextForImageClassification"), |
| ("convnextv2", "ConvNextV2ForImageClassification"), |
| ("cvt", "CvtForImageClassification"), |
| ("data2vec-vision", "Data2VecVisionForImageClassification"), |
| ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), |
| ("dinat", "DinatForImageClassification"), |
| ("dinov2", "Dinov2ForImageClassification"), |
| ( |
| "efficientformer", |
| ( |
| "EfficientFormerForImageClassification", |
| "EfficientFormerForImageClassificationWithTeacher", |
| ), |
| ), |
| ("efficientnet", "EfficientNetForImageClassification"), |
| ("focalnet", "FocalNetForImageClassification"), |
| ("imagegpt", "ImageGPTForImageClassification"), |
| ("levit", ("LevitForImageClassification", "LevitForImageClassificationWithTeacher")), |
| ("mobilenet_v1", "MobileNetV1ForImageClassification"), |
| ("mobilenet_v2", "MobileNetV2ForImageClassification"), |
| ("mobilevit", "MobileViTForImageClassification"), |
| ("mobilevitv2", "MobileViTV2ForImageClassification"), |
| ("nat", "NatForImageClassification"), |
| ( |
| "perceiver", |
| ( |
| "PerceiverForImageClassificationLearned", |
| "PerceiverForImageClassificationFourier", |
| "PerceiverForImageClassificationConvProcessing", |
| ), |
| ), |
| ("poolformer", "PoolFormerForImageClassification"), |
| ("pvt", "PvtForImageClassification"), |
| ("regnet", "RegNetForImageClassification"), |
| ("resnet", "ResNetForImageClassification"), |
| ("segformer", "SegformerForImageClassification"), |
| ("swiftformer", "SwiftFormerForImageClassification"), |
| ("swin", "SwinForImageClassification"), |
| ("swinv2", "Swinv2ForImageClassification"), |
| ("van", "VanForImageClassification"), |
| ("vit", "ViTForImageClassification"), |
| ("vit_hybrid", "ViTHybridForImageClassification"), |
| ("vit_msn", "ViTMSNForImageClassification"), |
| ] |
| ) |
|
|
| MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| |
| ("detr", "DetrForSegmentation"), |
| ] |
| ) |
|
|
| MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("beit", "BeitForSemanticSegmentation"), |
| ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"), |
| ("dpt", "DPTForSemanticSegmentation"), |
| ("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"), |
| ("mobilevit", "MobileViTForSemanticSegmentation"), |
| ("mobilevitv2", "MobileViTV2ForSemanticSegmentation"), |
| ("segformer", "SegformerForSemanticSegmentation"), |
| ("upernet", "UperNetForSemanticSegmentation"), |
| ] |
| ) |
|
|
| MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| |
| ("maskformer", "MaskFormerForInstanceSegmentation"), |
| ] |
| ) |
|
|
| MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("detr", "DetrForSegmentation"), |
| ("mask2former", "Mask2FormerForUniversalSegmentation"), |
| ("maskformer", "MaskFormerForInstanceSegmentation"), |
| ("oneformer", "OneFormerForUniversalSegmentation"), |
| ] |
| ) |
|
|
| MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| ("timesformer", "TimesformerForVideoClassification"), |
| ("videomae", "VideoMAEForVideoClassification"), |
| ("vivit", "VivitForVideoClassification"), |
| ] |
| ) |
|
|
| MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( |
| [ |
| ("blip", "BlipForConditionalGeneration"), |
| ("blip-2", "Blip2ForConditionalGeneration"), |
| ("git", "GitForCausalLM"), |
| ("instructblip", "InstructBlipForConditionalGeneration"), |
| ("pix2struct", "Pix2StructForConditionalGeneration"), |
| ("vision-encoder-decoder", "VisionEncoderDecoderModel"), |
| ] |
| ) |
|
|
| MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "AlbertForMaskedLM"), |
| ("bart", "BartForConditionalGeneration"), |
| ("bert", "BertForMaskedLM"), |
| ("big_bird", "BigBirdForMaskedLM"), |
| ("camembert", "CamembertForMaskedLM"), |
| ("convbert", "ConvBertForMaskedLM"), |
| ("data2vec-text", "Data2VecTextForMaskedLM"), |
| ("deberta", "DebertaForMaskedLM"), |
| ("deberta-v2", "DebertaV2ForMaskedLM"), |
| ("distilbert", "DistilBertForMaskedLM"), |
| ("electra", "ElectraForMaskedLM"), |
| ("ernie", "ErnieForMaskedLM"), |
| ("esm", "EsmForMaskedLM"), |
| ("flaubert", "FlaubertWithLMHeadModel"), |
| ("fnet", "FNetForMaskedLM"), |
| ("funnel", "FunnelForMaskedLM"), |
| ("ibert", "IBertForMaskedLM"), |
| ("layoutlm", "LayoutLMForMaskedLM"), |
| ("longformer", "LongformerForMaskedLM"), |
| ("luke", "LukeForMaskedLM"), |
| ("mbart", "MBartForConditionalGeneration"), |
| ("mega", "MegaForMaskedLM"), |
| ("megatron-bert", "MegatronBertForMaskedLM"), |
| ("mobilebert", "MobileBertForMaskedLM"), |
| ("mpnet", "MPNetForMaskedLM"), |
| ("mra", "MraForMaskedLM"), |
| ("mvp", "MvpForConditionalGeneration"), |
| ("nezha", "NezhaForMaskedLM"), |
| ("nystromformer", "NystromformerForMaskedLM"), |
| ("perceiver", "PerceiverForMaskedLM"), |
| ("qdqbert", "QDQBertForMaskedLM"), |
| ("reformer", "ReformerForMaskedLM"), |
| ("rembert", "RemBertForMaskedLM"), |
| ("roberta", "RobertaForMaskedLM"), |
| ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), |
| ("roc_bert", "RoCBertForMaskedLM"), |
| ("roformer", "RoFormerForMaskedLM"), |
| ("squeezebert", "SqueezeBertForMaskedLM"), |
| ("tapas", "TapasForMaskedLM"), |
| ("wav2vec2", "Wav2Vec2ForMaskedLM"), |
| ("xlm", "XLMWithLMHeadModel"), |
| ("xlm-roberta", "XLMRobertaForMaskedLM"), |
| ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), |
| ("xmod", "XmodForMaskedLM"), |
| ("yoso", "YosoForMaskedLM"), |
| ] |
| ) |
|
|
| MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("conditional_detr", "ConditionalDetrForObjectDetection"), |
| ("deformable_detr", "DeformableDetrForObjectDetection"), |
| ("deta", "DetaForObjectDetection"), |
| ("detr", "DetrForObjectDetection"), |
| ("table-transformer", "TableTransformerForObjectDetection"), |
| ("yolos", "YolosForObjectDetection"), |
| ] |
| ) |
|
|
| MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("owlvit", "OwlViTForObjectDetection") |
| ] |
| ) |
|
|
| MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("dpt", "DPTForDepthEstimation"), |
| ("glpn", "GLPNForDepthEstimation"), |
| ] |
| ) |
| MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("bart", "BartForConditionalGeneration"), |
| ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), |
| ("blenderbot", "BlenderbotForConditionalGeneration"), |
| ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), |
| ("encoder-decoder", "EncoderDecoderModel"), |
| ("fsmt", "FSMTForConditionalGeneration"), |
| ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), |
| ("led", "LEDForConditionalGeneration"), |
| ("longt5", "LongT5ForConditionalGeneration"), |
| ("m2m_100", "M2M100ForConditionalGeneration"), |
| ("marian", "MarianMTModel"), |
| ("mbart", "MBartForConditionalGeneration"), |
| ("mt5", "MT5ForConditionalGeneration"), |
| ("mvp", "MvpForConditionalGeneration"), |
| ("nllb-moe", "NllbMoeForConditionalGeneration"), |
| ("pegasus", "PegasusForConditionalGeneration"), |
| ("pegasus_x", "PegasusXForConditionalGeneration"), |
| ("plbart", "PLBartForConditionalGeneration"), |
| ("prophetnet", "ProphetNetForConditionalGeneration"), |
| ("switch_transformers", "SwitchTransformersForConditionalGeneration"), |
| ("t5", "T5ForConditionalGeneration"), |
| ("umt5", "UMT5ForConditionalGeneration"), |
| ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), |
| ] |
| ) |
|
|
| MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( |
| [ |
| ("pop2piano", "Pop2PianoForConditionalGeneration"), |
| ("speech-encoder-decoder", "SpeechEncoderDecoderModel"), |
| ("speech_to_text", "Speech2TextForConditionalGeneration"), |
| ("speecht5", "SpeechT5ForSpeechToText"), |
| ("whisper", "WhisperForConditionalGeneration"), |
| ] |
| ) |
|
|
| MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "AlbertForSequenceClassification"), |
| ("bart", "BartForSequenceClassification"), |
| ("bert", "BertForSequenceClassification"), |
| ("big_bird", "BigBirdForSequenceClassification"), |
| ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), |
| ("biogpt", "BioGptForSequenceClassification"), |
| ("bloom", "BloomForSequenceClassification"), |
| ("camembert", "CamembertForSequenceClassification"), |
| ("canine", "CanineForSequenceClassification"), |
| ("code_llama", "LlamaForSequenceClassification"), |
| ("convbert", "ConvBertForSequenceClassification"), |
| ("ctrl", "CTRLForSequenceClassification"), |
| ("data2vec-text", "Data2VecTextForSequenceClassification"), |
| ("deberta", "DebertaForSequenceClassification"), |
| ("deberta-v2", "DebertaV2ForSequenceClassification"), |
| ("distilbert", "DistilBertForSequenceClassification"), |
| ("electra", "ElectraForSequenceClassification"), |
| ("ernie", "ErnieForSequenceClassification"), |
| ("ernie_m", "ErnieMForSequenceClassification"), |
| ("esm", "EsmForSequenceClassification"), |
| ("falcon", "FalconForSequenceClassification"), |
| ("flaubert", "FlaubertForSequenceClassification"), |
| ("fnet", "FNetForSequenceClassification"), |
| ("funnel", "FunnelForSequenceClassification"), |
| ("gpt-sw3", "GPT2ForSequenceClassification"), |
| ("gpt2", "GPT2ForSequenceClassification"), |
| ("gpt_bigcode", "GPTBigCodeForSequenceClassification"), |
| ("gpt_neo", "GPTNeoForSequenceClassification"), |
| ("gpt_neox", "GPTNeoXForSequenceClassification"), |
| ("gptj", "GPTJForSequenceClassification"), |
| ("ibert", "IBertForSequenceClassification"), |
| ("layoutlm", "LayoutLMForSequenceClassification"), |
| ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), |
| ("layoutlmv3", "LayoutLMv3ForSequenceClassification"), |
| ("led", "LEDForSequenceClassification"), |
| ("lilt", "LiltForSequenceClassification"), |
| ("llama", "LlamaForSequenceClassification"), |
| ("longformer", "LongformerForSequenceClassification"), |
| ("luke", "LukeForSequenceClassification"), |
| ("markuplm", "MarkupLMForSequenceClassification"), |
| ("mbart", "MBartForSequenceClassification"), |
| ("mega", "MegaForSequenceClassification"), |
| ("megatron-bert", "MegatronBertForSequenceClassification"), |
| ("mistral", "MistralForSequenceClassification"), |
| ("mobilebert", "MobileBertForSequenceClassification"), |
| ("mpnet", "MPNetForSequenceClassification"), |
| ("mpt", "MptForSequenceClassification"), |
| ("mra", "MraForSequenceClassification"), |
| ("mt5", "MT5ForSequenceClassification"), |
| ("mvp", "MvpForSequenceClassification"), |
| ("nezha", "NezhaForSequenceClassification"), |
| ("nystromformer", "NystromformerForSequenceClassification"), |
| ("open-llama", "OpenLlamaForSequenceClassification"), |
| ("openai-gpt", "OpenAIGPTForSequenceClassification"), |
| ("opt", "OPTForSequenceClassification"), |
| ("perceiver", "PerceiverForSequenceClassification"), |
| ("persimmon", "PersimmonForSequenceClassification"), |
| ("plbart", "PLBartForSequenceClassification"), |
| ("qdqbert", "QDQBertForSequenceClassification"), |
| ("reformer", "ReformerForSequenceClassification"), |
| ("rembert", "RemBertForSequenceClassification"), |
| ("roberta", "RobertaForSequenceClassification"), |
| ("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"), |
| ("roc_bert", "RoCBertForSequenceClassification"), |
| ("roformer", "RoFormerForSequenceClassification"), |
| ("squeezebert", "SqueezeBertForSequenceClassification"), |
| ("t5", "T5ForSequenceClassification"), |
| ("tapas", "TapasForSequenceClassification"), |
| ("transfo-xl", "TransfoXLForSequenceClassification"), |
| ("umt5", "UMT5ForSequenceClassification"), |
| ("xlm", "XLMForSequenceClassification"), |
| ("xlm-roberta", "XLMRobertaForSequenceClassification"), |
| ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"), |
| ("xlnet", "XLNetForSequenceClassification"), |
| ("xmod", "XmodForSequenceClassification"), |
| ("yoso", "YosoForSequenceClassification"), |
| ] |
| ) |
|
|
| MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "AlbertForQuestionAnswering"), |
| ("bart", "BartForQuestionAnswering"), |
| ("bert", "BertForQuestionAnswering"), |
| ("big_bird", "BigBirdForQuestionAnswering"), |
| ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), |
| ("bloom", "BloomForQuestionAnswering"), |
| ("camembert", "CamembertForQuestionAnswering"), |
| ("canine", "CanineForQuestionAnswering"), |
| ("convbert", "ConvBertForQuestionAnswering"), |
| ("data2vec-text", "Data2VecTextForQuestionAnswering"), |
| ("deberta", "DebertaForQuestionAnswering"), |
| ("deberta-v2", "DebertaV2ForQuestionAnswering"), |
| ("distilbert", "DistilBertForQuestionAnswering"), |
| ("electra", "ElectraForQuestionAnswering"), |
| ("ernie", "ErnieForQuestionAnswering"), |
| ("ernie_m", "ErnieMForQuestionAnswering"), |
| ("falcon", "FalconForQuestionAnswering"), |
| ("flaubert", "FlaubertForQuestionAnsweringSimple"), |
| ("fnet", "FNetForQuestionAnswering"), |
| ("funnel", "FunnelForQuestionAnswering"), |
| ("gpt2", "GPT2ForQuestionAnswering"), |
| ("gpt_neo", "GPTNeoForQuestionAnswering"), |
| ("gpt_neox", "GPTNeoXForQuestionAnswering"), |
| ("gptj", "GPTJForQuestionAnswering"), |
| ("ibert", "IBertForQuestionAnswering"), |
| ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), |
| ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), |
| ("led", "LEDForQuestionAnswering"), |
| ("lilt", "LiltForQuestionAnswering"), |
| ("longformer", "LongformerForQuestionAnswering"), |
| ("luke", "LukeForQuestionAnswering"), |
| ("lxmert", "LxmertForQuestionAnswering"), |
| ("markuplm", "MarkupLMForQuestionAnswering"), |
| ("mbart", "MBartForQuestionAnswering"), |
| ("mega", "MegaForQuestionAnswering"), |
| ("megatron-bert", "MegatronBertForQuestionAnswering"), |
| ("mobilebert", "MobileBertForQuestionAnswering"), |
| ("mpnet", "MPNetForQuestionAnswering"), |
| ("mpt", "MptForQuestionAnswering"), |
| ("mra", "MraForQuestionAnswering"), |
| ("mt5", "MT5ForQuestionAnswering"), |
| ("mvp", "MvpForQuestionAnswering"), |
| ("nezha", "NezhaForQuestionAnswering"), |
| ("nystromformer", "NystromformerForQuestionAnswering"), |
| ("opt", "OPTForQuestionAnswering"), |
| ("qdqbert", "QDQBertForQuestionAnswering"), |
| ("reformer", "ReformerForQuestionAnswering"), |
| ("rembert", "RemBertForQuestionAnswering"), |
| ("roberta", "RobertaForQuestionAnswering"), |
| ("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"), |
| ("roc_bert", "RoCBertForQuestionAnswering"), |
| ("roformer", "RoFormerForQuestionAnswering"), |
| ("splinter", "SplinterForQuestionAnswering"), |
| ("squeezebert", "SqueezeBertForQuestionAnswering"), |
| ("t5", "T5ForQuestionAnswering"), |
| ("umt5", "UMT5ForQuestionAnswering"), |
| ("xlm", "XLMForQuestionAnsweringSimple"), |
| ("xlm-roberta", "XLMRobertaForQuestionAnswering"), |
| ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"), |
| ("xlnet", "XLNetForQuestionAnsweringSimple"), |
| ("xmod", "XmodForQuestionAnswering"), |
| ("yoso", "YosoForQuestionAnswering"), |
| ] |
| ) |
|
|
| MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("tapas", "TapasForQuestionAnswering"), |
| ] |
| ) |
|
|
| MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
| [ |
| ("blip-2", "Blip2ForConditionalGeneration"), |
| ("vilt", "ViltForQuestionAnswering"), |
| ] |
| ) |
|
|
| MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
| [ |
| ("layoutlm", "LayoutLMForQuestionAnswering"), |
| ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), |
| ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), |
| ] |
| ) |
|
|
| MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "AlbertForTokenClassification"), |
| ("bert", "BertForTokenClassification"), |
| ("big_bird", "BigBirdForTokenClassification"), |
| ("biogpt", "BioGptForTokenClassification"), |
| ("bloom", "BloomForTokenClassification"), |
| ("bros", "BrosForTokenClassification"), |
| ("camembert", "CamembertForTokenClassification"), |
| ("canine", "CanineForTokenClassification"), |
| ("convbert", "ConvBertForTokenClassification"), |
| ("data2vec-text", "Data2VecTextForTokenClassification"), |
| ("deberta", "DebertaForTokenClassification"), |
| ("deberta-v2", "DebertaV2ForTokenClassification"), |
| ("distilbert", "DistilBertForTokenClassification"), |
| ("electra", "ElectraForTokenClassification"), |
| ("ernie", "ErnieForTokenClassification"), |
| ("ernie_m", "ErnieMForTokenClassification"), |
| ("esm", "EsmForTokenClassification"), |
| ("falcon", "FalconForTokenClassification"), |
| ("flaubert", "FlaubertForTokenClassification"), |
| ("fnet", "FNetForTokenClassification"), |
| ("funnel", "FunnelForTokenClassification"), |
| ("gpt-sw3", "GPT2ForTokenClassification"), |
| ("gpt2", "GPT2ForTokenClassification"), |
| ("gpt_bigcode", "GPTBigCodeForTokenClassification"), |
| ("gpt_neo", "GPTNeoForTokenClassification"), |
| ("gpt_neox", "GPTNeoXForTokenClassification"), |
| ("ibert", "IBertForTokenClassification"), |
| ("layoutlm", "LayoutLMForTokenClassification"), |
| ("layoutlmv2", "LayoutLMv2ForTokenClassification"), |
| ("layoutlmv3", "LayoutLMv3ForTokenClassification"), |
| ("lilt", "LiltForTokenClassification"), |
| ("longformer", "LongformerForTokenClassification"), |
| ("luke", "LukeForTokenClassification"), |
| ("markuplm", "MarkupLMForTokenClassification"), |
| ("mega", "MegaForTokenClassification"), |
| ("megatron-bert", "MegatronBertForTokenClassification"), |
| ("mobilebert", "MobileBertForTokenClassification"), |
| ("mpnet", "MPNetForTokenClassification"), |
| ("mpt", "MptForTokenClassification"), |
| ("mra", "MraForTokenClassification"), |
| ("nezha", "NezhaForTokenClassification"), |
| ("nystromformer", "NystromformerForTokenClassification"), |
| ("qdqbert", "QDQBertForTokenClassification"), |
| ("rembert", "RemBertForTokenClassification"), |
| ("roberta", "RobertaForTokenClassification"), |
| ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"), |
| ("roc_bert", "RoCBertForTokenClassification"), |
| ("roformer", "RoFormerForTokenClassification"), |
| ("squeezebert", "SqueezeBertForTokenClassification"), |
| ("xlm", "XLMForTokenClassification"), |
| ("xlm-roberta", "XLMRobertaForTokenClassification"), |
| ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"), |
| ("xlnet", "XLNetForTokenClassification"), |
| ("xmod", "XmodForTokenClassification"), |
| ("yoso", "YosoForTokenClassification"), |
| ] |
| ) |
|
|
| MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("albert", "AlbertForMultipleChoice"), |
| ("bert", "BertForMultipleChoice"), |
| ("big_bird", "BigBirdForMultipleChoice"), |
| ("camembert", "CamembertForMultipleChoice"), |
| ("canine", "CanineForMultipleChoice"), |
| ("convbert", "ConvBertForMultipleChoice"), |
| ("data2vec-text", "Data2VecTextForMultipleChoice"), |
| ("deberta-v2", "DebertaV2ForMultipleChoice"), |
| ("distilbert", "DistilBertForMultipleChoice"), |
| ("electra", "ElectraForMultipleChoice"), |
| ("ernie", "ErnieForMultipleChoice"), |
| ("ernie_m", "ErnieMForMultipleChoice"), |
| ("flaubert", "FlaubertForMultipleChoice"), |
| ("fnet", "FNetForMultipleChoice"), |
| ("funnel", "FunnelForMultipleChoice"), |
| ("ibert", "IBertForMultipleChoice"), |
| ("longformer", "LongformerForMultipleChoice"), |
| ("luke", "LukeForMultipleChoice"), |
| ("mega", "MegaForMultipleChoice"), |
| ("megatron-bert", "MegatronBertForMultipleChoice"), |
| ("mobilebert", "MobileBertForMultipleChoice"), |
| ("mpnet", "MPNetForMultipleChoice"), |
| ("mra", "MraForMultipleChoice"), |
| ("nezha", "NezhaForMultipleChoice"), |
| ("nystromformer", "NystromformerForMultipleChoice"), |
| ("qdqbert", "QDQBertForMultipleChoice"), |
| ("rembert", "RemBertForMultipleChoice"), |
| ("roberta", "RobertaForMultipleChoice"), |
| ("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"), |
| ("roc_bert", "RoCBertForMultipleChoice"), |
| ("roformer", "RoFormerForMultipleChoice"), |
| ("squeezebert", "SqueezeBertForMultipleChoice"), |
| ("xlm", "XLMForMultipleChoice"), |
| ("xlm-roberta", "XLMRobertaForMultipleChoice"), |
| ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"), |
| ("xlnet", "XLNetForMultipleChoice"), |
| ("xmod", "XmodForMultipleChoice"), |
| ("yoso", "YosoForMultipleChoice"), |
| ] |
| ) |
|
|
| MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( |
| [ |
| ("bert", "BertForNextSentencePrediction"), |
| ("ernie", "ErnieForNextSentencePrediction"), |
| ("fnet", "FNetForNextSentencePrediction"), |
| ("megatron-bert", "MegatronBertForNextSentencePrediction"), |
| ("mobilebert", "MobileBertForNextSentencePrediction"), |
| ("nezha", "NezhaForNextSentencePrediction"), |
| ("qdqbert", "QDQBertForNextSentencePrediction"), |
| ] |
| ) |
|
|
| MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("audio-spectrogram-transformer", "ASTForAudioClassification"), |
| ("data2vec-audio", "Data2VecAudioForSequenceClassification"), |
| ("hubert", "HubertForSequenceClassification"), |
| ("sew", "SEWForSequenceClassification"), |
| ("sew-d", "SEWDForSequenceClassification"), |
| ("unispeech", "UniSpeechForSequenceClassification"), |
| ("unispeech-sat", "UniSpeechSatForSequenceClassification"), |
| ("wav2vec2", "Wav2Vec2ForSequenceClassification"), |
| ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"), |
| ("wavlm", "WavLMForSequenceClassification"), |
| ("whisper", "WhisperForAudioClassification"), |
| ] |
| ) |
|
|
| MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("data2vec-audio", "Data2VecAudioForCTC"), |
| ("hubert", "HubertForCTC"), |
| ("mctct", "MCTCTForCTC"), |
| ("sew", "SEWForCTC"), |
| ("sew-d", "SEWDForCTC"), |
| ("unispeech", "UniSpeechForCTC"), |
| ("unispeech-sat", "UniSpeechSatForCTC"), |
| ("wav2vec2", "Wav2Vec2ForCTC"), |
| ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"), |
| ("wavlm", "WavLMForCTC"), |
| ] |
| ) |
|
|
| MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"), |
| ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"), |
| ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"), |
| ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"), |
| ("wavlm", "WavLMForAudioFrameClassification"), |
| ] |
| ) |
|
|
| MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("data2vec-audio", "Data2VecAudioForXVector"), |
| ("unispeech-sat", "UniSpeechSatForXVector"), |
| ("wav2vec2", "Wav2Vec2ForXVector"), |
| ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"), |
| ("wavlm", "WavLMForXVector"), |
| ] |
| ) |
|
|
| MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("speecht5", "SpeechT5ForTextToSpeech"), |
| ] |
| ) |
|
|
| MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("bark", "BarkModel"), |
| ("musicgen", "MusicgenForConditionalGeneration"), |
| ("vits", "VitsModel"), |
| ] |
| ) |
|
|
| MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("align", "AlignModel"), |
| ("altclip", "AltCLIPModel"), |
| ("blip", "BlipModel"), |
| ("chinese_clip", "ChineseCLIPModel"), |
| ("clip", "CLIPModel"), |
| ("clipseg", "CLIPSegModel"), |
| ] |
| ) |
|
|
| MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( |
| [ |
| |
| ("bit", "BitBackbone"), |
| ("convnext", "ConvNextBackbone"), |
| ("convnextv2", "ConvNextV2Backbone"), |
| ("dinat", "DinatBackbone"), |
| ("dinov2", "Dinov2Backbone"), |
| ("focalnet", "FocalNetBackbone"), |
| ("maskformer-swin", "MaskFormerSwinBackbone"), |
| ("nat", "NatBackbone"), |
| ("resnet", "ResNetBackbone"), |
| ("swin", "SwinBackbone"), |
| ("timm_backbone", "TimmBackbone"), |
| ("vitdet", "VitDetBackbone"), |
| ] |
| ) |
|
|
| MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( |
| [ |
| ("sam", "SamModel"), |
| ] |
| ) |
|
|
| MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( |
| [ |
| ("albert", "AlbertModel"), |
| ("bert", "BertModel"), |
| ("big_bird", "BigBirdModel"), |
| ("data2vec-text", "Data2VecTextModel"), |
| ("deberta", "DebertaModel"), |
| ("deberta-v2", "DebertaV2Model"), |
| ("distilbert", "DistilBertModel"), |
| ("electra", "ElectraModel"), |
| ("flaubert", "FlaubertModel"), |
| ("ibert", "IBertModel"), |
| ("longformer", "LongformerModel"), |
| ("mobilebert", "MobileBertModel"), |
| ("mt5", "MT5EncoderModel"), |
| ("nystromformer", "NystromformerModel"), |
| ("reformer", "ReformerModel"), |
| ("rembert", "RemBertModel"), |
| ("roberta", "RobertaModel"), |
| ("roberta-prelayernorm", "RobertaPreLayerNormModel"), |
| ("roc_bert", "RoCBertModel"), |
| ("roformer", "RoFormerModel"), |
| ("squeezebert", "SqueezeBertModel"), |
| ("t5", "T5EncoderModel"), |
| ("umt5", "UMT5EncoderModel"), |
| ("xlm", "XLMModel"), |
| ("xlm-roberta", "XLMRobertaModel"), |
| ("xlm-roberta-xl", "XLMRobertaXLModel"), |
| ] |
| ) |
|
|
| MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict( |
| [ |
| ("swin2sr", "Swin2SRForImageSuperResolution"), |
| ] |
| ) |
|
|
| MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) |
| MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) |
| MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) |
| MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) |
| MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES |
| ) |
| MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) |
| MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES |
| ) |
| MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES |
| ) |
| MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) |
| MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES |
| ) |
| MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) |
| MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES |
| ) |
| MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES) |
| MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES |
| ) |
| MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES |
| ) |
| MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES |
| ) |
| MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) |
| MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES |
| ) |
| MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) |
| MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) |
| MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES |
| ) |
| MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES) |
|
|
| MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping( |
| CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES |
| ) |
|
|
| MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES) |
|
|
| MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES) |
|
|
| MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) |
|
|
| MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) |
|
|
| MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES) |
|
|
|
|
| class AutoModelForMaskGeneration(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING |
|
|
|
|
| class AutoModelForTextEncoding(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING |
|
|
|
|
| class AutoModelForImageToImage(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING |
|
|
|
|
| class AutoModel(_BaseAutoModelClass): |
| _model_mapping = MODEL_MAPPING |
|
|
|
|
| AutoModel = auto_class_update(AutoModel) |
|
|
|
|
| class AutoModelForPreTraining(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_PRETRAINING_MAPPING |
|
|
|
|
| AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") |
|
|
|
|
| |
| class _AutoModelWithLMHead(_BaseAutoModelClass): |
| _model_mapping = MODEL_WITH_LM_HEAD_MAPPING |
|
|
|
|
| _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") |
|
|
|
|
| class AutoModelForCausalLM(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING |
|
|
|
|
| AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") |
|
|
|
|
| class AutoModelForMaskedLM(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_MASKED_LM_MAPPING |
|
|
|
|
| AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") |
|
|
|
|
| class AutoModelForSeq2SeqLM(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING |
|
|
|
|
| AutoModelForSeq2SeqLM = auto_class_update( |
| AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" |
| ) |
|
|
|
|
| class AutoModelForSequenceClassification(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING |
|
|
|
|
| AutoModelForSequenceClassification = auto_class_update( |
| AutoModelForSequenceClassification, head_doc="sequence classification" |
| ) |
|
|
|
|
| class AutoModelForQuestionAnswering(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING |
|
|
|
|
| AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") |
|
|
|
|
| class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING |
|
|
|
|
| AutoModelForTableQuestionAnswering = auto_class_update( |
| AutoModelForTableQuestionAnswering, |
| head_doc="table question answering", |
| checkpoint_for_example="google/tapas-base-finetuned-wtq", |
| ) |
|
|
|
|
| class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING |
|
|
|
|
| AutoModelForVisualQuestionAnswering = auto_class_update( |
| AutoModelForVisualQuestionAnswering, |
| head_doc="visual question answering", |
| checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa", |
| ) |
|
|
|
|
| class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING |
|
|
|
|
| AutoModelForDocumentQuestionAnswering = auto_class_update( |
| AutoModelForDocumentQuestionAnswering, |
| head_doc="document question answering", |
| checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', |
| ) |
|
|
|
|
| class AutoModelForTokenClassification(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING |
|
|
|
|
| AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") |
|
|
|
|
| class AutoModelForMultipleChoice(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING |
|
|
|
|
| AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") |
|
|
|
|
| class AutoModelForNextSentencePrediction(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING |
|
|
|
|
| AutoModelForNextSentencePrediction = auto_class_update( |
| AutoModelForNextSentencePrediction, head_doc="next sentence prediction" |
| ) |
|
|
|
|
| class AutoModelForImageClassification(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING |
|
|
|
|
| AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") |
|
|
|
|
| class AutoModelForZeroShotImageClassification(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING |
|
|
|
|
| AutoModelForZeroShotImageClassification = auto_class_update( |
| AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" |
| ) |
|
|
|
|
| class AutoModelForImageSegmentation(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING |
|
|
|
|
| AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation") |
|
|
|
|
| class AutoModelForSemanticSegmentation(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING |
|
|
|
|
| AutoModelForSemanticSegmentation = auto_class_update( |
| AutoModelForSemanticSegmentation, head_doc="semantic segmentation" |
| ) |
|
|
|
|
| class AutoModelForUniversalSegmentation(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING |
|
|
|
|
| AutoModelForUniversalSegmentation = auto_class_update( |
| AutoModelForUniversalSegmentation, head_doc="universal image segmentation" |
| ) |
|
|
|
|
| class AutoModelForInstanceSegmentation(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING |
|
|
|
|
| AutoModelForInstanceSegmentation = auto_class_update( |
| AutoModelForInstanceSegmentation, head_doc="instance segmentation" |
| ) |
|
|
|
|
| class AutoModelForObjectDetection(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING |
|
|
|
|
| AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") |
|
|
|
|
| class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING |
|
|
|
|
| AutoModelForZeroShotObjectDetection = auto_class_update( |
| AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection" |
| ) |
|
|
|
|
| class AutoModelForDepthEstimation(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING |
|
|
|
|
| AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation") |
|
|
|
|
| class AutoModelForVideoClassification(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING |
|
|
|
|
| AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification") |
|
|
|
|
| class AutoModelForVision2Seq(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING |
|
|
|
|
| AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling") |
|
|
|
|
| class AutoModelForAudioClassification(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING |
|
|
|
|
| AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") |
|
|
|
|
| class AutoModelForCTC(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_CTC_MAPPING |
|
|
|
|
| AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") |
|
|
|
|
| class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING |
|
|
|
|
| AutoModelForSpeechSeq2Seq = auto_class_update( |
| AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" |
| ) |
|
|
|
|
| class AutoModelForAudioFrameClassification(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING |
|
|
|
|
| AutoModelForAudioFrameClassification = auto_class_update( |
| AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification" |
| ) |
|
|
|
|
| class AutoModelForAudioXVector(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING |
|
|
|
|
| class AutoModelForTextToSpectrogram(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING |
|
|
|
|
| class AutoModelForTextToWaveform(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING |
|
|
|
|
| class AutoBackbone(_BaseAutoBackboneClass): |
| _model_mapping = MODEL_FOR_BACKBONE_MAPPING |
|
|
|
|
| AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") |
|
|
|
|
| class AutoModelForMaskedImageModeling(_BaseAutoModelClass): |
| _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING |
|
|
|
|
| AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling") |
|
|
|
|
| class AutoModelWithLMHead(_AutoModelWithLMHead): |
| @classmethod |
| def from_config(cls, config): |
| warnings.warn( |
| "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " |
| "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " |
| "`AutoModelForSeq2SeqLM` for encoder-decoder models.", |
| FutureWarning, |
| ) |
| return super().from_config(config) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| warnings.warn( |
| "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " |
| "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " |
| "`AutoModelForSeq2SeqLM` for encoder-decoder models.", |
| FutureWarning, |
| ) |
| return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|