diff --git a/.gitattributes b/.gitattributes
index c78f1ab778a6c54558a8c3047cb60791e8c528aa..8a5dbf9b880f1b83d9812acbbd89ce7999b27b19 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -419,3 +419,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
.venv/lib/python3.11/site-packages/jiter/jiter.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/idna/__pycache__/uts46data.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/idna/__pycache__/idnadata.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
+.venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
+.venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c417fff35dbd7d116b3783bc2a5d35118ca4c950
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/pkg_resources/__pycache__/__init__.cpython-311.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae8fd4ca816177bb2c6f471e0b6c7334eb9caa4704a71b0935e8abf1ca1a36d2
+size 159566
diff --git a/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a15433d42fc8d1f144b6b9023bef629d853cc978
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/pkg_resources/_vendor/pyparsing/__pycache__/core.cpython-311.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c669d4b4c31b91773ae2bd09aa2fd0eb809698068c77850d9ca93283b9acc875
+size 277639
diff --git a/.venv/lib/python3.11/site-packages/transformers/__init__.py b/.venv/lib/python3.11/site-packages/transformers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d9134c146266c74f5b80d5f5a945a558edf49a6
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/__init__.py
@@ -0,0 +1,9436 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# When adding a new object to this init, remember to add it twice: once inside the `_import_structure` dictionary and
+# once inside the `if TYPE_CHECKING` branch. The `TYPE_CHECKING` should have import statements as usual, but they are
+# only there for type checking. The `_import_structure` is a dictionary submodule to list of object names, and is used
+# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
+# in the namespace without actually importing anything (and especially none of the backends).
+
+__version__ = "4.48.3"
+
+from typing import TYPE_CHECKING
+
+# Check the dependencies satisfy the minimal versions required.
+from . import dependency_versions_check
+from .utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_bitsandbytes_available,
+ is_essentia_available,
+ is_flax_available,
+ is_g2p_en_available,
+ is_keras_nlp_available,
+ is_librosa_available,
+ is_pretty_midi_available,
+ is_scipy_available,
+ is_sentencepiece_available,
+ is_speech_available,
+ is_tensorflow_text_available,
+ is_tf_available,
+ is_timm_available,
+ is_tokenizers_available,
+ is_torch_available,
+ is_torchaudio_available,
+ is_torchvision_available,
+ is_vision_available,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Base objects, independent of any specific backend
+_import_structure = {
+ "agents": [
+ "Agent",
+ "CodeAgent",
+ "HfApiEngine",
+ "ManagedAgent",
+ "PipelineTool",
+ "ReactAgent",
+ "ReactCodeAgent",
+ "ReactJsonAgent",
+ "Tool",
+ "Toolbox",
+ "ToolCollection",
+ "TransformersEngine",
+ "launch_gradio_demo",
+ "load_tool",
+ "stream_to_gradio",
+ "tool",
+ ],
+ "audio_utils": [],
+ "benchmark": [],
+ "commands": [],
+ "configuration_utils": ["PretrainedConfig"],
+ "convert_graph_to_onnx": [],
+ "convert_slow_tokenizers_checkpoints_to_fast": [],
+ "convert_tf_hub_seq_to_seq_bert_to_pytorch": [],
+ "data": [
+ "DataProcessor",
+ "InputExample",
+ "InputFeatures",
+ "SingleSentenceClassificationProcessor",
+ "SquadExample",
+ "SquadFeatures",
+ "SquadV1Processor",
+ "SquadV2Processor",
+ "glue_compute_metrics",
+ "glue_convert_examples_to_features",
+ "glue_output_modes",
+ "glue_processors",
+ "glue_tasks_num_labels",
+ "squad_convert_examples_to_features",
+ "xnli_compute_metrics",
+ "xnli_output_modes",
+ "xnli_processors",
+ "xnli_tasks_num_labels",
+ ],
+ "data.data_collator": [
+ "DataCollator",
+ "DataCollatorForLanguageModeling",
+ "DataCollatorForPermutationLanguageModeling",
+ "DataCollatorForSeq2Seq",
+ "DataCollatorForSOP",
+ "DataCollatorForTokenClassification",
+ "DataCollatorForWholeWordMask",
+ "DataCollatorWithFlattening",
+ "DataCollatorWithPadding",
+ "DefaultDataCollator",
+ "default_data_collator",
+ ],
+ "data.metrics": [],
+ "data.processors": [],
+ "debug_utils": [],
+ "dependency_versions_check": [],
+ "dependency_versions_table": [],
+ "dynamic_module_utils": [],
+ "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
+ "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
+ "file_utils": [],
+ "generation": [
+ "AsyncTextIteratorStreamer",
+ "CompileConfig",
+ "GenerationConfig",
+ "TextIteratorStreamer",
+ "TextStreamer",
+ "WatermarkingConfig",
+ ],
+ "hf_argparser": ["HfArgumentParser"],
+ "hyperparameter_search": [],
+ "image_transforms": [],
+ "integrations": [
+ "is_clearml_available",
+ "is_comet_available",
+ "is_dvclive_available",
+ "is_neptune_available",
+ "is_optuna_available",
+ "is_ray_available",
+ "is_ray_tune_available",
+ "is_sigopt_available",
+ "is_tensorboard_available",
+ "is_wandb_available",
+ ],
+ "loss": [],
+ "modelcard": ["ModelCard"],
+ # Losses
+ "modeling_tf_pytorch_utils": [
+ "convert_tf_weight_name_to_pt_weight_name",
+ "load_pytorch_checkpoint_in_tf2_model",
+ "load_pytorch_model_in_tf2_model",
+ "load_pytorch_weights_in_tf2_model",
+ "load_tf2_checkpoint_in_pytorch_model",
+ "load_tf2_model_in_pytorch_model",
+ "load_tf2_weights_in_pytorch_model",
+ ],
+ # Models
+ "models": [],
+ "models.albert": ["AlbertConfig"],
+ "models.align": [
+ "AlignConfig",
+ "AlignProcessor",
+ "AlignTextConfig",
+ "AlignVisionConfig",
+ ],
+ "models.altclip": [
+ "AltCLIPConfig",
+ "AltCLIPProcessor",
+ "AltCLIPTextConfig",
+ "AltCLIPVisionConfig",
+ ],
+ "models.aria": [
+ "AriaConfig",
+ "AriaProcessor",
+ "AriaTextConfig",
+ ],
+ "models.audio_spectrogram_transformer": [
+ "ASTConfig",
+ "ASTFeatureExtractor",
+ ],
+ "models.auto": [
+ "CONFIG_MAPPING",
+ "FEATURE_EXTRACTOR_MAPPING",
+ "IMAGE_PROCESSOR_MAPPING",
+ "MODEL_NAMES_MAPPING",
+ "PROCESSOR_MAPPING",
+ "TOKENIZER_MAPPING",
+ "AutoConfig",
+ "AutoFeatureExtractor",
+ "AutoImageProcessor",
+ "AutoProcessor",
+ "AutoTokenizer",
+ ],
+ "models.autoformer": ["AutoformerConfig"],
+ "models.bamba": ["BambaConfig"],
+ "models.bark": [
+ "BarkCoarseConfig",
+ "BarkConfig",
+ "BarkFineConfig",
+ "BarkProcessor",
+ "BarkSemanticConfig",
+ ],
+ "models.bart": ["BartConfig", "BartTokenizer"],
+ "models.barthez": [],
+ "models.bartpho": [],
+ "models.beit": ["BeitConfig"],
+ "models.bert": [
+ "BasicTokenizer",
+ "BertConfig",
+ "BertTokenizer",
+ "WordpieceTokenizer",
+ ],
+ "models.bert_generation": ["BertGenerationConfig"],
+ "models.bert_japanese": [
+ "BertJapaneseTokenizer",
+ "CharacterTokenizer",
+ "MecabTokenizer",
+ ],
+ "models.bertweet": ["BertweetTokenizer"],
+ "models.big_bird": ["BigBirdConfig"],
+ "models.bigbird_pegasus": ["BigBirdPegasusConfig"],
+ "models.biogpt": [
+ "BioGptConfig",
+ "BioGptTokenizer",
+ ],
+ "models.bit": ["BitConfig"],
+ "models.blenderbot": [
+ "BlenderbotConfig",
+ "BlenderbotTokenizer",
+ ],
+ "models.blenderbot_small": [
+ "BlenderbotSmallConfig",
+ "BlenderbotSmallTokenizer",
+ ],
+ "models.blip": [
+ "BlipConfig",
+ "BlipProcessor",
+ "BlipTextConfig",
+ "BlipVisionConfig",
+ ],
+ "models.blip_2": [
+ "Blip2Config",
+ "Blip2Processor",
+ "Blip2QFormerConfig",
+ "Blip2VisionConfig",
+ ],
+ "models.bloom": ["BloomConfig"],
+ "models.bridgetower": [
+ "BridgeTowerConfig",
+ "BridgeTowerProcessor",
+ "BridgeTowerTextConfig",
+ "BridgeTowerVisionConfig",
+ ],
+ "models.bros": [
+ "BrosConfig",
+ "BrosProcessor",
+ ],
+ "models.byt5": ["ByT5Tokenizer"],
+ "models.camembert": ["CamembertConfig"],
+ "models.canine": [
+ "CanineConfig",
+ "CanineTokenizer",
+ ],
+ "models.chameleon": [
+ "ChameleonConfig",
+ "ChameleonProcessor",
+ "ChameleonVQVAEConfig",
+ ],
+ "models.chinese_clip": [
+ "ChineseCLIPConfig",
+ "ChineseCLIPProcessor",
+ "ChineseCLIPTextConfig",
+ "ChineseCLIPVisionConfig",
+ ],
+ "models.clap": [
+ "ClapAudioConfig",
+ "ClapConfig",
+ "ClapProcessor",
+ "ClapTextConfig",
+ ],
+ "models.clip": [
+ "CLIPConfig",
+ "CLIPProcessor",
+ "CLIPTextConfig",
+ "CLIPTokenizer",
+ "CLIPVisionConfig",
+ ],
+ "models.clipseg": [
+ "CLIPSegConfig",
+ "CLIPSegProcessor",
+ "CLIPSegTextConfig",
+ "CLIPSegVisionConfig",
+ ],
+ "models.clvp": [
+ "ClvpConfig",
+ "ClvpDecoderConfig",
+ "ClvpEncoderConfig",
+ "ClvpFeatureExtractor",
+ "ClvpProcessor",
+ "ClvpTokenizer",
+ ],
+ "models.code_llama": [],
+ "models.codegen": [
+ "CodeGenConfig",
+ "CodeGenTokenizer",
+ ],
+ "models.cohere": ["CohereConfig"],
+ "models.cohere2": ["Cohere2Config"],
+ "models.colpali": [
+ "ColPaliConfig",
+ "ColPaliProcessor",
+ ],
+ "models.conditional_detr": ["ConditionalDetrConfig"],
+ "models.convbert": [
+ "ConvBertConfig",
+ "ConvBertTokenizer",
+ ],
+ "models.convnext": ["ConvNextConfig"],
+ "models.convnextv2": ["ConvNextV2Config"],
+ "models.cpm": [],
+ "models.cpmant": [
+ "CpmAntConfig",
+ "CpmAntTokenizer",
+ ],
+ "models.ctrl": [
+ "CTRLConfig",
+ "CTRLTokenizer",
+ ],
+ "models.cvt": ["CvtConfig"],
+ "models.dac": ["DacConfig", "DacFeatureExtractor"],
+ "models.data2vec": [
+ "Data2VecAudioConfig",
+ "Data2VecTextConfig",
+ "Data2VecVisionConfig",
+ ],
+ "models.dbrx": ["DbrxConfig"],
+ "models.deberta": [
+ "DebertaConfig",
+ "DebertaTokenizer",
+ ],
+ "models.deberta_v2": ["DebertaV2Config"],
+ "models.decision_transformer": ["DecisionTransformerConfig"],
+ "models.deformable_detr": ["DeformableDetrConfig"],
+ "models.deit": ["DeiTConfig"],
+ "models.deprecated": [],
+ "models.deprecated.bort": [],
+ "models.deprecated.deta": ["DetaConfig"],
+ "models.deprecated.efficientformer": ["EfficientFormerConfig"],
+ "models.deprecated.ernie_m": ["ErnieMConfig"],
+ "models.deprecated.gptsan_japanese": [
+ "GPTSanJapaneseConfig",
+ "GPTSanJapaneseTokenizer",
+ ],
+ "models.deprecated.graphormer": ["GraphormerConfig"],
+ "models.deprecated.jukebox": [
+ "JukeboxConfig",
+ "JukeboxPriorConfig",
+ "JukeboxTokenizer",
+ "JukeboxVQVAEConfig",
+ ],
+ "models.deprecated.mctct": [
+ "MCTCTConfig",
+ "MCTCTFeatureExtractor",
+ "MCTCTProcessor",
+ ],
+ "models.deprecated.mega": ["MegaConfig"],
+ "models.deprecated.mmbt": ["MMBTConfig"],
+ "models.deprecated.nat": ["NatConfig"],
+ "models.deprecated.nezha": ["NezhaConfig"],
+ "models.deprecated.open_llama": ["OpenLlamaConfig"],
+ "models.deprecated.qdqbert": ["QDQBertConfig"],
+ "models.deprecated.realm": [
+ "RealmConfig",
+ "RealmTokenizer",
+ ],
+ "models.deprecated.retribert": [
+ "RetriBertConfig",
+ "RetriBertTokenizer",
+ ],
+ "models.deprecated.speech_to_text_2": [
+ "Speech2Text2Config",
+ "Speech2Text2Processor",
+ "Speech2Text2Tokenizer",
+ ],
+ "models.deprecated.tapex": ["TapexTokenizer"],
+ "models.deprecated.trajectory_transformer": ["TrajectoryTransformerConfig"],
+ "models.deprecated.transfo_xl": [
+ "TransfoXLConfig",
+ "TransfoXLCorpus",
+ "TransfoXLTokenizer",
+ ],
+ "models.deprecated.tvlt": [
+ "TvltConfig",
+ "TvltFeatureExtractor",
+ "TvltProcessor",
+ ],
+ "models.deprecated.van": ["VanConfig"],
+ "models.deprecated.vit_hybrid": ["ViTHybridConfig"],
+ "models.deprecated.xlm_prophetnet": ["XLMProphetNetConfig"],
+ "models.depth_anything": ["DepthAnythingConfig"],
+ "models.detr": ["DetrConfig"],
+ "models.dialogpt": [],
+ "models.diffllama": ["DiffLlamaConfig"],
+ "models.dinat": ["DinatConfig"],
+ "models.dinov2": ["Dinov2Config"],
+ "models.dinov2_with_registers": ["Dinov2WithRegistersConfig"],
+ "models.distilbert": [
+ "DistilBertConfig",
+ "DistilBertTokenizer",
+ ],
+ "models.dit": [],
+ "models.donut": [
+ "DonutProcessor",
+ "DonutSwinConfig",
+ ],
+ "models.dpr": [
+ "DPRConfig",
+ "DPRContextEncoderTokenizer",
+ "DPRQuestionEncoderTokenizer",
+ "DPRReaderOutput",
+ "DPRReaderTokenizer",
+ ],
+ "models.dpt": ["DPTConfig"],
+ "models.efficientnet": ["EfficientNetConfig"],
+ "models.electra": [
+ "ElectraConfig",
+ "ElectraTokenizer",
+ ],
+ "models.emu3": [
+ "Emu3Config",
+ "Emu3Processor",
+ "Emu3TextConfig",
+ "Emu3VQVAEConfig",
+ ],
+ "models.encodec": [
+ "EncodecConfig",
+ "EncodecFeatureExtractor",
+ ],
+ "models.encoder_decoder": ["EncoderDecoderConfig"],
+ "models.ernie": ["ErnieConfig"],
+ "models.esm": ["EsmConfig", "EsmTokenizer"],
+ "models.falcon": ["FalconConfig"],
+ "models.falcon_mamba": ["FalconMambaConfig"],
+ "models.fastspeech2_conformer": [
+ "FastSpeech2ConformerConfig",
+ "FastSpeech2ConformerHifiGanConfig",
+ "FastSpeech2ConformerTokenizer",
+ "FastSpeech2ConformerWithHifiGanConfig",
+ ],
+ "models.flaubert": ["FlaubertConfig", "FlaubertTokenizer"],
+ "models.flava": [
+ "FlavaConfig",
+ "FlavaImageCodebookConfig",
+ "FlavaImageConfig",
+ "FlavaMultimodalConfig",
+ "FlavaTextConfig",
+ ],
+ "models.fnet": ["FNetConfig"],
+ "models.focalnet": ["FocalNetConfig"],
+ "models.fsmt": [
+ "FSMTConfig",
+ "FSMTTokenizer",
+ ],
+ "models.funnel": [
+ "FunnelConfig",
+ "FunnelTokenizer",
+ ],
+ "models.fuyu": ["FuyuConfig"],
+ "models.gemma": ["GemmaConfig"],
+ "models.gemma2": ["Gemma2Config"],
+ "models.git": [
+ "GitConfig",
+ "GitProcessor",
+ "GitVisionConfig",
+ ],
+ "models.glm": ["GlmConfig"],
+ "models.glpn": ["GLPNConfig"],
+ "models.gpt2": [
+ "GPT2Config",
+ "GPT2Tokenizer",
+ ],
+ "models.gpt_bigcode": ["GPTBigCodeConfig"],
+ "models.gpt_neo": ["GPTNeoConfig"],
+ "models.gpt_neox": ["GPTNeoXConfig"],
+ "models.gpt_neox_japanese": ["GPTNeoXJapaneseConfig"],
+ "models.gpt_sw3": [],
+ "models.gptj": ["GPTJConfig"],
+ "models.granite": ["GraniteConfig"],
+ "models.granitemoe": ["GraniteMoeConfig"],
+ "models.grounding_dino": [
+ "GroundingDinoConfig",
+ "GroundingDinoProcessor",
+ ],
+ "models.groupvit": [
+ "GroupViTConfig",
+ "GroupViTTextConfig",
+ "GroupViTVisionConfig",
+ ],
+ "models.herbert": ["HerbertTokenizer"],
+ "models.hiera": ["HieraConfig"],
+ "models.hubert": ["HubertConfig"],
+ "models.ibert": ["IBertConfig"],
+ "models.idefics": ["IdeficsConfig"],
+ "models.idefics2": ["Idefics2Config"],
+ "models.idefics3": ["Idefics3Config"],
+ "models.ijepa": ["IJepaConfig"],
+ "models.imagegpt": ["ImageGPTConfig"],
+ "models.informer": ["InformerConfig"],
+ "models.instructblip": [
+ "InstructBlipConfig",
+ "InstructBlipProcessor",
+ "InstructBlipQFormerConfig",
+ "InstructBlipVisionConfig",
+ ],
+ "models.instructblipvideo": [
+ "InstructBlipVideoConfig",
+ "InstructBlipVideoProcessor",
+ "InstructBlipVideoQFormerConfig",
+ "InstructBlipVideoVisionConfig",
+ ],
+ "models.jamba": ["JambaConfig"],
+ "models.jetmoe": ["JetMoeConfig"],
+ "models.kosmos2": [
+ "Kosmos2Config",
+ "Kosmos2Processor",
+ ],
+ "models.layoutlm": [
+ "LayoutLMConfig",
+ "LayoutLMTokenizer",
+ ],
+ "models.layoutlmv2": [
+ "LayoutLMv2Config",
+ "LayoutLMv2FeatureExtractor",
+ "LayoutLMv2ImageProcessor",
+ "LayoutLMv2Processor",
+ "LayoutLMv2Tokenizer",
+ ],
+ "models.layoutlmv3": [
+ "LayoutLMv3Config",
+ "LayoutLMv3FeatureExtractor",
+ "LayoutLMv3ImageProcessor",
+ "LayoutLMv3Processor",
+ "LayoutLMv3Tokenizer",
+ ],
+ "models.layoutxlm": ["LayoutXLMProcessor"],
+ "models.led": ["LEDConfig", "LEDTokenizer"],
+ "models.levit": ["LevitConfig"],
+ "models.lilt": ["LiltConfig"],
+ "models.llama": ["LlamaConfig"],
+ "models.llava": [
+ "LlavaConfig",
+ "LlavaProcessor",
+ ],
+ "models.llava_next": [
+ "LlavaNextConfig",
+ "LlavaNextProcessor",
+ ],
+ "models.llava_next_video": [
+ "LlavaNextVideoConfig",
+ "LlavaNextVideoProcessor",
+ ],
+ "models.llava_onevision": ["LlavaOnevisionConfig", "LlavaOnevisionProcessor"],
+ "models.longformer": [
+ "LongformerConfig",
+ "LongformerTokenizer",
+ ],
+ "models.longt5": ["LongT5Config"],
+ "models.luke": [
+ "LukeConfig",
+ "LukeTokenizer",
+ ],
+ "models.lxmert": [
+ "LxmertConfig",
+ "LxmertTokenizer",
+ ],
+ "models.m2m_100": ["M2M100Config"],
+ "models.mamba": ["MambaConfig"],
+ "models.mamba2": ["Mamba2Config"],
+ "models.marian": ["MarianConfig"],
+ "models.markuplm": [
+ "MarkupLMConfig",
+ "MarkupLMFeatureExtractor",
+ "MarkupLMProcessor",
+ "MarkupLMTokenizer",
+ ],
+ "models.mask2former": ["Mask2FormerConfig"],
+ "models.maskformer": [
+ "MaskFormerConfig",
+ "MaskFormerSwinConfig",
+ ],
+ "models.mbart": ["MBartConfig"],
+ "models.mbart50": [],
+ "models.megatron_bert": ["MegatronBertConfig"],
+ "models.megatron_gpt2": [],
+ "models.mgp_str": [
+ "MgpstrConfig",
+ "MgpstrProcessor",
+ "MgpstrTokenizer",
+ ],
+ "models.mimi": ["MimiConfig"],
+ "models.mistral": ["MistralConfig"],
+ "models.mixtral": ["MixtralConfig"],
+ "models.mllama": [
+ "MllamaConfig",
+ "MllamaProcessor",
+ ],
+ "models.mluke": [],
+ "models.mobilebert": [
+ "MobileBertConfig",
+ "MobileBertTokenizer",
+ ],
+ "models.mobilenet_v1": ["MobileNetV1Config"],
+ "models.mobilenet_v2": ["MobileNetV2Config"],
+ "models.mobilevit": ["MobileViTConfig"],
+ "models.mobilevitv2": ["MobileViTV2Config"],
+ "models.modernbert": ["ModernBertConfig"],
+ "models.moonshine": ["MoonshineConfig"],
+ "models.moshi": [
+ "MoshiConfig",
+ "MoshiDepthConfig",
+ ],
+ "models.mpnet": [
+ "MPNetConfig",
+ "MPNetTokenizer",
+ ],
+ "models.mpt": ["MptConfig"],
+ "models.mra": ["MraConfig"],
+ "models.mt5": ["MT5Config"],
+ "models.musicgen": [
+ "MusicgenConfig",
+ "MusicgenDecoderConfig",
+ ],
+ "models.musicgen_melody": [
+ "MusicgenMelodyConfig",
+ "MusicgenMelodyDecoderConfig",
+ ],
+ "models.mvp": ["MvpConfig", "MvpTokenizer"],
+ "models.myt5": ["MyT5Tokenizer"],
+ "models.nemotron": ["NemotronConfig"],
+ "models.nllb": [],
+ "models.nllb_moe": ["NllbMoeConfig"],
+ "models.nougat": ["NougatProcessor"],
+ "models.nystromformer": ["NystromformerConfig"],
+ "models.olmo": ["OlmoConfig"],
+ "models.olmo2": ["Olmo2Config"],
+ "models.olmoe": ["OlmoeConfig"],
+ "models.omdet_turbo": [
+ "OmDetTurboConfig",
+ "OmDetTurboProcessor",
+ ],
+ "models.oneformer": [
+ "OneFormerConfig",
+ "OneFormerProcessor",
+ ],
+ "models.openai": [
+ "OpenAIGPTConfig",
+ "OpenAIGPTTokenizer",
+ ],
+ "models.opt": ["OPTConfig"],
+ "models.owlv2": [
+ "Owlv2Config",
+ "Owlv2Processor",
+ "Owlv2TextConfig",
+ "Owlv2VisionConfig",
+ ],
+ "models.owlvit": [
+ "OwlViTConfig",
+ "OwlViTProcessor",
+ "OwlViTTextConfig",
+ "OwlViTVisionConfig",
+ ],
+ "models.paligemma": ["PaliGemmaConfig"],
+ "models.patchtsmixer": ["PatchTSMixerConfig"],
+ "models.patchtst": ["PatchTSTConfig"],
+ "models.pegasus": [
+ "PegasusConfig",
+ "PegasusTokenizer",
+ ],
+ "models.pegasus_x": ["PegasusXConfig"],
+ "models.perceiver": [
+ "PerceiverConfig",
+ "PerceiverTokenizer",
+ ],
+ "models.persimmon": ["PersimmonConfig"],
+ "models.phi": ["PhiConfig"],
+ "models.phi3": ["Phi3Config"],
+ "models.phimoe": ["PhimoeConfig"],
+ "models.phobert": ["PhobertTokenizer"],
+ "models.pix2struct": [
+ "Pix2StructConfig",
+ "Pix2StructProcessor",
+ "Pix2StructTextConfig",
+ "Pix2StructVisionConfig",
+ ],
+ "models.pixtral": ["PixtralProcessor", "PixtralVisionConfig"],
+ "models.plbart": ["PLBartConfig"],
+ "models.poolformer": ["PoolFormerConfig"],
+ "models.pop2piano": ["Pop2PianoConfig"],
+ "models.prophetnet": [
+ "ProphetNetConfig",
+ "ProphetNetTokenizer",
+ ],
+ "models.pvt": ["PvtConfig"],
+ "models.pvt_v2": ["PvtV2Config"],
+ "models.qwen2": [
+ "Qwen2Config",
+ "Qwen2Tokenizer",
+ ],
+ "models.qwen2_audio": [
+ "Qwen2AudioConfig",
+ "Qwen2AudioEncoderConfig",
+ "Qwen2AudioProcessor",
+ ],
+ "models.qwen2_moe": ["Qwen2MoeConfig"],
+ "models.qwen2_vl": [
+ "Qwen2VLConfig",
+ "Qwen2VLProcessor",
+ ],
+ "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"],
+ "models.recurrent_gemma": ["RecurrentGemmaConfig"],
+ "models.reformer": ["ReformerConfig"],
+ "models.regnet": ["RegNetConfig"],
+ "models.rembert": ["RemBertConfig"],
+ "models.resnet": ["ResNetConfig"],
+ "models.roberta": [
+ "RobertaConfig",
+ "RobertaTokenizer",
+ ],
+ "models.roberta_prelayernorm": ["RobertaPreLayerNormConfig"],
+ "models.roc_bert": [
+ "RoCBertConfig",
+ "RoCBertTokenizer",
+ ],
+ "models.roformer": [
+ "RoFormerConfig",
+ "RoFormerTokenizer",
+ ],
+ "models.rt_detr": ["RTDetrConfig", "RTDetrResNetConfig"],
+ "models.rwkv": ["RwkvConfig"],
+ "models.sam": [
+ "SamConfig",
+ "SamMaskDecoderConfig",
+ "SamProcessor",
+ "SamPromptEncoderConfig",
+ "SamVisionConfig",
+ ],
+ "models.seamless_m4t": [
+ "SeamlessM4TConfig",
+ "SeamlessM4TFeatureExtractor",
+ "SeamlessM4TProcessor",
+ ],
+ "models.seamless_m4t_v2": ["SeamlessM4Tv2Config"],
+ "models.segformer": ["SegformerConfig"],
+ "models.seggpt": ["SegGptConfig"],
+ "models.sew": ["SEWConfig"],
+ "models.sew_d": ["SEWDConfig"],
+ "models.siglip": [
+ "SiglipConfig",
+ "SiglipProcessor",
+ "SiglipTextConfig",
+ "SiglipVisionConfig",
+ ],
+ "models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
+ "models.speech_to_text": [
+ "Speech2TextConfig",
+ "Speech2TextFeatureExtractor",
+ "Speech2TextProcessor",
+ ],
+ "models.speecht5": [
+ "SpeechT5Config",
+ "SpeechT5FeatureExtractor",
+ "SpeechT5HifiGanConfig",
+ "SpeechT5Processor",
+ ],
+ "models.splinter": [
+ "SplinterConfig",
+ "SplinterTokenizer",
+ ],
+ "models.squeezebert": [
+ "SqueezeBertConfig",
+ "SqueezeBertTokenizer",
+ ],
+ "models.stablelm": ["StableLmConfig"],
+ "models.starcoder2": ["Starcoder2Config"],
+ "models.superpoint": ["SuperPointConfig"],
+ "models.swiftformer": ["SwiftFormerConfig"],
+ "models.swin": ["SwinConfig"],
+ "models.swin2sr": ["Swin2SRConfig"],
+ "models.swinv2": ["Swinv2Config"],
+ "models.switch_transformers": ["SwitchTransformersConfig"],
+ "models.t5": ["T5Config"],
+ "models.table_transformer": ["TableTransformerConfig"],
+ "models.tapas": [
+ "TapasConfig",
+ "TapasTokenizer",
+ ],
+ "models.textnet": ["TextNetConfig"],
+ "models.time_series_transformer": ["TimeSeriesTransformerConfig"],
+ "models.timesformer": ["TimesformerConfig"],
+ "models.timm_backbone": ["TimmBackboneConfig"],
+ "models.timm_wrapper": ["TimmWrapperConfig"],
+ "models.trocr": [
+ "TrOCRConfig",
+ "TrOCRProcessor",
+ ],
+ "models.tvp": [
+ "TvpConfig",
+ "TvpProcessor",
+ ],
+ "models.udop": [
+ "UdopConfig",
+ "UdopProcessor",
+ ],
+ "models.umt5": ["UMT5Config"],
+ "models.unispeech": ["UniSpeechConfig"],
+ "models.unispeech_sat": ["UniSpeechSatConfig"],
+ "models.univnet": [
+ "UnivNetConfig",
+ "UnivNetFeatureExtractor",
+ ],
+ "models.upernet": ["UperNetConfig"],
+ "models.video_llava": ["VideoLlavaConfig"],
+ "models.videomae": ["VideoMAEConfig"],
+ "models.vilt": [
+ "ViltConfig",
+ "ViltFeatureExtractor",
+ "ViltImageProcessor",
+ "ViltProcessor",
+ ],
+ "models.vipllava": ["VipLlavaConfig"],
+ "models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
+ "models.vision_text_dual_encoder": [
+ "VisionTextDualEncoderConfig",
+ "VisionTextDualEncoderProcessor",
+ ],
+ "models.visual_bert": ["VisualBertConfig"],
+ "models.vit": ["ViTConfig"],
+ "models.vit_mae": ["ViTMAEConfig"],
+ "models.vit_msn": ["ViTMSNConfig"],
+ "models.vitdet": ["VitDetConfig"],
+ "models.vitmatte": ["VitMatteConfig"],
+ "models.vitpose": ["VitPoseConfig"],
+ "models.vitpose_backbone": ["VitPoseBackboneConfig"],
+ "models.vits": [
+ "VitsConfig",
+ "VitsTokenizer",
+ ],
+ "models.vivit": ["VivitConfig"],
+ "models.wav2vec2": [
+ "Wav2Vec2Config",
+ "Wav2Vec2CTCTokenizer",
+ "Wav2Vec2FeatureExtractor",
+ "Wav2Vec2Processor",
+ "Wav2Vec2Tokenizer",
+ ],
+ "models.wav2vec2_bert": [
+ "Wav2Vec2BertConfig",
+ "Wav2Vec2BertProcessor",
+ ],
+ "models.wav2vec2_conformer": ["Wav2Vec2ConformerConfig"],
+ "models.wav2vec2_phoneme": ["Wav2Vec2PhonemeCTCTokenizer"],
+ "models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"],
+ "models.wavlm": ["WavLMConfig"],
+ "models.whisper": [
+ "WhisperConfig",
+ "WhisperFeatureExtractor",
+ "WhisperProcessor",
+ "WhisperTokenizer",
+ ],
+ "models.x_clip": [
+ "XCLIPConfig",
+ "XCLIPProcessor",
+ "XCLIPTextConfig",
+ "XCLIPVisionConfig",
+ ],
+ "models.xglm": ["XGLMConfig"],
+ "models.xlm": ["XLMConfig", "XLMTokenizer"],
+ "models.xlm_roberta": ["XLMRobertaConfig"],
+ "models.xlm_roberta_xl": ["XLMRobertaXLConfig"],
+ "models.xlnet": ["XLNetConfig"],
+ "models.xmod": ["XmodConfig"],
+ "models.yolos": ["YolosConfig"],
+ "models.yoso": ["YosoConfig"],
+ "models.zamba": ["ZambaConfig"],
+ "models.zoedepth": ["ZoeDepthConfig"],
+ "onnx": [],
+ "pipelines": [
+ "AudioClassificationPipeline",
+ "AutomaticSpeechRecognitionPipeline",
+ "CsvPipelineDataFormat",
+ "DepthEstimationPipeline",
+ "DocumentQuestionAnsweringPipeline",
+ "FeatureExtractionPipeline",
+ "FillMaskPipeline",
+ "ImageClassificationPipeline",
+ "ImageFeatureExtractionPipeline",
+ "ImageSegmentationPipeline",
+ "ImageTextToTextPipeline",
+ "ImageToImagePipeline",
+ "ImageToTextPipeline",
+ "JsonPipelineDataFormat",
+ "MaskGenerationPipeline",
+ "NerPipeline",
+ "ObjectDetectionPipeline",
+ "PipedPipelineDataFormat",
+ "Pipeline",
+ "PipelineDataFormat",
+ "QuestionAnsweringPipeline",
+ "SummarizationPipeline",
+ "TableQuestionAnsweringPipeline",
+ "Text2TextGenerationPipeline",
+ "TextClassificationPipeline",
+ "TextGenerationPipeline",
+ "TextToAudioPipeline",
+ "TokenClassificationPipeline",
+ "TranslationPipeline",
+ "VideoClassificationPipeline",
+ "VisualQuestionAnsweringPipeline",
+ "ZeroShotAudioClassificationPipeline",
+ "ZeroShotClassificationPipeline",
+ "ZeroShotImageClassificationPipeline",
+ "ZeroShotObjectDetectionPipeline",
+ "pipeline",
+ ],
+ "processing_utils": ["ProcessorMixin"],
+ "quantizers": [],
+ "testing_utils": [],
+ "tokenization_utils": ["PreTrainedTokenizer"],
+ "tokenization_utils_base": [
+ "AddedToken",
+ "BatchEncoding",
+ "CharSpan",
+ "PreTrainedTokenizerBase",
+ "SpecialTokensMixin",
+ "TokenSpan",
+ ],
+ "trainer_callback": [
+ "DefaultFlowCallback",
+ "EarlyStoppingCallback",
+ "PrinterCallback",
+ "ProgressCallback",
+ "TrainerCallback",
+ "TrainerControl",
+ "TrainerState",
+ ],
+ "trainer_utils": [
+ "EvalPrediction",
+ "IntervalStrategy",
+ "SchedulerType",
+ "enable_full_determinism",
+ "set_seed",
+ ],
+ "training_args": ["TrainingArguments"],
+ "training_args_seq2seq": ["Seq2SeqTrainingArguments"],
+ "training_args_tf": ["TFTrainingArguments"],
+ "utils": [
+ "CONFIG_NAME",
+ "MODEL_CARD_NAME",
+ "PYTORCH_PRETRAINED_BERT_CACHE",
+ "PYTORCH_TRANSFORMERS_CACHE",
+ "SPIECE_UNDERLINE",
+ "TF2_WEIGHTS_NAME",
+ "TF_WEIGHTS_NAME",
+ "TRANSFORMERS_CACHE",
+ "WEIGHTS_NAME",
+ "TensorType",
+ "add_end_docstrings",
+ "add_start_docstrings",
+ "is_apex_available",
+ "is_av_available",
+ "is_bitsandbytes_available",
+ "is_datasets_available",
+ "is_faiss_available",
+ "is_flax_available",
+ "is_keras_nlp_available",
+ "is_phonemizer_available",
+ "is_psutil_available",
+ "is_py3nvml_available",
+ "is_pyctcdecode_available",
+ "is_sacremoses_available",
+ "is_safetensors_available",
+ "is_scipy_available",
+ "is_sentencepiece_available",
+ "is_sklearn_available",
+ "is_speech_available",
+ "is_tensorflow_text_available",
+ "is_tf_available",
+ "is_timm_available",
+ "is_tokenizers_available",
+ "is_torch_available",
+ "is_torch_mlu_available",
+ "is_torch_musa_available",
+ "is_torch_neuroncore_available",
+ "is_torch_npu_available",
+ "is_torch_tpu_available",
+ "is_torchvision_available",
+ "is_torch_xla_available",
+ "is_torch_xpu_available",
+ "is_vision_available",
+ "logging",
+ ],
+ "utils.quantization_config": [
+ "AqlmConfig",
+ "AwqConfig",
+ "BitNetConfig",
+ "BitsAndBytesConfig",
+ "CompressedTensorsConfig",
+ "EetqConfig",
+ "FbgemmFp8Config",
+ "GPTQConfig",
+ "HiggsConfig",
+ "HqqConfig",
+ "QuantoConfig",
+ "TorchAoConfig",
+ "VptqConfig",
+ ],
+}
+
+# sentencepiece-backed objects
+try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_sentencepiece_objects
+
+ _import_structure["utils.dummy_sentencepiece_objects"] = [
+ name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["models.albert"].append("AlbertTokenizer")
+ _import_structure["models.barthez"].append("BarthezTokenizer")
+ _import_structure["models.bartpho"].append("BartphoTokenizer")
+ _import_structure["models.bert_generation"].append("BertGenerationTokenizer")
+ _import_structure["models.big_bird"].append("BigBirdTokenizer")
+ _import_structure["models.camembert"].append("CamembertTokenizer")
+ _import_structure["models.code_llama"].append("CodeLlamaTokenizer")
+ _import_structure["models.cpm"].append("CpmTokenizer")
+ _import_structure["models.deberta_v2"].append("DebertaV2Tokenizer")
+ _import_structure["models.deprecated.ernie_m"].append("ErnieMTokenizer")
+ _import_structure["models.deprecated.xlm_prophetnet"].append("XLMProphetNetTokenizer")
+ _import_structure["models.fnet"].append("FNetTokenizer")
+ _import_structure["models.gemma"].append("GemmaTokenizer")
+ _import_structure["models.gpt_sw3"].append("GPTSw3Tokenizer")
+ _import_structure["models.layoutxlm"].append("LayoutXLMTokenizer")
+ _import_structure["models.llama"].append("LlamaTokenizer")
+ _import_structure["models.m2m_100"].append("M2M100Tokenizer")
+ _import_structure["models.marian"].append("MarianTokenizer")
+ _import_structure["models.mbart"].append("MBartTokenizer")
+ _import_structure["models.mbart50"].append("MBart50Tokenizer")
+ _import_structure["models.mluke"].append("MLukeTokenizer")
+ _import_structure["models.mt5"].append("MT5Tokenizer")
+ _import_structure["models.nllb"].append("NllbTokenizer")
+ _import_structure["models.pegasus"].append("PegasusTokenizer")
+ _import_structure["models.plbart"].append("PLBartTokenizer")
+ _import_structure["models.reformer"].append("ReformerTokenizer")
+ _import_structure["models.rembert"].append("RemBertTokenizer")
+ _import_structure["models.seamless_m4t"].append("SeamlessM4TTokenizer")
+ _import_structure["models.siglip"].append("SiglipTokenizer")
+ _import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
+ _import_structure["models.speecht5"].append("SpeechT5Tokenizer")
+ _import_structure["models.t5"].append("T5Tokenizer")
+ _import_structure["models.udop"].append("UdopTokenizer")
+ _import_structure["models.xglm"].append("XGLMTokenizer")
+ _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
+ _import_structure["models.xlnet"].append("XLNetTokenizer")
+
+# tokenizers-backed objects
+try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_tokenizers_objects
+
+ _import_structure["utils.dummy_tokenizers_objects"] = [
+ name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
+ ]
+else:
+ # Fast tokenizers structure
+ _import_structure["models.albert"].append("AlbertTokenizerFast")
+ _import_structure["models.bart"].append("BartTokenizerFast")
+ _import_structure["models.barthez"].append("BarthezTokenizerFast")
+ _import_structure["models.bert"].append("BertTokenizerFast")
+ _import_structure["models.big_bird"].append("BigBirdTokenizerFast")
+ _import_structure["models.blenderbot"].append("BlenderbotTokenizerFast")
+ _import_structure["models.blenderbot_small"].append("BlenderbotSmallTokenizerFast")
+ _import_structure["models.bloom"].append("BloomTokenizerFast")
+ _import_structure["models.camembert"].append("CamembertTokenizerFast")
+ _import_structure["models.clip"].append("CLIPTokenizerFast")
+ _import_structure["models.code_llama"].append("CodeLlamaTokenizerFast")
+ _import_structure["models.codegen"].append("CodeGenTokenizerFast")
+ _import_structure["models.cohere"].append("CohereTokenizerFast")
+ _import_structure["models.convbert"].append("ConvBertTokenizerFast")
+ _import_structure["models.cpm"].append("CpmTokenizerFast")
+ _import_structure["models.deberta"].append("DebertaTokenizerFast")
+ _import_structure["models.deberta_v2"].append("DebertaV2TokenizerFast")
+ _import_structure["models.deprecated.realm"].append("RealmTokenizerFast")
+ _import_structure["models.deprecated.retribert"].append("RetriBertTokenizerFast")
+ _import_structure["models.distilbert"].append("DistilBertTokenizerFast")
+ _import_structure["models.dpr"].extend(
+ [
+ "DPRContextEncoderTokenizerFast",
+ "DPRQuestionEncoderTokenizerFast",
+ "DPRReaderTokenizerFast",
+ ]
+ )
+ _import_structure["models.electra"].append("ElectraTokenizerFast")
+ _import_structure["models.fnet"].append("FNetTokenizerFast")
+ _import_structure["models.funnel"].append("FunnelTokenizerFast")
+ _import_structure["models.gemma"].append("GemmaTokenizerFast")
+ _import_structure["models.gpt2"].append("GPT2TokenizerFast")
+ _import_structure["models.gpt_neox"].append("GPTNeoXTokenizerFast")
+ _import_structure["models.gpt_neox_japanese"].append("GPTNeoXJapaneseTokenizer")
+ _import_structure["models.herbert"].append("HerbertTokenizerFast")
+ _import_structure["models.layoutlm"].append("LayoutLMTokenizerFast")
+ _import_structure["models.layoutlmv2"].append("LayoutLMv2TokenizerFast")
+ _import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast")
+ _import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast")
+ _import_structure["models.led"].append("LEDTokenizerFast")
+ _import_structure["models.llama"].append("LlamaTokenizerFast")
+ _import_structure["models.longformer"].append("LongformerTokenizerFast")
+ _import_structure["models.lxmert"].append("LxmertTokenizerFast")
+ _import_structure["models.markuplm"].append("MarkupLMTokenizerFast")
+ _import_structure["models.mbart"].append("MBartTokenizerFast")
+ _import_structure["models.mbart50"].append("MBart50TokenizerFast")
+ _import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
+ _import_structure["models.mpnet"].append("MPNetTokenizerFast")
+ _import_structure["models.mt5"].append("MT5TokenizerFast")
+ _import_structure["models.mvp"].append("MvpTokenizerFast")
+ _import_structure["models.nllb"].append("NllbTokenizerFast")
+ _import_structure["models.nougat"].append("NougatTokenizerFast")
+ _import_structure["models.openai"].append("OpenAIGPTTokenizerFast")
+ _import_structure["models.pegasus"].append("PegasusTokenizerFast")
+ _import_structure["models.qwen2"].append("Qwen2TokenizerFast")
+ _import_structure["models.reformer"].append("ReformerTokenizerFast")
+ _import_structure["models.rembert"].append("RemBertTokenizerFast")
+ _import_structure["models.roberta"].append("RobertaTokenizerFast")
+ _import_structure["models.roformer"].append("RoFormerTokenizerFast")
+ _import_structure["models.seamless_m4t"].append("SeamlessM4TTokenizerFast")
+ _import_structure["models.splinter"].append("SplinterTokenizerFast")
+ _import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast")
+ _import_structure["models.t5"].append("T5TokenizerFast")
+ _import_structure["models.udop"].append("UdopTokenizerFast")
+ _import_structure["models.whisper"].append("WhisperTokenizerFast")
+ _import_structure["models.xglm"].append("XGLMTokenizerFast")
+ _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast")
+ _import_structure["models.xlnet"].append("XLNetTokenizerFast")
+ _import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"]
+
+
+try:
+ if not (is_sentencepiece_available() and is_tokenizers_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_sentencepiece_and_tokenizers_objects
+
+ _import_structure["utils.dummy_sentencepiece_and_tokenizers_objects"] = [
+ name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["convert_slow_tokenizer"] = [
+ "SLOW_TO_FAST_CONVERTERS",
+ "convert_slow_tokenizer",
+ ]
+
+# Tensorflow-text-specific objects
+try:
+ if not is_tensorflow_text_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_tensorflow_text_objects
+
+ _import_structure["utils.dummy_tensorflow_text_objects"] = [
+ name for name in dir(dummy_tensorflow_text_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["models.bert"].append("TFBertTokenizer")
+
+# keras-nlp-specific objects
+try:
+ if not is_keras_nlp_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_keras_nlp_objects
+
+ _import_structure["utils.dummy_keras_nlp_objects"] = [
+ name for name in dir(dummy_keras_nlp_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["models.gpt2"].append("TFGPT2Tokenizer")
+
+# Vision-specific objects
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_vision_objects
+
+ _import_structure["utils.dummy_vision_objects"] = [
+ name for name in dir(dummy_vision_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["image_processing_base"] = ["ImageProcessingMixin"]
+ _import_structure["image_processing_utils"] = ["BaseImageProcessor"]
+ _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
+ _import_structure["models.aria"].extend(["AriaImageProcessor"])
+ _import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"])
+ _import_structure["models.bit"].extend(["BitImageProcessor"])
+ _import_structure["models.blip"].extend(["BlipImageProcessor"])
+ _import_structure["models.bridgetower"].append("BridgeTowerImageProcessor")
+ _import_structure["models.chameleon"].append("ChameleonImageProcessor")
+ _import_structure["models.chinese_clip"].extend(["ChineseCLIPFeatureExtractor", "ChineseCLIPImageProcessor"])
+ _import_structure["models.clip"].extend(["CLIPFeatureExtractor", "CLIPImageProcessor"])
+ _import_structure["models.conditional_detr"].extend(
+ ["ConditionalDetrFeatureExtractor", "ConditionalDetrImageProcessor"]
+ )
+ _import_structure["models.convnext"].extend(["ConvNextFeatureExtractor", "ConvNextImageProcessor"])
+ _import_structure["models.deformable_detr"].extend(
+ ["DeformableDetrFeatureExtractor", "DeformableDetrImageProcessor"]
+ )
+ _import_structure["models.deit"].extend(["DeiTFeatureExtractor", "DeiTImageProcessor"])
+ _import_structure["models.deprecated.deta"].append("DetaImageProcessor")
+ _import_structure["models.deprecated.efficientformer"].append("EfficientFormerImageProcessor")
+ _import_structure["models.deprecated.tvlt"].append("TvltImageProcessor")
+ _import_structure["models.deprecated.vit_hybrid"].extend(["ViTHybridImageProcessor"])
+ _import_structure["models.detr"].extend(["DetrFeatureExtractor", "DetrImageProcessor"])
+ _import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"])
+ _import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"])
+ _import_structure["models.efficientnet"].append("EfficientNetImageProcessor")
+ _import_structure["models.emu3"].append("Emu3ImageProcessor")
+ _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
+ _import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"])
+ _import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
+ _import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"])
+ _import_structure["models.idefics"].extend(["IdeficsImageProcessor"])
+ _import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"])
+ _import_structure["models.idefics3"].extend(["Idefics3ImageProcessor"])
+ _import_structure["models.imagegpt"].extend(["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"])
+ _import_structure["models.instructblipvideo"].extend(["InstructBlipVideoImageProcessor"])
+ _import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
+ _import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"])
+ _import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"])
+ _import_structure["models.llava_next"].append("LlavaNextImageProcessor")
+ _import_structure["models.llava_next_video"].append("LlavaNextVideoImageProcessor")
+ _import_structure["models.llava_onevision"].extend(
+ ["LlavaOnevisionImageProcessor", "LlavaOnevisionVideoProcessor"]
+ )
+ _import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
+ _import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
+ _import_structure["models.mllama"].extend(["MllamaImageProcessor"])
+ _import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
+ _import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"])
+ _import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"])
+ _import_structure["models.nougat"].append("NougatImageProcessor")
+ _import_structure["models.oneformer"].extend(["OneFormerImageProcessor"])
+ _import_structure["models.owlv2"].append("Owlv2ImageProcessor")
+ _import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"])
+ _import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"])
+ _import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
+ _import_structure["models.pixtral"].append("PixtralImageProcessor")
+ _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
+ _import_structure["models.pvt"].extend(["PvtImageProcessor"])
+ _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
+ _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"])
+ _import_structure["models.sam"].extend(["SamImageProcessor"])
+ _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
+ _import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
+ _import_structure["models.siglip"].append("SiglipImageProcessor")
+ _import_structure["models.superpoint"].extend(["SuperPointImageProcessor"])
+ _import_structure["models.swin2sr"].append("Swin2SRImageProcessor")
+ _import_structure["models.textnet"].extend(["TextNetImageProcessor"])
+ _import_structure["models.tvp"].append("TvpImageProcessor")
+ _import_structure["models.video_llava"].append("VideoLlavaImageProcessor")
+ _import_structure["models.videomae"].extend(["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"])
+ _import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"])
+ _import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"])
+ _import_structure["models.vitmatte"].append("VitMatteImageProcessor")
+ _import_structure["models.vitpose"].append("VitPoseImageProcessor")
+ _import_structure["models.vivit"].append("VivitImageProcessor")
+ _import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
+ _import_structure["models.zoedepth"].append("ZoeDepthImageProcessor")
+
+try:
+ if not is_torchvision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_torchvision_objects
+
+ _import_structure["utils.dummy_torchvision_objects"] = [
+ name for name in dir(dummy_torchvision_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
+ _import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast")
+ _import_structure["models.detr"].append("DetrImageProcessorFast")
+ _import_structure["models.pixtral"].append("PixtralImageProcessorFast")
+ _import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
+ _import_structure["models.vit"].append("ViTImageProcessorFast")
+
+try:
+ if not is_torchvision_available() and not is_timm_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_timm_and_torchvision_objects
+
+ _import_structure["utils.dummy_timm_and_torchvision_objects"] = [
+ name for name in dir(dummy_timm_and_torchvision_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["models.timm_wrapper"].extend(["TimmWrapperImageProcessor"])
+
+# PyTorch-backed objects
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_pt_objects
+
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
+else:
+ _import_structure["activations"] = []
+ _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
+ _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
+ _import_structure["cache_utils"] = [
+ "Cache",
+ "CacheConfig",
+ "DynamicCache",
+ "EncoderDecoderCache",
+ "HQQQuantizedCache",
+ "HybridCache",
+ "MambaCache",
+ "OffloadedCache",
+ "OffloadedStaticCache",
+ "QuantizedCache",
+ "QuantizedCacheConfig",
+ "QuantoQuantizedCache",
+ "SinkCache",
+ "SlidingWindowCache",
+ "StaticCache",
+ ]
+ _import_structure["data.datasets"] = [
+ "GlueDataset",
+ "GlueDataTrainingArguments",
+ "LineByLineTextDataset",
+ "LineByLineWithRefDataset",
+ "LineByLineWithSOPTextDataset",
+ "SquadDataset",
+ "SquadDataTrainingArguments",
+ "TextDataset",
+ "TextDatasetForNextSentencePrediction",
+ ]
+ _import_structure["generation"].extend(
+ [
+ "AlternatingCodebooksLogitsProcessor",
+ "BayesianDetectorConfig",
+ "BayesianDetectorModel",
+ "BeamScorer",
+ "BeamSearchScorer",
+ "ClassifierFreeGuidanceLogitsProcessor",
+ "ConstrainedBeamSearchScorer",
+ "Constraint",
+ "ConstraintListState",
+ "DisjunctiveConstraint",
+ "EncoderNoRepeatNGramLogitsProcessor",
+ "EncoderRepetitionPenaltyLogitsProcessor",
+ "EosTokenCriteria",
+ "EpsilonLogitsWarper",
+ "EtaLogitsWarper",
+ "ExponentialDecayLengthPenalty",
+ "ForcedBOSTokenLogitsProcessor",
+ "ForcedEOSTokenLogitsProcessor",
+ "GenerationMixin",
+ "HammingDiversityLogitsProcessor",
+ "InfNanRemoveLogitsProcessor",
+ "LogitNormalization",
+ "LogitsProcessor",
+ "LogitsProcessorList",
+ "LogitsWarper",
+ "MaxLengthCriteria",
+ "MaxTimeCriteria",
+ "MinLengthLogitsProcessor",
+ "MinNewTokensLengthLogitsProcessor",
+ "MinPLogitsWarper",
+ "NoBadWordsLogitsProcessor",
+ "NoRepeatNGramLogitsProcessor",
+ "PhrasalConstraint",
+ "PrefixConstrainedLogitsProcessor",
+ "RepetitionPenaltyLogitsProcessor",
+ "SequenceBiasLogitsProcessor",
+ "StoppingCriteria",
+ "StoppingCriteriaList",
+ "StopStringCriteria",
+ "SuppressTokensAtBeginLogitsProcessor",
+ "SuppressTokensLogitsProcessor",
+ "SynthIDTextWatermarkDetector",
+ "SynthIDTextWatermarkingConfig",
+ "SynthIDTextWatermarkLogitsProcessor",
+ "TemperatureLogitsWarper",
+ "TopKLogitsWarper",
+ "TopPLogitsWarper",
+ "TypicalLogitsWarper",
+ "UnbatchedClassifierFreeGuidanceLogitsProcessor",
+ "WatermarkDetector",
+ "WatermarkLogitsProcessor",
+ "WhisperTimeStampLogitsProcessor",
+ ]
+ )
+
+ # PyTorch domain libraries integration
+ _import_structure["integrations.executorch"] = [
+ "TorchExportableModuleWithStaticCache",
+ "convert_and_export_with_cache",
+ ]
+
+ _import_structure["modeling_flash_attention_utils"] = []
+ _import_structure["modeling_outputs"] = []
+ _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
+ _import_structure["modeling_utils"] = ["PreTrainedModel"]
+
+ # PyTorch models structure
+
+ _import_structure["models.albert"].extend(
+ [
+ "AlbertForMaskedLM",
+ "AlbertForMultipleChoice",
+ "AlbertForPreTraining",
+ "AlbertForQuestionAnswering",
+ "AlbertForSequenceClassification",
+ "AlbertForTokenClassification",
+ "AlbertModel",
+ "AlbertPreTrainedModel",
+ "load_tf_weights_in_albert",
+ ]
+ )
+
+ _import_structure["models.align"].extend(
+ [
+ "AlignModel",
+ "AlignPreTrainedModel",
+ "AlignTextModel",
+ "AlignVisionModel",
+ ]
+ )
+ _import_structure["models.altclip"].extend(
+ [
+ "AltCLIPModel",
+ "AltCLIPPreTrainedModel",
+ "AltCLIPTextModel",
+ "AltCLIPVisionModel",
+ ]
+ )
+ _import_structure["models.aria"].extend(
+ [
+ "AriaForConditionalGeneration",
+ "AriaPreTrainedModel",
+ "AriaTextForCausalLM",
+ "AriaTextModel",
+ "AriaTextPreTrainedModel",
+ ]
+ )
+ _import_structure["models.audio_spectrogram_transformer"].extend(
+ [
+ "ASTForAudioClassification",
+ "ASTModel",
+ "ASTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.auto"].extend(
+ [
+ "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_AUDIO_XVECTOR_MAPPING",
+ "MODEL_FOR_BACKBONE_MAPPING",
+ "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
+ "MODEL_FOR_CAUSAL_LM_MAPPING",
+ "MODEL_FOR_CTC_MAPPING",
+ "MODEL_FOR_DEPTH_ESTIMATION_MAPPING",
+ "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
+ "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_IMAGE_MAPPING",
+ "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
+ "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING",
+ "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
+ "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
+ "MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
+ "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
+ "MODEL_FOR_MASKED_LM_MAPPING",
+ "MODEL_FOR_MASK_GENERATION_MAPPING",
+ "MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
+ "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
+ "MODEL_FOR_OBJECT_DETECTION_MAPPING",
+ "MODEL_FOR_PRETRAINING_MAPPING",
+ "MODEL_FOR_QUESTION_ANSWERING_MAPPING",
+ "MODEL_FOR_RETRIEVAL_MAPPING",
+ "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
+ "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
+ "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
+ "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
+ "MODEL_FOR_TEXT_ENCODING_MAPPING",
+ "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
+ "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
+ "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING",
+ "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
+ "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_VISION_2_SEQ_MAPPING",
+ "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
+ "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
+ "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
+ "MODEL_MAPPING",
+ "MODEL_WITH_LM_HEAD_MAPPING",
+ "AutoBackbone",
+ "AutoModel",
+ "AutoModelForAudioClassification",
+ "AutoModelForAudioFrameClassification",
+ "AutoModelForAudioXVector",
+ "AutoModelForCausalLM",
+ "AutoModelForCTC",
+ "AutoModelForDepthEstimation",
+ "AutoModelForDocumentQuestionAnswering",
+ "AutoModelForImageClassification",
+ "AutoModelForImageSegmentation",
+ "AutoModelForImageTextToText",
+ "AutoModelForImageToImage",
+ "AutoModelForInstanceSegmentation",
+ "AutoModelForKeypointDetection",
+ "AutoModelForMaskedImageModeling",
+ "AutoModelForMaskedLM",
+ "AutoModelForMaskGeneration",
+ "AutoModelForMultipleChoice",
+ "AutoModelForNextSentencePrediction",
+ "AutoModelForObjectDetection",
+ "AutoModelForPreTraining",
+ "AutoModelForQuestionAnswering",
+ "AutoModelForSemanticSegmentation",
+ "AutoModelForSeq2SeqLM",
+ "AutoModelForSequenceClassification",
+ "AutoModelForSpeechSeq2Seq",
+ "AutoModelForTableQuestionAnswering",
+ "AutoModelForTextEncoding",
+ "AutoModelForTextToSpectrogram",
+ "AutoModelForTextToWaveform",
+ "AutoModelForTokenClassification",
+ "AutoModelForUniversalSegmentation",
+ "AutoModelForVideoClassification",
+ "AutoModelForVision2Seq",
+ "AutoModelForVisualQuestionAnswering",
+ "AutoModelForZeroShotImageClassification",
+ "AutoModelForZeroShotObjectDetection",
+ "AutoModelWithLMHead",
+ ]
+ )
+ _import_structure["models.autoformer"].extend(
+ [
+ "AutoformerForPrediction",
+ "AutoformerModel",
+ "AutoformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.bamba"].extend(
+ [
+ "BambaForCausalLM",
+ "BambaModel",
+ "BambaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.bark"].extend(
+ [
+ "BarkCausalModel",
+ "BarkCoarseModel",
+ "BarkFineModel",
+ "BarkModel",
+ "BarkPreTrainedModel",
+ "BarkSemanticModel",
+ ]
+ )
+ _import_structure["models.bart"].extend(
+ [
+ "BartForCausalLM",
+ "BartForConditionalGeneration",
+ "BartForQuestionAnswering",
+ "BartForSequenceClassification",
+ "BartModel",
+ "BartPretrainedModel",
+ "BartPreTrainedModel",
+ "PretrainedBartModel",
+ ]
+ )
+ _import_structure["models.beit"].extend(
+ [
+ "BeitBackbone",
+ "BeitForImageClassification",
+ "BeitForMaskedImageModeling",
+ "BeitForSemanticSegmentation",
+ "BeitModel",
+ "BeitPreTrainedModel",
+ ]
+ )
+ _import_structure["models.bert"].extend(
+ [
+ "BertForMaskedLM",
+ "BertForMultipleChoice",
+ "BertForNextSentencePrediction",
+ "BertForPreTraining",
+ "BertForQuestionAnswering",
+ "BertForSequenceClassification",
+ "BertForTokenClassification",
+ "BertLMHeadModel",
+ "BertModel",
+ "BertPreTrainedModel",
+ "load_tf_weights_in_bert",
+ ]
+ )
+ _import_structure["models.bert_generation"].extend(
+ [
+ "BertGenerationDecoder",
+ "BertGenerationEncoder",
+ "BertGenerationPreTrainedModel",
+ "load_tf_weights_in_bert_generation",
+ ]
+ )
+ _import_structure["models.big_bird"].extend(
+ [
+ "BigBirdForCausalLM",
+ "BigBirdForMaskedLM",
+ "BigBirdForMultipleChoice",
+ "BigBirdForPreTraining",
+ "BigBirdForQuestionAnswering",
+ "BigBirdForSequenceClassification",
+ "BigBirdForTokenClassification",
+ "BigBirdModel",
+ "BigBirdPreTrainedModel",
+ "load_tf_weights_in_big_bird",
+ ]
+ )
+ _import_structure["models.bigbird_pegasus"].extend(
+ [
+ "BigBirdPegasusForCausalLM",
+ "BigBirdPegasusForConditionalGeneration",
+ "BigBirdPegasusForQuestionAnswering",
+ "BigBirdPegasusForSequenceClassification",
+ "BigBirdPegasusModel",
+ "BigBirdPegasusPreTrainedModel",
+ ]
+ )
+ _import_structure["models.biogpt"].extend(
+ [
+ "BioGptForCausalLM",
+ "BioGptForSequenceClassification",
+ "BioGptForTokenClassification",
+ "BioGptModel",
+ "BioGptPreTrainedModel",
+ ]
+ )
+ _import_structure["models.bit"].extend(
+ [
+ "BitBackbone",
+ "BitForImageClassification",
+ "BitModel",
+ "BitPreTrainedModel",
+ ]
+ )
+ _import_structure["models.blenderbot"].extend(
+ [
+ "BlenderbotForCausalLM",
+ "BlenderbotForConditionalGeneration",
+ "BlenderbotModel",
+ "BlenderbotPreTrainedModel",
+ ]
+ )
+ _import_structure["models.blenderbot_small"].extend(
+ [
+ "BlenderbotSmallForCausalLM",
+ "BlenderbotSmallForConditionalGeneration",
+ "BlenderbotSmallModel",
+ "BlenderbotSmallPreTrainedModel",
+ ]
+ )
+ _import_structure["models.blip"].extend(
+ [
+ "BlipForConditionalGeneration",
+ "BlipForImageTextRetrieval",
+ "BlipForQuestionAnswering",
+ "BlipModel",
+ "BlipPreTrainedModel",
+ "BlipTextModel",
+ "BlipVisionModel",
+ ]
+ )
+ _import_structure["models.blip_2"].extend(
+ [
+ "Blip2ForConditionalGeneration",
+ "Blip2ForImageTextRetrieval",
+ "Blip2Model",
+ "Blip2PreTrainedModel",
+ "Blip2QFormerModel",
+ "Blip2TextModelWithProjection",
+ "Blip2VisionModel",
+ "Blip2VisionModelWithProjection",
+ ]
+ )
+ _import_structure["models.bloom"].extend(
+ [
+ "BloomForCausalLM",
+ "BloomForQuestionAnswering",
+ "BloomForSequenceClassification",
+ "BloomForTokenClassification",
+ "BloomModel",
+ "BloomPreTrainedModel",
+ ]
+ )
+ _import_structure["models.bridgetower"].extend(
+ [
+ "BridgeTowerForContrastiveLearning",
+ "BridgeTowerForImageAndTextRetrieval",
+ "BridgeTowerForMaskedLM",
+ "BridgeTowerModel",
+ "BridgeTowerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.bros"].extend(
+ [
+ "BrosForTokenClassification",
+ "BrosModel",
+ "BrosPreTrainedModel",
+ "BrosProcessor",
+ "BrosSpadeEEForTokenClassification",
+ "BrosSpadeELForTokenClassification",
+ ]
+ )
+ _import_structure["models.camembert"].extend(
+ [
+ "CamembertForCausalLM",
+ "CamembertForMaskedLM",
+ "CamembertForMultipleChoice",
+ "CamembertForQuestionAnswering",
+ "CamembertForSequenceClassification",
+ "CamembertForTokenClassification",
+ "CamembertModel",
+ "CamembertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.canine"].extend(
+ [
+ "CanineForMultipleChoice",
+ "CanineForQuestionAnswering",
+ "CanineForSequenceClassification",
+ "CanineForTokenClassification",
+ "CanineModel",
+ "CaninePreTrainedModel",
+ "load_tf_weights_in_canine",
+ ]
+ )
+ _import_structure["models.chameleon"].extend(
+ [
+ "ChameleonForConditionalGeneration",
+ "ChameleonModel",
+ "ChameleonPreTrainedModel",
+ "ChameleonProcessor",
+ "ChameleonVQVAE",
+ ]
+ )
+ _import_structure["models.chinese_clip"].extend(
+ [
+ "ChineseCLIPModel",
+ "ChineseCLIPPreTrainedModel",
+ "ChineseCLIPTextModel",
+ "ChineseCLIPVisionModel",
+ ]
+ )
+ _import_structure["models.clap"].extend(
+ [
+ "ClapAudioModel",
+ "ClapAudioModelWithProjection",
+ "ClapFeatureExtractor",
+ "ClapModel",
+ "ClapPreTrainedModel",
+ "ClapTextModel",
+ "ClapTextModelWithProjection",
+ ]
+ )
+ _import_structure["models.clip"].extend(
+ [
+ "CLIPForImageClassification",
+ "CLIPModel",
+ "CLIPPreTrainedModel",
+ "CLIPTextModel",
+ "CLIPTextModelWithProjection",
+ "CLIPVisionModel",
+ "CLIPVisionModelWithProjection",
+ ]
+ )
+ _import_structure["models.clipseg"].extend(
+ [
+ "CLIPSegForImageSegmentation",
+ "CLIPSegModel",
+ "CLIPSegPreTrainedModel",
+ "CLIPSegTextModel",
+ "CLIPSegVisionModel",
+ ]
+ )
+ _import_structure["models.clvp"].extend(
+ [
+ "ClvpDecoder",
+ "ClvpEncoder",
+ "ClvpForCausalLM",
+ "ClvpModel",
+ "ClvpModelForConditionalGeneration",
+ "ClvpPreTrainedModel",
+ ]
+ )
+ _import_structure["models.codegen"].extend(
+ [
+ "CodeGenForCausalLM",
+ "CodeGenModel",
+ "CodeGenPreTrainedModel",
+ ]
+ )
+ _import_structure["models.cohere"].extend(["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"])
+ _import_structure["models.cohere2"].extend(["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"])
+ _import_structure["models.colpali"].extend(
+ [
+ "ColPaliForRetrieval",
+ "ColPaliPreTrainedModel",
+ ]
+ )
+ _import_structure["models.conditional_detr"].extend(
+ [
+ "ConditionalDetrForObjectDetection",
+ "ConditionalDetrForSegmentation",
+ "ConditionalDetrModel",
+ "ConditionalDetrPreTrainedModel",
+ ]
+ )
+ _import_structure["models.convbert"].extend(
+ [
+ "ConvBertForMaskedLM",
+ "ConvBertForMultipleChoice",
+ "ConvBertForQuestionAnswering",
+ "ConvBertForSequenceClassification",
+ "ConvBertForTokenClassification",
+ "ConvBertModel",
+ "ConvBertPreTrainedModel",
+ "load_tf_weights_in_convbert",
+ ]
+ )
+ _import_structure["models.convnext"].extend(
+ [
+ "ConvNextBackbone",
+ "ConvNextForImageClassification",
+ "ConvNextModel",
+ "ConvNextPreTrainedModel",
+ ]
+ )
+ _import_structure["models.convnextv2"].extend(
+ [
+ "ConvNextV2Backbone",
+ "ConvNextV2ForImageClassification",
+ "ConvNextV2Model",
+ "ConvNextV2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.cpmant"].extend(
+ [
+ "CpmAntForCausalLM",
+ "CpmAntModel",
+ "CpmAntPreTrainedModel",
+ ]
+ )
+ _import_structure["models.ctrl"].extend(
+ [
+ "CTRLForSequenceClassification",
+ "CTRLLMHeadModel",
+ "CTRLModel",
+ "CTRLPreTrainedModel",
+ ]
+ )
+ _import_structure["models.cvt"].extend(
+ [
+ "CvtForImageClassification",
+ "CvtModel",
+ "CvtPreTrainedModel",
+ ]
+ )
+ _import_structure["models.dac"].extend(
+ [
+ "DacModel",
+ "DacPreTrainedModel",
+ ]
+ )
+ _import_structure["models.data2vec"].extend(
+ [
+ "Data2VecAudioForAudioFrameClassification",
+ "Data2VecAudioForCTC",
+ "Data2VecAudioForSequenceClassification",
+ "Data2VecAudioForXVector",
+ "Data2VecAudioModel",
+ "Data2VecAudioPreTrainedModel",
+ "Data2VecTextForCausalLM",
+ "Data2VecTextForMaskedLM",
+ "Data2VecTextForMultipleChoice",
+ "Data2VecTextForQuestionAnswering",
+ "Data2VecTextForSequenceClassification",
+ "Data2VecTextForTokenClassification",
+ "Data2VecTextModel",
+ "Data2VecTextPreTrainedModel",
+ "Data2VecVisionForImageClassification",
+ "Data2VecVisionForSemanticSegmentation",
+ "Data2VecVisionModel",
+ "Data2VecVisionPreTrainedModel",
+ ]
+ )
+ _import_structure["models.dbrx"].extend(
+ [
+ "DbrxForCausalLM",
+ "DbrxModel",
+ "DbrxPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deberta"].extend(
+ [
+ "DebertaForMaskedLM",
+ "DebertaForQuestionAnswering",
+ "DebertaForSequenceClassification",
+ "DebertaForTokenClassification",
+ "DebertaModel",
+ "DebertaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deberta_v2"].extend(
+ [
+ "DebertaV2ForMaskedLM",
+ "DebertaV2ForMultipleChoice",
+ "DebertaV2ForQuestionAnswering",
+ "DebertaV2ForSequenceClassification",
+ "DebertaV2ForTokenClassification",
+ "DebertaV2Model",
+ "DebertaV2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.decision_transformer"].extend(
+ [
+ "DecisionTransformerGPT2Model",
+ "DecisionTransformerGPT2PreTrainedModel",
+ "DecisionTransformerModel",
+ "DecisionTransformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deformable_detr"].extend(
+ [
+ "DeformableDetrForObjectDetection",
+ "DeformableDetrModel",
+ "DeformableDetrPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deit"].extend(
+ [
+ "DeiTForImageClassification",
+ "DeiTForImageClassificationWithTeacher",
+ "DeiTForMaskedImageModeling",
+ "DeiTModel",
+ "DeiTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.deta"].extend(
+ [
+ "DetaForObjectDetection",
+ "DetaModel",
+ "DetaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.efficientformer"].extend(
+ [
+ "EfficientFormerForImageClassification",
+ "EfficientFormerForImageClassificationWithTeacher",
+ "EfficientFormerModel",
+ "EfficientFormerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.ernie_m"].extend(
+ [
+ "ErnieMForInformationExtraction",
+ "ErnieMForMultipleChoice",
+ "ErnieMForQuestionAnswering",
+ "ErnieMForSequenceClassification",
+ "ErnieMForTokenClassification",
+ "ErnieMModel",
+ "ErnieMPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.gptsan_japanese"].extend(
+ [
+ "GPTSanJapaneseForConditionalGeneration",
+ "GPTSanJapaneseModel",
+ "GPTSanJapanesePreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.graphormer"].extend(
+ [
+ "GraphormerForGraphClassification",
+ "GraphormerModel",
+ "GraphormerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.jukebox"].extend(
+ [
+ "JukeboxModel",
+ "JukeboxPreTrainedModel",
+ "JukeboxPrior",
+ "JukeboxVQVAE",
+ ]
+ )
+ _import_structure["models.deprecated.mctct"].extend(
+ [
+ "MCTCTForCTC",
+ "MCTCTModel",
+ "MCTCTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.mega"].extend(
+ [
+ "MegaForCausalLM",
+ "MegaForMaskedLM",
+ "MegaForMultipleChoice",
+ "MegaForQuestionAnswering",
+ "MegaForSequenceClassification",
+ "MegaForTokenClassification",
+ "MegaModel",
+ "MegaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"])
+ _import_structure["models.deprecated.nat"].extend(
+ [
+ "NatBackbone",
+ "NatForImageClassification",
+ "NatModel",
+ "NatPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.nezha"].extend(
+ [
+ "NezhaForMaskedLM",
+ "NezhaForMultipleChoice",
+ "NezhaForNextSentencePrediction",
+ "NezhaForPreTraining",
+ "NezhaForQuestionAnswering",
+ "NezhaForSequenceClassification",
+ "NezhaForTokenClassification",
+ "NezhaModel",
+ "NezhaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.open_llama"].extend(
+ [
+ "OpenLlamaForCausalLM",
+ "OpenLlamaForSequenceClassification",
+ "OpenLlamaModel",
+ "OpenLlamaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.qdqbert"].extend(
+ [
+ "QDQBertForMaskedLM",
+ "QDQBertForMultipleChoice",
+ "QDQBertForNextSentencePrediction",
+ "QDQBertForQuestionAnswering",
+ "QDQBertForSequenceClassification",
+ "QDQBertForTokenClassification",
+ "QDQBertLMHeadModel",
+ "QDQBertModel",
+ "QDQBertPreTrainedModel",
+ "load_tf_weights_in_qdqbert",
+ ]
+ )
+ _import_structure["models.deprecated.realm"].extend(
+ [
+ "RealmEmbedder",
+ "RealmForOpenQA",
+ "RealmKnowledgeAugEncoder",
+ "RealmPreTrainedModel",
+ "RealmReader",
+ "RealmRetriever",
+ "RealmScorer",
+ "load_tf_weights_in_realm",
+ ]
+ )
+ _import_structure["models.deprecated.retribert"].extend(
+ [
+ "RetriBertModel",
+ "RetriBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.speech_to_text_2"].extend(
+ ["Speech2Text2ForCausalLM", "Speech2Text2PreTrainedModel"]
+ )
+ _import_structure["models.deprecated.trajectory_transformer"].extend(
+ [
+ "TrajectoryTransformerModel",
+ "TrajectoryTransformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.transfo_xl"].extend(
+ [
+ "AdaptiveEmbedding",
+ "TransfoXLForSequenceClassification",
+ "TransfoXLLMHeadModel",
+ "TransfoXLModel",
+ "TransfoXLPreTrainedModel",
+ "load_tf_weights_in_transfo_xl",
+ ]
+ )
+ _import_structure["models.deprecated.tvlt"].extend(
+ [
+ "TvltForAudioVisualClassification",
+ "TvltForPreTraining",
+ "TvltModel",
+ "TvltPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.van"].extend(
+ [
+ "VanForImageClassification",
+ "VanModel",
+ "VanPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.vit_hybrid"].extend(
+ [
+ "ViTHybridForImageClassification",
+ "ViTHybridModel",
+ "ViTHybridPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.xlm_prophetnet"].extend(
+ [
+ "XLMProphetNetDecoder",
+ "XLMProphetNetEncoder",
+ "XLMProphetNetForCausalLM",
+ "XLMProphetNetForConditionalGeneration",
+ "XLMProphetNetModel",
+ "XLMProphetNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.depth_anything"].extend(
+ [
+ "DepthAnythingForDepthEstimation",
+ "DepthAnythingPreTrainedModel",
+ ]
+ )
+ _import_structure["models.detr"].extend(
+ [
+ "DetrForObjectDetection",
+ "DetrForSegmentation",
+ "DetrModel",
+ "DetrPreTrainedModel",
+ ]
+ )
+ _import_structure["models.diffllama"].extend(
+ [
+ "DiffLlamaForCausalLM",
+ "DiffLlamaForQuestionAnswering",
+ "DiffLlamaForSequenceClassification",
+ "DiffLlamaForTokenClassification",
+ "DiffLlamaModel",
+ "DiffLlamaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.dinat"].extend(
+ [
+ "DinatBackbone",
+ "DinatForImageClassification",
+ "DinatModel",
+ "DinatPreTrainedModel",
+ ]
+ )
+ _import_structure["models.dinov2"].extend(
+ [
+ "Dinov2Backbone",
+ "Dinov2ForImageClassification",
+ "Dinov2Model",
+ "Dinov2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.dinov2_with_registers"].extend(
+ [
+ "Dinov2WithRegistersBackbone",
+ "Dinov2WithRegistersForImageClassification",
+ "Dinov2WithRegistersModel",
+ "Dinov2WithRegistersPreTrainedModel",
+ ]
+ )
+ _import_structure["models.distilbert"].extend(
+ [
+ "DistilBertForMaskedLM",
+ "DistilBertForMultipleChoice",
+ "DistilBertForQuestionAnswering",
+ "DistilBertForSequenceClassification",
+ "DistilBertForTokenClassification",
+ "DistilBertModel",
+ "DistilBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.donut"].extend(
+ [
+ "DonutSwinModel",
+ "DonutSwinPreTrainedModel",
+ ]
+ )
+ _import_structure["models.dpr"].extend(
+ [
+ "DPRContextEncoder",
+ "DPRPretrainedContextEncoder",
+ "DPRPreTrainedModel",
+ "DPRPretrainedQuestionEncoder",
+ "DPRPretrainedReader",
+ "DPRQuestionEncoder",
+ "DPRReader",
+ ]
+ )
+ _import_structure["models.dpt"].extend(
+ [
+ "DPTForDepthEstimation",
+ "DPTForSemanticSegmentation",
+ "DPTModel",
+ "DPTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.efficientnet"].extend(
+ [
+ "EfficientNetForImageClassification",
+ "EfficientNetModel",
+ "EfficientNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.electra"].extend(
+ [
+ "ElectraForCausalLM",
+ "ElectraForMaskedLM",
+ "ElectraForMultipleChoice",
+ "ElectraForPreTraining",
+ "ElectraForQuestionAnswering",
+ "ElectraForSequenceClassification",
+ "ElectraForTokenClassification",
+ "ElectraModel",
+ "ElectraPreTrainedModel",
+ "load_tf_weights_in_electra",
+ ]
+ )
+ _import_structure["models.emu3"].extend(
+ [
+ "Emu3ForCausalLM",
+ "Emu3ForConditionalGeneration",
+ "Emu3PreTrainedModel",
+ "Emu3TextModel",
+ "Emu3VQVAE",
+ ]
+ )
+ _import_structure["models.encodec"].extend(
+ [
+ "EncodecModel",
+ "EncodecPreTrainedModel",
+ ]
+ )
+ _import_structure["models.encoder_decoder"].append("EncoderDecoderModel")
+ _import_structure["models.ernie"].extend(
+ [
+ "ErnieForCausalLM",
+ "ErnieForMaskedLM",
+ "ErnieForMultipleChoice",
+ "ErnieForNextSentencePrediction",
+ "ErnieForPreTraining",
+ "ErnieForQuestionAnswering",
+ "ErnieForSequenceClassification",
+ "ErnieForTokenClassification",
+ "ErnieModel",
+ "ErniePreTrainedModel",
+ ]
+ )
+ _import_structure["models.esm"].extend(
+ [
+ "EsmFoldPreTrainedModel",
+ "EsmForMaskedLM",
+ "EsmForProteinFolding",
+ "EsmForSequenceClassification",
+ "EsmForTokenClassification",
+ "EsmModel",
+ "EsmPreTrainedModel",
+ ]
+ )
+ _import_structure["models.falcon"].extend(
+ [
+ "FalconForCausalLM",
+ "FalconForQuestionAnswering",
+ "FalconForSequenceClassification",
+ "FalconForTokenClassification",
+ "FalconModel",
+ "FalconPreTrainedModel",
+ ]
+ )
+ _import_structure["models.falcon_mamba"].extend(
+ [
+ "FalconMambaForCausalLM",
+ "FalconMambaModel",
+ "FalconMambaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.fastspeech2_conformer"].extend(
+ [
+ "FastSpeech2ConformerHifiGan",
+ "FastSpeech2ConformerModel",
+ "FastSpeech2ConformerPreTrainedModel",
+ "FastSpeech2ConformerWithHifiGan",
+ ]
+ )
+ _import_structure["models.flaubert"].extend(
+ [
+ "FlaubertForMultipleChoice",
+ "FlaubertForQuestionAnswering",
+ "FlaubertForQuestionAnsweringSimple",
+ "FlaubertForSequenceClassification",
+ "FlaubertForTokenClassification",
+ "FlaubertModel",
+ "FlaubertPreTrainedModel",
+ "FlaubertWithLMHeadModel",
+ ]
+ )
+ _import_structure["models.flava"].extend(
+ [
+ "FlavaForPreTraining",
+ "FlavaImageCodebook",
+ "FlavaImageModel",
+ "FlavaModel",
+ "FlavaMultimodalModel",
+ "FlavaPreTrainedModel",
+ "FlavaTextModel",
+ ]
+ )
+ _import_structure["models.fnet"].extend(
+ [
+ "FNetForMaskedLM",
+ "FNetForMultipleChoice",
+ "FNetForNextSentencePrediction",
+ "FNetForPreTraining",
+ "FNetForQuestionAnswering",
+ "FNetForSequenceClassification",
+ "FNetForTokenClassification",
+ "FNetModel",
+ "FNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.focalnet"].extend(
+ [
+ "FocalNetBackbone",
+ "FocalNetForImageClassification",
+ "FocalNetForMaskedImageModeling",
+ "FocalNetModel",
+ "FocalNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.fsmt"].extend(["FSMTForConditionalGeneration", "FSMTModel", "PretrainedFSMTModel"])
+ _import_structure["models.funnel"].extend(
+ [
+ "FunnelBaseModel",
+ "FunnelForMaskedLM",
+ "FunnelForMultipleChoice",
+ "FunnelForPreTraining",
+ "FunnelForQuestionAnswering",
+ "FunnelForSequenceClassification",
+ "FunnelForTokenClassification",
+ "FunnelModel",
+ "FunnelPreTrainedModel",
+ "load_tf_weights_in_funnel",
+ ]
+ )
+ _import_structure["models.fuyu"].extend(["FuyuForCausalLM", "FuyuPreTrainedModel"])
+ _import_structure["models.gemma"].extend(
+ [
+ "GemmaForCausalLM",
+ "GemmaForSequenceClassification",
+ "GemmaForTokenClassification",
+ "GemmaModel",
+ "GemmaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.gemma2"].extend(
+ [
+ "Gemma2ForCausalLM",
+ "Gemma2ForSequenceClassification",
+ "Gemma2ForTokenClassification",
+ "Gemma2Model",
+ "Gemma2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.git"].extend(
+ [
+ "GitForCausalLM",
+ "GitModel",
+ "GitPreTrainedModel",
+ "GitVisionModel",
+ ]
+ )
+ _import_structure["models.glm"].extend(
+ [
+ "GlmForCausalLM",
+ "GlmForSequenceClassification",
+ "GlmForTokenClassification",
+ "GlmModel",
+ "GlmPreTrainedModel",
+ ]
+ )
+ _import_structure["models.glpn"].extend(
+ [
+ "GLPNForDepthEstimation",
+ "GLPNModel",
+ "GLPNPreTrainedModel",
+ ]
+ )
+ _import_structure["models.gpt2"].extend(
+ [
+ "GPT2DoubleHeadsModel",
+ "GPT2ForQuestionAnswering",
+ "GPT2ForSequenceClassification",
+ "GPT2ForTokenClassification",
+ "GPT2LMHeadModel",
+ "GPT2Model",
+ "GPT2PreTrainedModel",
+ "load_tf_weights_in_gpt2",
+ ]
+ )
+ _import_structure["models.gpt_bigcode"].extend(
+ [
+ "GPTBigCodeForCausalLM",
+ "GPTBigCodeForSequenceClassification",
+ "GPTBigCodeForTokenClassification",
+ "GPTBigCodeModel",
+ "GPTBigCodePreTrainedModel",
+ ]
+ )
+ _import_structure["models.gpt_neo"].extend(
+ [
+ "GPTNeoForCausalLM",
+ "GPTNeoForQuestionAnswering",
+ "GPTNeoForSequenceClassification",
+ "GPTNeoForTokenClassification",
+ "GPTNeoModel",
+ "GPTNeoPreTrainedModel",
+ "load_tf_weights_in_gpt_neo",
+ ]
+ )
+ _import_structure["models.gpt_neox"].extend(
+ [
+ "GPTNeoXForCausalLM",
+ "GPTNeoXForQuestionAnswering",
+ "GPTNeoXForSequenceClassification",
+ "GPTNeoXForTokenClassification",
+ "GPTNeoXModel",
+ "GPTNeoXPreTrainedModel",
+ ]
+ )
+ _import_structure["models.gpt_neox_japanese"].extend(
+ [
+ "GPTNeoXJapaneseForCausalLM",
+ "GPTNeoXJapaneseModel",
+ "GPTNeoXJapanesePreTrainedModel",
+ ]
+ )
+ _import_structure["models.gptj"].extend(
+ [
+ "GPTJForCausalLM",
+ "GPTJForQuestionAnswering",
+ "GPTJForSequenceClassification",
+ "GPTJModel",
+ "GPTJPreTrainedModel",
+ ]
+ )
+ _import_structure["models.granite"].extend(
+ [
+ "GraniteForCausalLM",
+ "GraniteModel",
+ "GranitePreTrainedModel",
+ ]
+ )
+ _import_structure["models.granitemoe"].extend(
+ [
+ "GraniteMoeForCausalLM",
+ "GraniteMoeModel",
+ "GraniteMoePreTrainedModel",
+ ]
+ )
+ _import_structure["models.grounding_dino"].extend(
+ [
+ "GroundingDinoForObjectDetection",
+ "GroundingDinoModel",
+ "GroundingDinoPreTrainedModel",
+ ]
+ )
+ _import_structure["models.groupvit"].extend(
+ [
+ "GroupViTModel",
+ "GroupViTPreTrainedModel",
+ "GroupViTTextModel",
+ "GroupViTVisionModel",
+ ]
+ )
+ _import_structure["models.hiera"].extend(
+ [
+ "HieraBackbone",
+ "HieraForImageClassification",
+ "HieraForPreTraining",
+ "HieraModel",
+ "HieraPreTrainedModel",
+ ]
+ )
+ _import_structure["models.hubert"].extend(
+ [
+ "HubertForCTC",
+ "HubertForSequenceClassification",
+ "HubertModel",
+ "HubertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.ibert"].extend(
+ [
+ "IBertForMaskedLM",
+ "IBertForMultipleChoice",
+ "IBertForQuestionAnswering",
+ "IBertForSequenceClassification",
+ "IBertForTokenClassification",
+ "IBertModel",
+ "IBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.idefics"].extend(
+ [
+ "IdeficsForVisionText2Text",
+ "IdeficsModel",
+ "IdeficsPreTrainedModel",
+ "IdeficsProcessor",
+ ]
+ )
+ _import_structure["models.idefics2"].extend(
+ [
+ "Idefics2ForConditionalGeneration",
+ "Idefics2Model",
+ "Idefics2PreTrainedModel",
+ "Idefics2Processor",
+ ]
+ )
+ _import_structure["models.idefics3"].extend(
+ [
+ "Idefics3ForConditionalGeneration",
+ "Idefics3Model",
+ "Idefics3PreTrainedModel",
+ "Idefics3Processor",
+ "Idefics3VisionConfig",
+ "Idefics3VisionTransformer",
+ ]
+ )
+ _import_structure["models.ijepa"].extend(
+ [
+ "IJepaForImageClassification",
+ "IJepaModel",
+ "IJepaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.imagegpt"].extend(
+ [
+ "ImageGPTForCausalImageModeling",
+ "ImageGPTForImageClassification",
+ "ImageGPTModel",
+ "ImageGPTPreTrainedModel",
+ "load_tf_weights_in_imagegpt",
+ ]
+ )
+ _import_structure["models.informer"].extend(
+ [
+ "InformerForPrediction",
+ "InformerModel",
+ "InformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.instructblip"].extend(
+ [
+ "InstructBlipForConditionalGeneration",
+ "InstructBlipPreTrainedModel",
+ "InstructBlipQFormerModel",
+ "InstructBlipVisionModel",
+ ]
+ )
+ _import_structure["models.instructblipvideo"].extend(
+ [
+ "InstructBlipVideoForConditionalGeneration",
+ "InstructBlipVideoPreTrainedModel",
+ "InstructBlipVideoQFormerModel",
+ "InstructBlipVideoVisionModel",
+ ]
+ )
+ _import_structure["models.jamba"].extend(
+ [
+ "JambaForCausalLM",
+ "JambaForSequenceClassification",
+ "JambaModel",
+ "JambaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.jetmoe"].extend(
+ [
+ "JetMoeForCausalLM",
+ "JetMoeForSequenceClassification",
+ "JetMoeModel",
+ "JetMoePreTrainedModel",
+ ]
+ )
+ _import_structure["models.kosmos2"].extend(
+ [
+ "Kosmos2ForConditionalGeneration",
+ "Kosmos2Model",
+ "Kosmos2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.layoutlm"].extend(
+ [
+ "LayoutLMForMaskedLM",
+ "LayoutLMForQuestionAnswering",
+ "LayoutLMForSequenceClassification",
+ "LayoutLMForTokenClassification",
+ "LayoutLMModel",
+ "LayoutLMPreTrainedModel",
+ ]
+ )
+ _import_structure["models.layoutlmv2"].extend(
+ [
+ "LayoutLMv2ForQuestionAnswering",
+ "LayoutLMv2ForSequenceClassification",
+ "LayoutLMv2ForTokenClassification",
+ "LayoutLMv2Model",
+ "LayoutLMv2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.layoutlmv3"].extend(
+ [
+ "LayoutLMv3ForQuestionAnswering",
+ "LayoutLMv3ForSequenceClassification",
+ "LayoutLMv3ForTokenClassification",
+ "LayoutLMv3Model",
+ "LayoutLMv3PreTrainedModel",
+ ]
+ )
+ _import_structure["models.led"].extend(
+ [
+ "LEDForConditionalGeneration",
+ "LEDForQuestionAnswering",
+ "LEDForSequenceClassification",
+ "LEDModel",
+ "LEDPreTrainedModel",
+ ]
+ )
+ _import_structure["models.levit"].extend(
+ [
+ "LevitForImageClassification",
+ "LevitForImageClassificationWithTeacher",
+ "LevitModel",
+ "LevitPreTrainedModel",
+ ]
+ )
+ _import_structure["models.lilt"].extend(
+ [
+ "LiltForQuestionAnswering",
+ "LiltForSequenceClassification",
+ "LiltForTokenClassification",
+ "LiltModel",
+ "LiltPreTrainedModel",
+ ]
+ )
+ _import_structure["models.llama"].extend(
+ [
+ "LlamaForCausalLM",
+ "LlamaForQuestionAnswering",
+ "LlamaForSequenceClassification",
+ "LlamaForTokenClassification",
+ "LlamaModel",
+ "LlamaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.llava"].extend(
+ [
+ "LlavaForConditionalGeneration",
+ "LlavaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.llava_next"].extend(
+ [
+ "LlavaNextForConditionalGeneration",
+ "LlavaNextPreTrainedModel",
+ ]
+ )
+ _import_structure["models.llava_next_video"].extend(
+ [
+ "LlavaNextVideoForConditionalGeneration",
+ "LlavaNextVideoPreTrainedModel",
+ ]
+ )
+ _import_structure["models.llava_onevision"].extend(
+ [
+ "LlavaOnevisionForConditionalGeneration",
+ "LlavaOnevisionPreTrainedModel",
+ ]
+ )
+ _import_structure["models.longformer"].extend(
+ [
+ "LongformerForMaskedLM",
+ "LongformerForMultipleChoice",
+ "LongformerForQuestionAnswering",
+ "LongformerForSequenceClassification",
+ "LongformerForTokenClassification",
+ "LongformerModel",
+ "LongformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.longt5"].extend(
+ [
+ "LongT5EncoderModel",
+ "LongT5ForConditionalGeneration",
+ "LongT5Model",
+ "LongT5PreTrainedModel",
+ ]
+ )
+ _import_structure["models.luke"].extend(
+ [
+ "LukeForEntityClassification",
+ "LukeForEntityPairClassification",
+ "LukeForEntitySpanClassification",
+ "LukeForMaskedLM",
+ "LukeForMultipleChoice",
+ "LukeForQuestionAnswering",
+ "LukeForSequenceClassification",
+ "LukeForTokenClassification",
+ "LukeModel",
+ "LukePreTrainedModel",
+ ]
+ )
+ _import_structure["models.lxmert"].extend(
+ [
+ "LxmertEncoder",
+ "LxmertForPreTraining",
+ "LxmertForQuestionAnswering",
+ "LxmertModel",
+ "LxmertPreTrainedModel",
+ "LxmertVisualFeatureEncoder",
+ ]
+ )
+ _import_structure["models.m2m_100"].extend(
+ [
+ "M2M100ForConditionalGeneration",
+ "M2M100Model",
+ "M2M100PreTrainedModel",
+ ]
+ )
+ _import_structure["models.mamba"].extend(
+ [
+ "MambaForCausalLM",
+ "MambaModel",
+ "MambaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mamba2"].extend(
+ [
+ "Mamba2ForCausalLM",
+ "Mamba2Model",
+ "Mamba2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.marian"].extend(
+ ["MarianForCausalLM", "MarianModel", "MarianMTModel", "MarianPreTrainedModel"]
+ )
+ _import_structure["models.markuplm"].extend(
+ [
+ "MarkupLMForQuestionAnswering",
+ "MarkupLMForSequenceClassification",
+ "MarkupLMForTokenClassification",
+ "MarkupLMModel",
+ "MarkupLMPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mask2former"].extend(
+ [
+ "Mask2FormerForUniversalSegmentation",
+ "Mask2FormerModel",
+ "Mask2FormerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.maskformer"].extend(
+ [
+ "MaskFormerForInstanceSegmentation",
+ "MaskFormerModel",
+ "MaskFormerPreTrainedModel",
+ "MaskFormerSwinBackbone",
+ ]
+ )
+ _import_structure["models.mbart"].extend(
+ [
+ "MBartForCausalLM",
+ "MBartForConditionalGeneration",
+ "MBartForQuestionAnswering",
+ "MBartForSequenceClassification",
+ "MBartModel",
+ "MBartPreTrainedModel",
+ ]
+ )
+ _import_structure["models.megatron_bert"].extend(
+ [
+ "MegatronBertForCausalLM",
+ "MegatronBertForMaskedLM",
+ "MegatronBertForMultipleChoice",
+ "MegatronBertForNextSentencePrediction",
+ "MegatronBertForPreTraining",
+ "MegatronBertForQuestionAnswering",
+ "MegatronBertForSequenceClassification",
+ "MegatronBertForTokenClassification",
+ "MegatronBertModel",
+ "MegatronBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mgp_str"].extend(
+ [
+ "MgpstrForSceneTextRecognition",
+ "MgpstrModel",
+ "MgpstrPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mimi"].extend(
+ [
+ "MimiModel",
+ "MimiPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mistral"].extend(
+ [
+ "MistralForCausalLM",
+ "MistralForQuestionAnswering",
+ "MistralForSequenceClassification",
+ "MistralForTokenClassification",
+ "MistralModel",
+ "MistralPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mixtral"].extend(
+ [
+ "MixtralForCausalLM",
+ "MixtralForQuestionAnswering",
+ "MixtralForSequenceClassification",
+ "MixtralForTokenClassification",
+ "MixtralModel",
+ "MixtralPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mllama"].extend(
+ [
+ "MllamaForCausalLM",
+ "MllamaForConditionalGeneration",
+ "MllamaPreTrainedModel",
+ "MllamaProcessor",
+ "MllamaTextModel",
+ "MllamaVisionModel",
+ ]
+ )
+ _import_structure["models.mobilebert"].extend(
+ [
+ "MobileBertForMaskedLM",
+ "MobileBertForMultipleChoice",
+ "MobileBertForNextSentencePrediction",
+ "MobileBertForPreTraining",
+ "MobileBertForQuestionAnswering",
+ "MobileBertForSequenceClassification",
+ "MobileBertForTokenClassification",
+ "MobileBertModel",
+ "MobileBertPreTrainedModel",
+ "load_tf_weights_in_mobilebert",
+ ]
+ )
+ _import_structure["models.mobilenet_v1"].extend(
+ [
+ "MobileNetV1ForImageClassification",
+ "MobileNetV1Model",
+ "MobileNetV1PreTrainedModel",
+ "load_tf_weights_in_mobilenet_v1",
+ ]
+ )
+ _import_structure["models.mobilenet_v2"].extend(
+ [
+ "MobileNetV2ForImageClassification",
+ "MobileNetV2ForSemanticSegmentation",
+ "MobileNetV2Model",
+ "MobileNetV2PreTrainedModel",
+ "load_tf_weights_in_mobilenet_v2",
+ ]
+ )
+ _import_structure["models.mobilevit"].extend(
+ [
+ "MobileViTForImageClassification",
+ "MobileViTForSemanticSegmentation",
+ "MobileViTModel",
+ "MobileViTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mobilevitv2"].extend(
+ [
+ "MobileViTV2ForImageClassification",
+ "MobileViTV2ForSemanticSegmentation",
+ "MobileViTV2Model",
+ "MobileViTV2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.modernbert"].extend(
+ [
+ "ModernBertForMaskedLM",
+ "ModernBertForSequenceClassification",
+ "ModernBertForTokenClassification",
+ "ModernBertModel",
+ "ModernBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.moonshine"].extend(
+ [
+ "MoonshineForConditionalGeneration",
+ "MoonshineModel",
+ "MoonshinePreTrainedModel",
+ ]
+ )
+ _import_structure["models.moshi"].extend(
+ [
+ "MoshiForCausalLM",
+ "MoshiForConditionalGeneration",
+ "MoshiModel",
+ "MoshiPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mpnet"].extend(
+ [
+ "MPNetForMaskedLM",
+ "MPNetForMultipleChoice",
+ "MPNetForQuestionAnswering",
+ "MPNetForSequenceClassification",
+ "MPNetForTokenClassification",
+ "MPNetModel",
+ "MPNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mpt"].extend(
+ [
+ "MptForCausalLM",
+ "MptForQuestionAnswering",
+ "MptForSequenceClassification",
+ "MptForTokenClassification",
+ "MptModel",
+ "MptPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mra"].extend(
+ [
+ "MraForMaskedLM",
+ "MraForMultipleChoice",
+ "MraForQuestionAnswering",
+ "MraForSequenceClassification",
+ "MraForTokenClassification",
+ "MraModel",
+ "MraPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mt5"].extend(
+ [
+ "MT5EncoderModel",
+ "MT5ForConditionalGeneration",
+ "MT5ForQuestionAnswering",
+ "MT5ForSequenceClassification",
+ "MT5ForTokenClassification",
+ "MT5Model",
+ "MT5PreTrainedModel",
+ ]
+ )
+ _import_structure["models.musicgen"].extend(
+ [
+ "MusicgenForCausalLM",
+ "MusicgenForConditionalGeneration",
+ "MusicgenModel",
+ "MusicgenPreTrainedModel",
+ "MusicgenProcessor",
+ ]
+ )
+ _import_structure["models.musicgen_melody"].extend(
+ [
+ "MusicgenMelodyForCausalLM",
+ "MusicgenMelodyForConditionalGeneration",
+ "MusicgenMelodyModel",
+ "MusicgenMelodyPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mvp"].extend(
+ [
+ "MvpForCausalLM",
+ "MvpForConditionalGeneration",
+ "MvpForQuestionAnswering",
+ "MvpForSequenceClassification",
+ "MvpModel",
+ "MvpPreTrainedModel",
+ ]
+ )
+ _import_structure["models.nemotron"].extend(
+ [
+ "NemotronForCausalLM",
+ "NemotronForQuestionAnswering",
+ "NemotronForSequenceClassification",
+ "NemotronForTokenClassification",
+ "NemotronModel",
+ "NemotronPreTrainedModel",
+ ]
+ )
+ _import_structure["models.nllb_moe"].extend(
+ [
+ "NllbMoeForConditionalGeneration",
+ "NllbMoeModel",
+ "NllbMoePreTrainedModel",
+ "NllbMoeSparseMLP",
+ "NllbMoeTop2Router",
+ ]
+ )
+ _import_structure["models.nystromformer"].extend(
+ [
+ "NystromformerForMaskedLM",
+ "NystromformerForMultipleChoice",
+ "NystromformerForQuestionAnswering",
+ "NystromformerForSequenceClassification",
+ "NystromformerForTokenClassification",
+ "NystromformerModel",
+ "NystromformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.olmo"].extend(
+ [
+ "OlmoForCausalLM",
+ "OlmoModel",
+ "OlmoPreTrainedModel",
+ ]
+ )
+ _import_structure["models.olmo2"].extend(
+ [
+ "Olmo2ForCausalLM",
+ "Olmo2Model",
+ "Olmo2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.olmoe"].extend(
+ [
+ "OlmoeForCausalLM",
+ "OlmoeModel",
+ "OlmoePreTrainedModel",
+ ]
+ )
+ _import_structure["models.omdet_turbo"].extend(
+ [
+ "OmDetTurboForObjectDetection",
+ "OmDetTurboPreTrainedModel",
+ ]
+ )
+ _import_structure["models.oneformer"].extend(
+ [
+ "OneFormerForUniversalSegmentation",
+ "OneFormerModel",
+ "OneFormerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.openai"].extend(
+ [
+ "OpenAIGPTDoubleHeadsModel",
+ "OpenAIGPTForSequenceClassification",
+ "OpenAIGPTLMHeadModel",
+ "OpenAIGPTModel",
+ "OpenAIGPTPreTrainedModel",
+ "load_tf_weights_in_openai_gpt",
+ ]
+ )
+ _import_structure["models.opt"].extend(
+ [
+ "OPTForCausalLM",
+ "OPTForQuestionAnswering",
+ "OPTForSequenceClassification",
+ "OPTModel",
+ "OPTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.owlv2"].extend(
+ [
+ "Owlv2ForObjectDetection",
+ "Owlv2Model",
+ "Owlv2PreTrainedModel",
+ "Owlv2TextModel",
+ "Owlv2VisionModel",
+ ]
+ )
+ _import_structure["models.owlvit"].extend(
+ [
+ "OwlViTForObjectDetection",
+ "OwlViTModel",
+ "OwlViTPreTrainedModel",
+ "OwlViTTextModel",
+ "OwlViTVisionModel",
+ ]
+ )
+ _import_structure["models.paligemma"].extend(
+ [
+ "PaliGemmaForConditionalGeneration",
+ "PaliGemmaPreTrainedModel",
+ "PaliGemmaProcessor",
+ ]
+ )
+ _import_structure["models.patchtsmixer"].extend(
+ [
+ "PatchTSMixerForPrediction",
+ "PatchTSMixerForPretraining",
+ "PatchTSMixerForRegression",
+ "PatchTSMixerForTimeSeriesClassification",
+ "PatchTSMixerModel",
+ "PatchTSMixerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.patchtst"].extend(
+ [
+ "PatchTSTForClassification",
+ "PatchTSTForPrediction",
+ "PatchTSTForPretraining",
+ "PatchTSTForRegression",
+ "PatchTSTModel",
+ "PatchTSTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.pegasus"].extend(
+ [
+ "PegasusForCausalLM",
+ "PegasusForConditionalGeneration",
+ "PegasusModel",
+ "PegasusPreTrainedModel",
+ ]
+ )
+ _import_structure["models.pegasus_x"].extend(
+ [
+ "PegasusXForConditionalGeneration",
+ "PegasusXModel",
+ "PegasusXPreTrainedModel",
+ ]
+ )
+ _import_structure["models.perceiver"].extend(
+ [
+ "PerceiverForImageClassificationConvProcessing",
+ "PerceiverForImageClassificationFourier",
+ "PerceiverForImageClassificationLearned",
+ "PerceiverForMaskedLM",
+ "PerceiverForMultimodalAutoencoding",
+ "PerceiverForOpticalFlow",
+ "PerceiverForSequenceClassification",
+ "PerceiverModel",
+ "PerceiverPreTrainedModel",
+ ]
+ )
+ _import_structure["models.persimmon"].extend(
+ [
+ "PersimmonForCausalLM",
+ "PersimmonForSequenceClassification",
+ "PersimmonForTokenClassification",
+ "PersimmonModel",
+ "PersimmonPreTrainedModel",
+ ]
+ )
+ _import_structure["models.phi"].extend(
+ [
+ "PhiForCausalLM",
+ "PhiForSequenceClassification",
+ "PhiForTokenClassification",
+ "PhiModel",
+ "PhiPreTrainedModel",
+ ]
+ )
+ _import_structure["models.phi3"].extend(
+ [
+ "Phi3ForCausalLM",
+ "Phi3ForSequenceClassification",
+ "Phi3ForTokenClassification",
+ "Phi3Model",
+ "Phi3PreTrainedModel",
+ ]
+ )
+ _import_structure["models.phimoe"].extend(
+ [
+ "PhimoeForCausalLM",
+ "PhimoeForSequenceClassification",
+ "PhimoeModel",
+ "PhimoePreTrainedModel",
+ ]
+ )
+ _import_structure["models.pix2struct"].extend(
+ [
+ "Pix2StructForConditionalGeneration",
+ "Pix2StructPreTrainedModel",
+ "Pix2StructTextModel",
+ "Pix2StructVisionModel",
+ ]
+ )
+ _import_structure["models.pixtral"].extend(["PixtralPreTrainedModel", "PixtralVisionModel"])
+ _import_structure["models.plbart"].extend(
+ [
+ "PLBartForCausalLM",
+ "PLBartForConditionalGeneration",
+ "PLBartForSequenceClassification",
+ "PLBartModel",
+ "PLBartPreTrainedModel",
+ ]
+ )
+ _import_structure["models.poolformer"].extend(
+ [
+ "PoolFormerForImageClassification",
+ "PoolFormerModel",
+ "PoolFormerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.pop2piano"].extend(
+ [
+ "Pop2PianoForConditionalGeneration",
+ "Pop2PianoPreTrainedModel",
+ ]
+ )
+ _import_structure["models.prophetnet"].extend(
+ [
+ "ProphetNetDecoder",
+ "ProphetNetEncoder",
+ "ProphetNetForCausalLM",
+ "ProphetNetForConditionalGeneration",
+ "ProphetNetModel",
+ "ProphetNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.pvt"].extend(
+ [
+ "PvtForImageClassification",
+ "PvtModel",
+ "PvtPreTrainedModel",
+ ]
+ )
+ _import_structure["models.pvt_v2"].extend(
+ [
+ "PvtV2Backbone",
+ "PvtV2ForImageClassification",
+ "PvtV2Model",
+ "PvtV2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.qwen2"].extend(
+ [
+ "Qwen2ForCausalLM",
+ "Qwen2ForQuestionAnswering",
+ "Qwen2ForSequenceClassification",
+ "Qwen2ForTokenClassification",
+ "Qwen2Model",
+ "Qwen2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.qwen2_audio"].extend(
+ [
+ "Qwen2AudioEncoder",
+ "Qwen2AudioForConditionalGeneration",
+ "Qwen2AudioPreTrainedModel",
+ ]
+ )
+ _import_structure["models.qwen2_moe"].extend(
+ [
+ "Qwen2MoeForCausalLM",
+ "Qwen2MoeForQuestionAnswering",
+ "Qwen2MoeForSequenceClassification",
+ "Qwen2MoeForTokenClassification",
+ "Qwen2MoeModel",
+ "Qwen2MoePreTrainedModel",
+ ]
+ )
+ _import_structure["models.qwen2_vl"].extend(
+ [
+ "Qwen2VLForConditionalGeneration",
+ "Qwen2VLModel",
+ "Qwen2VLPreTrainedModel",
+ ]
+ )
+ _import_structure["models.rag"].extend(
+ [
+ "RagModel",
+ "RagPreTrainedModel",
+ "RagSequenceForGeneration",
+ "RagTokenForGeneration",
+ ]
+ )
+ _import_structure["models.recurrent_gemma"].extend(
+ [
+ "RecurrentGemmaForCausalLM",
+ "RecurrentGemmaModel",
+ "RecurrentGemmaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.reformer"].extend(
+ [
+ "ReformerForMaskedLM",
+ "ReformerForQuestionAnswering",
+ "ReformerForSequenceClassification",
+ "ReformerModel",
+ "ReformerModelWithLMHead",
+ "ReformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.regnet"].extend(
+ [
+ "RegNetForImageClassification",
+ "RegNetModel",
+ "RegNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.rembert"].extend(
+ [
+ "RemBertForCausalLM",
+ "RemBertForMaskedLM",
+ "RemBertForMultipleChoice",
+ "RemBertForQuestionAnswering",
+ "RemBertForSequenceClassification",
+ "RemBertForTokenClassification",
+ "RemBertModel",
+ "RemBertPreTrainedModel",
+ "load_tf_weights_in_rembert",
+ ]
+ )
+ _import_structure["models.resnet"].extend(
+ [
+ "ResNetBackbone",
+ "ResNetForImageClassification",
+ "ResNetModel",
+ "ResNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.roberta"].extend(
+ [
+ "RobertaForCausalLM",
+ "RobertaForMaskedLM",
+ "RobertaForMultipleChoice",
+ "RobertaForQuestionAnswering",
+ "RobertaForSequenceClassification",
+ "RobertaForTokenClassification",
+ "RobertaModel",
+ "RobertaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.roberta_prelayernorm"].extend(
+ [
+ "RobertaPreLayerNormForCausalLM",
+ "RobertaPreLayerNormForMaskedLM",
+ "RobertaPreLayerNormForMultipleChoice",
+ "RobertaPreLayerNormForQuestionAnswering",
+ "RobertaPreLayerNormForSequenceClassification",
+ "RobertaPreLayerNormForTokenClassification",
+ "RobertaPreLayerNormModel",
+ "RobertaPreLayerNormPreTrainedModel",
+ ]
+ )
+ _import_structure["models.roc_bert"].extend(
+ [
+ "RoCBertForCausalLM",
+ "RoCBertForMaskedLM",
+ "RoCBertForMultipleChoice",
+ "RoCBertForPreTraining",
+ "RoCBertForQuestionAnswering",
+ "RoCBertForSequenceClassification",
+ "RoCBertForTokenClassification",
+ "RoCBertModel",
+ "RoCBertPreTrainedModel",
+ "load_tf_weights_in_roc_bert",
+ ]
+ )
+ _import_structure["models.roformer"].extend(
+ [
+ "RoFormerForCausalLM",
+ "RoFormerForMaskedLM",
+ "RoFormerForMultipleChoice",
+ "RoFormerForQuestionAnswering",
+ "RoFormerForSequenceClassification",
+ "RoFormerForTokenClassification",
+ "RoFormerModel",
+ "RoFormerPreTrainedModel",
+ "load_tf_weights_in_roformer",
+ ]
+ )
+ _import_structure["models.rt_detr"].extend(
+ [
+ "RTDetrForObjectDetection",
+ "RTDetrModel",
+ "RTDetrPreTrainedModel",
+ "RTDetrResNetBackbone",
+ "RTDetrResNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.rwkv"].extend(
+ [
+ "RwkvForCausalLM",
+ "RwkvModel",
+ "RwkvPreTrainedModel",
+ ]
+ )
+ _import_structure["models.sam"].extend(
+ [
+ "SamModel",
+ "SamPreTrainedModel",
+ ]
+ )
+ _import_structure["models.seamless_m4t"].extend(
+ [
+ "SeamlessM4TCodeHifiGan",
+ "SeamlessM4TForSpeechToSpeech",
+ "SeamlessM4TForSpeechToText",
+ "SeamlessM4TForTextToSpeech",
+ "SeamlessM4TForTextToText",
+ "SeamlessM4THifiGan",
+ "SeamlessM4TModel",
+ "SeamlessM4TPreTrainedModel",
+ "SeamlessM4TTextToUnitForConditionalGeneration",
+ "SeamlessM4TTextToUnitModel",
+ ]
+ )
+ _import_structure["models.seamless_m4t_v2"].extend(
+ [
+ "SeamlessM4Tv2ForSpeechToSpeech",
+ "SeamlessM4Tv2ForSpeechToText",
+ "SeamlessM4Tv2ForTextToSpeech",
+ "SeamlessM4Tv2ForTextToText",
+ "SeamlessM4Tv2Model",
+ "SeamlessM4Tv2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.segformer"].extend(
+ [
+ "SegformerDecodeHead",
+ "SegformerForImageClassification",
+ "SegformerForSemanticSegmentation",
+ "SegformerModel",
+ "SegformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.seggpt"].extend(
+ [
+ "SegGptForImageSegmentation",
+ "SegGptModel",
+ "SegGptPreTrainedModel",
+ ]
+ )
+ _import_structure["models.sew"].extend(
+ [
+ "SEWForCTC",
+ "SEWForSequenceClassification",
+ "SEWModel",
+ "SEWPreTrainedModel",
+ ]
+ )
+ _import_structure["models.sew_d"].extend(
+ [
+ "SEWDForCTC",
+ "SEWDForSequenceClassification",
+ "SEWDModel",
+ "SEWDPreTrainedModel",
+ ]
+ )
+ _import_structure["models.siglip"].extend(
+ [
+ "SiglipForImageClassification",
+ "SiglipModel",
+ "SiglipPreTrainedModel",
+ "SiglipTextModel",
+ "SiglipVisionModel",
+ ]
+ )
+ _import_structure["models.speech_encoder_decoder"].extend(["SpeechEncoderDecoderModel"])
+ _import_structure["models.speech_to_text"].extend(
+ [
+ "Speech2TextForConditionalGeneration",
+ "Speech2TextModel",
+ "Speech2TextPreTrainedModel",
+ ]
+ )
+ _import_structure["models.speecht5"].extend(
+ [
+ "SpeechT5ForSpeechToSpeech",
+ "SpeechT5ForSpeechToText",
+ "SpeechT5ForTextToSpeech",
+ "SpeechT5HifiGan",
+ "SpeechT5Model",
+ "SpeechT5PreTrainedModel",
+ ]
+ )
+ _import_structure["models.splinter"].extend(
+ [
+ "SplinterForPreTraining",
+ "SplinterForQuestionAnswering",
+ "SplinterModel",
+ "SplinterPreTrainedModel",
+ ]
+ )
+ _import_structure["models.squeezebert"].extend(
+ [
+ "SqueezeBertForMaskedLM",
+ "SqueezeBertForMultipleChoice",
+ "SqueezeBertForQuestionAnswering",
+ "SqueezeBertForSequenceClassification",
+ "SqueezeBertForTokenClassification",
+ "SqueezeBertModel",
+ "SqueezeBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.stablelm"].extend(
+ [
+ "StableLmForCausalLM",
+ "StableLmForSequenceClassification",
+ "StableLmForTokenClassification",
+ "StableLmModel",
+ "StableLmPreTrainedModel",
+ ]
+ )
+ _import_structure["models.starcoder2"].extend(
+ [
+ "Starcoder2ForCausalLM",
+ "Starcoder2ForSequenceClassification",
+ "Starcoder2ForTokenClassification",
+ "Starcoder2Model",
+ "Starcoder2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.superpoint"].extend(
+ [
+ "SuperPointForKeypointDetection",
+ "SuperPointPreTrainedModel",
+ ]
+ )
+ _import_structure["models.swiftformer"].extend(
+ [
+ "SwiftFormerForImageClassification",
+ "SwiftFormerModel",
+ "SwiftFormerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.swin"].extend(
+ [
+ "SwinBackbone",
+ "SwinForImageClassification",
+ "SwinForMaskedImageModeling",
+ "SwinModel",
+ "SwinPreTrainedModel",
+ ]
+ )
+ _import_structure["models.swin2sr"].extend(
+ [
+ "Swin2SRForImageSuperResolution",
+ "Swin2SRModel",
+ "Swin2SRPreTrainedModel",
+ ]
+ )
+ _import_structure["models.swinv2"].extend(
+ [
+ "Swinv2Backbone",
+ "Swinv2ForImageClassification",
+ "Swinv2ForMaskedImageModeling",
+ "Swinv2Model",
+ "Swinv2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.switch_transformers"].extend(
+ [
+ "SwitchTransformersEncoderModel",
+ "SwitchTransformersForConditionalGeneration",
+ "SwitchTransformersModel",
+ "SwitchTransformersPreTrainedModel",
+ "SwitchTransformersSparseMLP",
+ "SwitchTransformersTop1Router",
+ ]
+ )
+ _import_structure["models.t5"].extend(
+ [
+ "T5EncoderModel",
+ "T5ForConditionalGeneration",
+ "T5ForQuestionAnswering",
+ "T5ForSequenceClassification",
+ "T5ForTokenClassification",
+ "T5Model",
+ "T5PreTrainedModel",
+ "load_tf_weights_in_t5",
+ ]
+ )
+ _import_structure["models.table_transformer"].extend(
+ [
+ "TableTransformerForObjectDetection",
+ "TableTransformerModel",
+ "TableTransformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.tapas"].extend(
+ [
+ "TapasForMaskedLM",
+ "TapasForQuestionAnswering",
+ "TapasForSequenceClassification",
+ "TapasModel",
+ "TapasPreTrainedModel",
+ "load_tf_weights_in_tapas",
+ ]
+ )
+ _import_structure["models.textnet"].extend(
+ [
+ "TextNetBackbone",
+ "TextNetForImageClassification",
+ "TextNetModel",
+ "TextNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.time_series_transformer"].extend(
+ [
+ "TimeSeriesTransformerForPrediction",
+ "TimeSeriesTransformerModel",
+ "TimeSeriesTransformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.timesformer"].extend(
+ [
+ "TimesformerForVideoClassification",
+ "TimesformerModel",
+ "TimesformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.timm_backbone"].extend(["TimmBackbone"])
+ _import_structure["models.timm_wrapper"].extend(
+ ["TimmWrapperForImageClassification", "TimmWrapperModel", "TimmWrapperPreTrainedModel"]
+ )
+ _import_structure["models.trocr"].extend(
+ [
+ "TrOCRForCausalLM",
+ "TrOCRPreTrainedModel",
+ ]
+ )
+ _import_structure["models.tvp"].extend(
+ [
+ "TvpForVideoGrounding",
+ "TvpModel",
+ "TvpPreTrainedModel",
+ ]
+ )
+ _import_structure["models.udop"].extend(
+ [
+ "UdopEncoderModel",
+ "UdopForConditionalGeneration",
+ "UdopModel",
+ "UdopPreTrainedModel",
+ ],
+ )
+ _import_structure["models.umt5"].extend(
+ [
+ "UMT5EncoderModel",
+ "UMT5ForConditionalGeneration",
+ "UMT5ForQuestionAnswering",
+ "UMT5ForSequenceClassification",
+ "UMT5ForTokenClassification",
+ "UMT5Model",
+ "UMT5PreTrainedModel",
+ ]
+ )
+ _import_structure["models.unispeech"].extend(
+ [
+ "UniSpeechForCTC",
+ "UniSpeechForPreTraining",
+ "UniSpeechForSequenceClassification",
+ "UniSpeechModel",
+ "UniSpeechPreTrainedModel",
+ ]
+ )
+ _import_structure["models.unispeech_sat"].extend(
+ [
+ "UniSpeechSatForAudioFrameClassification",
+ "UniSpeechSatForCTC",
+ "UniSpeechSatForPreTraining",
+ "UniSpeechSatForSequenceClassification",
+ "UniSpeechSatForXVector",
+ "UniSpeechSatModel",
+ "UniSpeechSatPreTrainedModel",
+ ]
+ )
+ _import_structure["models.univnet"].extend(
+ [
+ "UnivNetModel",
+ ]
+ )
+ _import_structure["models.upernet"].extend(
+ [
+ "UperNetForSemanticSegmentation",
+ "UperNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.video_llava"].extend(
+ [
+ "VideoLlavaForConditionalGeneration",
+ "VideoLlavaPreTrainedModel",
+ "VideoLlavaProcessor",
+ ]
+ )
+ _import_structure["models.videomae"].extend(
+ [
+ "VideoMAEForPreTraining",
+ "VideoMAEForVideoClassification",
+ "VideoMAEModel",
+ "VideoMAEPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vilt"].extend(
+ [
+ "ViltForImageAndTextRetrieval",
+ "ViltForImagesAndTextClassification",
+ "ViltForMaskedLM",
+ "ViltForQuestionAnswering",
+ "ViltForTokenClassification",
+ "ViltModel",
+ "ViltPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vipllava"].extend(
+ [
+ "VipLlavaForConditionalGeneration",
+ "VipLlavaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"])
+ _import_structure["models.vision_text_dual_encoder"].extend(["VisionTextDualEncoderModel"])
+ _import_structure["models.visual_bert"].extend(
+ [
+ "VisualBertForMultipleChoice",
+ "VisualBertForPreTraining",
+ "VisualBertForQuestionAnswering",
+ "VisualBertForRegionToPhraseAlignment",
+ "VisualBertForVisualReasoning",
+ "VisualBertModel",
+ "VisualBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vit"].extend(
+ [
+ "ViTForImageClassification",
+ "ViTForMaskedImageModeling",
+ "ViTModel",
+ "ViTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vit_mae"].extend(
+ [
+ "ViTMAEForPreTraining",
+ "ViTMAEModel",
+ "ViTMAEPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vit_msn"].extend(
+ [
+ "ViTMSNForImageClassification",
+ "ViTMSNModel",
+ "ViTMSNPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vitdet"].extend(
+ [
+ "VitDetBackbone",
+ "VitDetModel",
+ "VitDetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vitmatte"].extend(
+ [
+ "VitMatteForImageMatting",
+ "VitMattePreTrainedModel",
+ ]
+ )
+ _import_structure["models.vitpose"].extend(
+ [
+ "VitPoseForPoseEstimation",
+ "VitPosePreTrainedModel",
+ ]
+ )
+ _import_structure["models.vitpose_backbone"].extend(
+ [
+ "VitPoseBackbone",
+ "VitPoseBackbonePreTrainedModel",
+ ]
+ )
+ _import_structure["models.vits"].extend(
+ [
+ "VitsModel",
+ "VitsPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vivit"].extend(
+ [
+ "VivitForVideoClassification",
+ "VivitModel",
+ "VivitPreTrainedModel",
+ ]
+ )
+ _import_structure["models.wav2vec2"].extend(
+ [
+ "Wav2Vec2ForAudioFrameClassification",
+ "Wav2Vec2ForCTC",
+ "Wav2Vec2ForMaskedLM",
+ "Wav2Vec2ForPreTraining",
+ "Wav2Vec2ForSequenceClassification",
+ "Wav2Vec2ForXVector",
+ "Wav2Vec2Model",
+ "Wav2Vec2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.wav2vec2_bert"].extend(
+ [
+ "Wav2Vec2BertForAudioFrameClassification",
+ "Wav2Vec2BertForCTC",
+ "Wav2Vec2BertForSequenceClassification",
+ "Wav2Vec2BertForXVector",
+ "Wav2Vec2BertModel",
+ "Wav2Vec2BertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.wav2vec2_conformer"].extend(
+ [
+ "Wav2Vec2ConformerForAudioFrameClassification",
+ "Wav2Vec2ConformerForCTC",
+ "Wav2Vec2ConformerForPreTraining",
+ "Wav2Vec2ConformerForSequenceClassification",
+ "Wav2Vec2ConformerForXVector",
+ "Wav2Vec2ConformerModel",
+ "Wav2Vec2ConformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.wavlm"].extend(
+ [
+ "WavLMForAudioFrameClassification",
+ "WavLMForCTC",
+ "WavLMForSequenceClassification",
+ "WavLMForXVector",
+ "WavLMModel",
+ "WavLMPreTrainedModel",
+ ]
+ )
+ _import_structure["models.whisper"].extend(
+ [
+ "WhisperForAudioClassification",
+ "WhisperForCausalLM",
+ "WhisperForConditionalGeneration",
+ "WhisperModel",
+ "WhisperPreTrainedModel",
+ ]
+ )
+ _import_structure["models.x_clip"].extend(
+ [
+ "XCLIPModel",
+ "XCLIPPreTrainedModel",
+ "XCLIPTextModel",
+ "XCLIPVisionModel",
+ ]
+ )
+ _import_structure["models.xglm"].extend(
+ [
+ "XGLMForCausalLM",
+ "XGLMModel",
+ "XGLMPreTrainedModel",
+ ]
+ )
+ _import_structure["models.xlm"].extend(
+ [
+ "XLMForMultipleChoice",
+ "XLMForQuestionAnswering",
+ "XLMForQuestionAnsweringSimple",
+ "XLMForSequenceClassification",
+ "XLMForTokenClassification",
+ "XLMModel",
+ "XLMPreTrainedModel",
+ "XLMWithLMHeadModel",
+ ]
+ )
+ _import_structure["models.xlm_roberta"].extend(
+ [
+ "XLMRobertaForCausalLM",
+ "XLMRobertaForMaskedLM",
+ "XLMRobertaForMultipleChoice",
+ "XLMRobertaForQuestionAnswering",
+ "XLMRobertaForSequenceClassification",
+ "XLMRobertaForTokenClassification",
+ "XLMRobertaModel",
+ "XLMRobertaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.xlm_roberta_xl"].extend(
+ [
+ "XLMRobertaXLForCausalLM",
+ "XLMRobertaXLForMaskedLM",
+ "XLMRobertaXLForMultipleChoice",
+ "XLMRobertaXLForQuestionAnswering",
+ "XLMRobertaXLForSequenceClassification",
+ "XLMRobertaXLForTokenClassification",
+ "XLMRobertaXLModel",
+ "XLMRobertaXLPreTrainedModel",
+ ]
+ )
+ _import_structure["models.xlnet"].extend(
+ [
+ "XLNetForMultipleChoice",
+ "XLNetForQuestionAnswering",
+ "XLNetForQuestionAnsweringSimple",
+ "XLNetForSequenceClassification",
+ "XLNetForTokenClassification",
+ "XLNetLMHeadModel",
+ "XLNetModel",
+ "XLNetPreTrainedModel",
+ "load_tf_weights_in_xlnet",
+ ]
+ )
+ _import_structure["models.xmod"].extend(
+ [
+ "XmodForCausalLM",
+ "XmodForMaskedLM",
+ "XmodForMultipleChoice",
+ "XmodForQuestionAnswering",
+ "XmodForSequenceClassification",
+ "XmodForTokenClassification",
+ "XmodModel",
+ "XmodPreTrainedModel",
+ ]
+ )
+ _import_structure["models.yolos"].extend(
+ [
+ "YolosForObjectDetection",
+ "YolosModel",
+ "YolosPreTrainedModel",
+ ]
+ )
+ _import_structure["models.yoso"].extend(
+ [
+ "YosoForMaskedLM",
+ "YosoForMultipleChoice",
+ "YosoForQuestionAnswering",
+ "YosoForSequenceClassification",
+ "YosoForTokenClassification",
+ "YosoModel",
+ "YosoPreTrainedModel",
+ ]
+ )
+ _import_structure["models.zamba"].extend(
+ [
+ "ZambaForCausalLM",
+ "ZambaForSequenceClassification",
+ "ZambaModel",
+ "ZambaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.zoedepth"].extend(
+ [
+ "ZoeDepthForDepthEstimation",
+ "ZoeDepthPreTrainedModel",
+ ]
+ )
+ _import_structure["optimization"] = [
+ "Adafactor",
+ "AdamW",
+ "get_constant_schedule",
+ "get_constant_schedule_with_warmup",
+ "get_cosine_schedule_with_warmup",
+ "get_cosine_with_hard_restarts_schedule_with_warmup",
+ "get_inverse_sqrt_schedule",
+ "get_linear_schedule_with_warmup",
+ "get_polynomial_decay_schedule_with_warmup",
+ "get_scheduler",
+ "get_wsd_schedule",
+ ]
+ _import_structure["pytorch_utils"] = [
+ "Conv1D",
+ "apply_chunking_to_forward",
+ "prune_layer",
+ ]
+ _import_structure["sagemaker"] = []
+ _import_structure["time_series_utils"] = []
+ _import_structure["trainer"] = ["Trainer"]
+ _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
+ _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]
+
+# TensorFlow-backed objects
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_tf_objects
+
+ _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
+else:
+ _import_structure["activations_tf"] = []
+ _import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
+ _import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
+ _import_structure["generation"].extend(
+ [
+ "TFForcedBOSTokenLogitsProcessor",
+ "TFForcedEOSTokenLogitsProcessor",
+ "TFForceTokensLogitsProcessor",
+ "TFGenerationMixin",
+ "TFLogitsProcessor",
+ "TFLogitsProcessorList",
+ "TFLogitsWarper",
+ "TFMinLengthLogitsProcessor",
+ "TFNoBadWordsLogitsProcessor",
+ "TFNoRepeatNGramLogitsProcessor",
+ "TFRepetitionPenaltyLogitsProcessor",
+ "TFSuppressTokensAtBeginLogitsProcessor",
+ "TFSuppressTokensLogitsProcessor",
+ "TFTemperatureLogitsWarper",
+ "TFTopKLogitsWarper",
+ "TFTopPLogitsWarper",
+ ]
+ )
+ _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
+ _import_structure["modeling_tf_outputs"] = []
+ _import_structure["modeling_tf_utils"] = [
+ "TFPreTrainedModel",
+ "TFSequenceSummary",
+ "TFSharedEmbeddings",
+ "shape_list",
+ ]
+ # TensorFlow models structure
+ _import_structure["models.albert"].extend(
+ [
+ "TFAlbertForMaskedLM",
+ "TFAlbertForMultipleChoice",
+ "TFAlbertForPreTraining",
+ "TFAlbertForQuestionAnswering",
+ "TFAlbertForSequenceClassification",
+ "TFAlbertForTokenClassification",
+ "TFAlbertMainLayer",
+ "TFAlbertModel",
+ "TFAlbertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.auto"].extend(
+ [
+ "TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
+ "TF_MODEL_FOR_CAUSAL_LM_MAPPING",
+ "TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING",
+ "TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
+ "TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
+ "TF_MODEL_FOR_MASKED_LM_MAPPING",
+ "TF_MODEL_FOR_MASK_GENERATION_MAPPING",
+ "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
+ "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
+ "TF_MODEL_FOR_PRETRAINING_MAPPING",
+ "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
+ "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
+ "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
+ "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
+ "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
+ "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
+ "TF_MODEL_FOR_TEXT_ENCODING_MAPPING",
+ "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
+ "TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
+ "TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING",
+ "TF_MODEL_MAPPING",
+ "TF_MODEL_WITH_LM_HEAD_MAPPING",
+ "TFAutoModel",
+ "TFAutoModelForAudioClassification",
+ "TFAutoModelForCausalLM",
+ "TFAutoModelForDocumentQuestionAnswering",
+ "TFAutoModelForImageClassification",
+ "TFAutoModelForMaskedImageModeling",
+ "TFAutoModelForMaskedLM",
+ "TFAutoModelForMaskGeneration",
+ "TFAutoModelForMultipleChoice",
+ "TFAutoModelForNextSentencePrediction",
+ "TFAutoModelForPreTraining",
+ "TFAutoModelForQuestionAnswering",
+ "TFAutoModelForSemanticSegmentation",
+ "TFAutoModelForSeq2SeqLM",
+ "TFAutoModelForSequenceClassification",
+ "TFAutoModelForSpeechSeq2Seq",
+ "TFAutoModelForTableQuestionAnswering",
+ "TFAutoModelForTextEncoding",
+ "TFAutoModelForTokenClassification",
+ "TFAutoModelForVision2Seq",
+ "TFAutoModelForZeroShotImageClassification",
+ "TFAutoModelWithLMHead",
+ ]
+ )
+ _import_structure["models.bart"].extend(
+ [
+ "TFBartForConditionalGeneration",
+ "TFBartForSequenceClassification",
+ "TFBartModel",
+ "TFBartPretrainedModel",
+ ]
+ )
+ _import_structure["models.bert"].extend(
+ [
+ "TFBertForMaskedLM",
+ "TFBertForMultipleChoice",
+ "TFBertForNextSentencePrediction",
+ "TFBertForPreTraining",
+ "TFBertForQuestionAnswering",
+ "TFBertForSequenceClassification",
+ "TFBertForTokenClassification",
+ "TFBertLMHeadModel",
+ "TFBertMainLayer",
+ "TFBertModel",
+ "TFBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.blenderbot"].extend(
+ [
+ "TFBlenderbotForConditionalGeneration",
+ "TFBlenderbotModel",
+ "TFBlenderbotPreTrainedModel",
+ ]
+ )
+ _import_structure["models.blenderbot_small"].extend(
+ [
+ "TFBlenderbotSmallForConditionalGeneration",
+ "TFBlenderbotSmallModel",
+ "TFBlenderbotSmallPreTrainedModel",
+ ]
+ )
+ _import_structure["models.blip"].extend(
+ [
+ "TFBlipForConditionalGeneration",
+ "TFBlipForImageTextRetrieval",
+ "TFBlipForQuestionAnswering",
+ "TFBlipModel",
+ "TFBlipPreTrainedModel",
+ "TFBlipTextModel",
+ "TFBlipVisionModel",
+ ]
+ )
+ _import_structure["models.camembert"].extend(
+ [
+ "TFCamembertForCausalLM",
+ "TFCamembertForMaskedLM",
+ "TFCamembertForMultipleChoice",
+ "TFCamembertForQuestionAnswering",
+ "TFCamembertForSequenceClassification",
+ "TFCamembertForTokenClassification",
+ "TFCamembertModel",
+ "TFCamembertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.clip"].extend(
+ [
+ "TFCLIPModel",
+ "TFCLIPPreTrainedModel",
+ "TFCLIPTextModel",
+ "TFCLIPVisionModel",
+ ]
+ )
+ _import_structure["models.convbert"].extend(
+ [
+ "TFConvBertForMaskedLM",
+ "TFConvBertForMultipleChoice",
+ "TFConvBertForQuestionAnswering",
+ "TFConvBertForSequenceClassification",
+ "TFConvBertForTokenClassification",
+ "TFConvBertModel",
+ "TFConvBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.convnext"].extend(
+ [
+ "TFConvNextForImageClassification",
+ "TFConvNextModel",
+ "TFConvNextPreTrainedModel",
+ ]
+ )
+ _import_structure["models.convnextv2"].extend(
+ [
+ "TFConvNextV2ForImageClassification",
+ "TFConvNextV2Model",
+ "TFConvNextV2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.ctrl"].extend(
+ [
+ "TFCTRLForSequenceClassification",
+ "TFCTRLLMHeadModel",
+ "TFCTRLModel",
+ "TFCTRLPreTrainedModel",
+ ]
+ )
+ _import_structure["models.cvt"].extend(
+ [
+ "TFCvtForImageClassification",
+ "TFCvtModel",
+ "TFCvtPreTrainedModel",
+ ]
+ )
+ _import_structure["models.data2vec"].extend(
+ [
+ "TFData2VecVisionForImageClassification",
+ "TFData2VecVisionForSemanticSegmentation",
+ "TFData2VecVisionModel",
+ "TFData2VecVisionPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deberta"].extend(
+ [
+ "TFDebertaForMaskedLM",
+ "TFDebertaForQuestionAnswering",
+ "TFDebertaForSequenceClassification",
+ "TFDebertaForTokenClassification",
+ "TFDebertaModel",
+ "TFDebertaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deberta_v2"].extend(
+ [
+ "TFDebertaV2ForMaskedLM",
+ "TFDebertaV2ForMultipleChoice",
+ "TFDebertaV2ForQuestionAnswering",
+ "TFDebertaV2ForSequenceClassification",
+ "TFDebertaV2ForTokenClassification",
+ "TFDebertaV2Model",
+ "TFDebertaV2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.deit"].extend(
+ [
+ "TFDeiTForImageClassification",
+ "TFDeiTForImageClassificationWithTeacher",
+ "TFDeiTForMaskedImageModeling",
+ "TFDeiTModel",
+ "TFDeiTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.efficientformer"].extend(
+ [
+ "TFEfficientFormerForImageClassification",
+ "TFEfficientFormerForImageClassificationWithTeacher",
+ "TFEfficientFormerModel",
+ "TFEfficientFormerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.deprecated.transfo_xl"].extend(
+ [
+ "TFAdaptiveEmbedding",
+ "TFTransfoXLForSequenceClassification",
+ "TFTransfoXLLMHeadModel",
+ "TFTransfoXLMainLayer",
+ "TFTransfoXLModel",
+ "TFTransfoXLPreTrainedModel",
+ ]
+ )
+ _import_structure["models.distilbert"].extend(
+ [
+ "TFDistilBertForMaskedLM",
+ "TFDistilBertForMultipleChoice",
+ "TFDistilBertForQuestionAnswering",
+ "TFDistilBertForSequenceClassification",
+ "TFDistilBertForTokenClassification",
+ "TFDistilBertMainLayer",
+ "TFDistilBertModel",
+ "TFDistilBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.dpr"].extend(
+ [
+ "TFDPRContextEncoder",
+ "TFDPRPretrainedContextEncoder",
+ "TFDPRPretrainedQuestionEncoder",
+ "TFDPRPretrainedReader",
+ "TFDPRQuestionEncoder",
+ "TFDPRReader",
+ ]
+ )
+ _import_structure["models.electra"].extend(
+ [
+ "TFElectraForMaskedLM",
+ "TFElectraForMultipleChoice",
+ "TFElectraForPreTraining",
+ "TFElectraForQuestionAnswering",
+ "TFElectraForSequenceClassification",
+ "TFElectraForTokenClassification",
+ "TFElectraModel",
+ "TFElectraPreTrainedModel",
+ ]
+ )
+ _import_structure["models.encoder_decoder"].append("TFEncoderDecoderModel")
+ _import_structure["models.esm"].extend(
+ [
+ "TFEsmForMaskedLM",
+ "TFEsmForSequenceClassification",
+ "TFEsmForTokenClassification",
+ "TFEsmModel",
+ "TFEsmPreTrainedModel",
+ ]
+ )
+ _import_structure["models.flaubert"].extend(
+ [
+ "TFFlaubertForMultipleChoice",
+ "TFFlaubertForQuestionAnsweringSimple",
+ "TFFlaubertForSequenceClassification",
+ "TFFlaubertForTokenClassification",
+ "TFFlaubertModel",
+ "TFFlaubertPreTrainedModel",
+ "TFFlaubertWithLMHeadModel",
+ ]
+ )
+ _import_structure["models.funnel"].extend(
+ [
+ "TFFunnelBaseModel",
+ "TFFunnelForMaskedLM",
+ "TFFunnelForMultipleChoice",
+ "TFFunnelForPreTraining",
+ "TFFunnelForQuestionAnswering",
+ "TFFunnelForSequenceClassification",
+ "TFFunnelForTokenClassification",
+ "TFFunnelModel",
+ "TFFunnelPreTrainedModel",
+ ]
+ )
+ _import_structure["models.gpt2"].extend(
+ [
+ "TFGPT2DoubleHeadsModel",
+ "TFGPT2ForSequenceClassification",
+ "TFGPT2LMHeadModel",
+ "TFGPT2MainLayer",
+ "TFGPT2Model",
+ "TFGPT2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.gptj"].extend(
+ [
+ "TFGPTJForCausalLM",
+ "TFGPTJForQuestionAnswering",
+ "TFGPTJForSequenceClassification",
+ "TFGPTJModel",
+ "TFGPTJPreTrainedModel",
+ ]
+ )
+ _import_structure["models.groupvit"].extend(
+ [
+ "TFGroupViTModel",
+ "TFGroupViTPreTrainedModel",
+ "TFGroupViTTextModel",
+ "TFGroupViTVisionModel",
+ ]
+ )
+ _import_structure["models.hubert"].extend(
+ [
+ "TFHubertForCTC",
+ "TFHubertModel",
+ "TFHubertPreTrainedModel",
+ ]
+ )
+
+ _import_structure["models.idefics"].extend(
+ [
+ "TFIdeficsForVisionText2Text",
+ "TFIdeficsModel",
+ "TFIdeficsPreTrainedModel",
+ ]
+ )
+
+ _import_structure["models.layoutlm"].extend(
+ [
+ "TFLayoutLMForMaskedLM",
+ "TFLayoutLMForQuestionAnswering",
+ "TFLayoutLMForSequenceClassification",
+ "TFLayoutLMForTokenClassification",
+ "TFLayoutLMMainLayer",
+ "TFLayoutLMModel",
+ "TFLayoutLMPreTrainedModel",
+ ]
+ )
+ _import_structure["models.layoutlmv3"].extend(
+ [
+ "TFLayoutLMv3ForQuestionAnswering",
+ "TFLayoutLMv3ForSequenceClassification",
+ "TFLayoutLMv3ForTokenClassification",
+ "TFLayoutLMv3Model",
+ "TFLayoutLMv3PreTrainedModel",
+ ]
+ )
+ _import_structure["models.led"].extend(["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"])
+ _import_structure["models.longformer"].extend(
+ [
+ "TFLongformerForMaskedLM",
+ "TFLongformerForMultipleChoice",
+ "TFLongformerForQuestionAnswering",
+ "TFLongformerForSequenceClassification",
+ "TFLongformerForTokenClassification",
+ "TFLongformerModel",
+ "TFLongformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.lxmert"].extend(
+ [
+ "TFLxmertForPreTraining",
+ "TFLxmertMainLayer",
+ "TFLxmertModel",
+ "TFLxmertPreTrainedModel",
+ "TFLxmertVisualFeatureEncoder",
+ ]
+ )
+ _import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"])
+ _import_structure["models.mbart"].extend(
+ ["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"]
+ )
+ _import_structure["models.mistral"].extend(
+ ["TFMistralForCausalLM", "TFMistralForSequenceClassification", "TFMistralModel", "TFMistralPreTrainedModel"]
+ )
+ _import_structure["models.mobilebert"].extend(
+ [
+ "TFMobileBertForMaskedLM",
+ "TFMobileBertForMultipleChoice",
+ "TFMobileBertForNextSentencePrediction",
+ "TFMobileBertForPreTraining",
+ "TFMobileBertForQuestionAnswering",
+ "TFMobileBertForSequenceClassification",
+ "TFMobileBertForTokenClassification",
+ "TFMobileBertMainLayer",
+ "TFMobileBertModel",
+ "TFMobileBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mobilevit"].extend(
+ [
+ "TFMobileViTForImageClassification",
+ "TFMobileViTForSemanticSegmentation",
+ "TFMobileViTModel",
+ "TFMobileViTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mpnet"].extend(
+ [
+ "TFMPNetForMaskedLM",
+ "TFMPNetForMultipleChoice",
+ "TFMPNetForQuestionAnswering",
+ "TFMPNetForSequenceClassification",
+ "TFMPNetForTokenClassification",
+ "TFMPNetMainLayer",
+ "TFMPNetModel",
+ "TFMPNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mt5"].extend(["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"])
+ _import_structure["models.openai"].extend(
+ [
+ "TFOpenAIGPTDoubleHeadsModel",
+ "TFOpenAIGPTForSequenceClassification",
+ "TFOpenAIGPTLMHeadModel",
+ "TFOpenAIGPTMainLayer",
+ "TFOpenAIGPTModel",
+ "TFOpenAIGPTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.opt"].extend(
+ [
+ "TFOPTForCausalLM",
+ "TFOPTModel",
+ "TFOPTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.pegasus"].extend(
+ [
+ "TFPegasusForConditionalGeneration",
+ "TFPegasusModel",
+ "TFPegasusPreTrainedModel",
+ ]
+ )
+ _import_structure["models.rag"].extend(
+ [
+ "TFRagModel",
+ "TFRagPreTrainedModel",
+ "TFRagSequenceForGeneration",
+ "TFRagTokenForGeneration",
+ ]
+ )
+ _import_structure["models.regnet"].extend(
+ [
+ "TFRegNetForImageClassification",
+ "TFRegNetModel",
+ "TFRegNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.rembert"].extend(
+ [
+ "TFRemBertForCausalLM",
+ "TFRemBertForMaskedLM",
+ "TFRemBertForMultipleChoice",
+ "TFRemBertForQuestionAnswering",
+ "TFRemBertForSequenceClassification",
+ "TFRemBertForTokenClassification",
+ "TFRemBertModel",
+ "TFRemBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.resnet"].extend(
+ [
+ "TFResNetForImageClassification",
+ "TFResNetModel",
+ "TFResNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.roberta"].extend(
+ [
+ "TFRobertaForCausalLM",
+ "TFRobertaForMaskedLM",
+ "TFRobertaForMultipleChoice",
+ "TFRobertaForQuestionAnswering",
+ "TFRobertaForSequenceClassification",
+ "TFRobertaForTokenClassification",
+ "TFRobertaMainLayer",
+ "TFRobertaModel",
+ "TFRobertaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.roberta_prelayernorm"].extend(
+ [
+ "TFRobertaPreLayerNormForCausalLM",
+ "TFRobertaPreLayerNormForMaskedLM",
+ "TFRobertaPreLayerNormForMultipleChoice",
+ "TFRobertaPreLayerNormForQuestionAnswering",
+ "TFRobertaPreLayerNormForSequenceClassification",
+ "TFRobertaPreLayerNormForTokenClassification",
+ "TFRobertaPreLayerNormMainLayer",
+ "TFRobertaPreLayerNormModel",
+ "TFRobertaPreLayerNormPreTrainedModel",
+ ]
+ )
+ _import_structure["models.roformer"].extend(
+ [
+ "TFRoFormerForCausalLM",
+ "TFRoFormerForMaskedLM",
+ "TFRoFormerForMultipleChoice",
+ "TFRoFormerForQuestionAnswering",
+ "TFRoFormerForSequenceClassification",
+ "TFRoFormerForTokenClassification",
+ "TFRoFormerModel",
+ "TFRoFormerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.sam"].extend(
+ [
+ "TFSamModel",
+ "TFSamPreTrainedModel",
+ ]
+ )
+ _import_structure["models.segformer"].extend(
+ [
+ "TFSegformerDecodeHead",
+ "TFSegformerForImageClassification",
+ "TFSegformerForSemanticSegmentation",
+ "TFSegformerModel",
+ "TFSegformerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.speech_to_text"].extend(
+ [
+ "TFSpeech2TextForConditionalGeneration",
+ "TFSpeech2TextModel",
+ "TFSpeech2TextPreTrainedModel",
+ ]
+ )
+ _import_structure["models.swiftformer"].extend(
+ [
+ "TFSwiftFormerForImageClassification",
+ "TFSwiftFormerModel",
+ "TFSwiftFormerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.swin"].extend(
+ [
+ "TFSwinForImageClassification",
+ "TFSwinForMaskedImageModeling",
+ "TFSwinModel",
+ "TFSwinPreTrainedModel",
+ ]
+ )
+ _import_structure["models.t5"].extend(
+ [
+ "TFT5EncoderModel",
+ "TFT5ForConditionalGeneration",
+ "TFT5Model",
+ "TFT5PreTrainedModel",
+ ]
+ )
+ _import_structure["models.tapas"].extend(
+ [
+ "TFTapasForMaskedLM",
+ "TFTapasForQuestionAnswering",
+ "TFTapasForSequenceClassification",
+ "TFTapasModel",
+ "TFTapasPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vision_encoder_decoder"].extend(["TFVisionEncoderDecoderModel"])
+ _import_structure["models.vision_text_dual_encoder"].extend(["TFVisionTextDualEncoderModel"])
+ _import_structure["models.vit"].extend(
+ [
+ "TFViTForImageClassification",
+ "TFViTModel",
+ "TFViTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.vit_mae"].extend(
+ [
+ "TFViTMAEForPreTraining",
+ "TFViTMAEModel",
+ "TFViTMAEPreTrainedModel",
+ ]
+ )
+ _import_structure["models.wav2vec2"].extend(
+ [
+ "TFWav2Vec2ForCTC",
+ "TFWav2Vec2ForSequenceClassification",
+ "TFWav2Vec2Model",
+ "TFWav2Vec2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.whisper"].extend(
+ [
+ "TFWhisperForConditionalGeneration",
+ "TFWhisperModel",
+ "TFWhisperPreTrainedModel",
+ ]
+ )
+ _import_structure["models.xglm"].extend(
+ [
+ "TFXGLMForCausalLM",
+ "TFXGLMModel",
+ "TFXGLMPreTrainedModel",
+ ]
+ )
+ _import_structure["models.xlm"].extend(
+ [
+ "TFXLMForMultipleChoice",
+ "TFXLMForQuestionAnsweringSimple",
+ "TFXLMForSequenceClassification",
+ "TFXLMForTokenClassification",
+ "TFXLMMainLayer",
+ "TFXLMModel",
+ "TFXLMPreTrainedModel",
+ "TFXLMWithLMHeadModel",
+ ]
+ )
+ _import_structure["models.xlm_roberta"].extend(
+ [
+ "TFXLMRobertaForCausalLM",
+ "TFXLMRobertaForMaskedLM",
+ "TFXLMRobertaForMultipleChoice",
+ "TFXLMRobertaForQuestionAnswering",
+ "TFXLMRobertaForSequenceClassification",
+ "TFXLMRobertaForTokenClassification",
+ "TFXLMRobertaModel",
+ "TFXLMRobertaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.xlnet"].extend(
+ [
+ "TFXLNetForMultipleChoice",
+ "TFXLNetForQuestionAnsweringSimple",
+ "TFXLNetForSequenceClassification",
+ "TFXLNetForTokenClassification",
+ "TFXLNetLMHeadModel",
+ "TFXLNetMainLayer",
+ "TFXLNetModel",
+ "TFXLNetPreTrainedModel",
+ ]
+ )
+ _import_structure["optimization_tf"] = [
+ "AdamWeightDecay",
+ "GradientAccumulator",
+ "WarmUp",
+ "create_optimizer",
+ ]
+ _import_structure["tf_utils"] = []
+
+
+try:
+ if not (
+ is_librosa_available()
+ and is_essentia_available()
+ and is_scipy_available()
+ and is_torch_available()
+ and is_pretty_midi_available()
+ ):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import (
+ dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects,
+ )
+
+ _import_structure["utils.dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects"] = [
+ name
+ for name in dir(dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects)
+ if not name.startswith("_")
+ ]
+else:
+ _import_structure["models.pop2piano"].append("Pop2PianoFeatureExtractor")
+ _import_structure["models.pop2piano"].append("Pop2PianoTokenizer")
+ _import_structure["models.pop2piano"].append("Pop2PianoProcessor")
+
+try:
+ if not is_torchaudio_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import (
+ dummy_torchaudio_objects,
+ )
+
+ _import_structure["utils.dummy_torchaudio_objects"] = [
+ name for name in dir(dummy_torchaudio_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["models.musicgen_melody"].append("MusicgenMelodyFeatureExtractor")
+ _import_structure["models.musicgen_melody"].append("MusicgenMelodyProcessor")
+
+
+# FLAX-backed objects
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_flax_objects
+
+ _import_structure["utils.dummy_flax_objects"] = [
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["generation"].extend(
+ [
+ "FlaxForcedBOSTokenLogitsProcessor",
+ "FlaxForcedEOSTokenLogitsProcessor",
+ "FlaxForceTokensLogitsProcessor",
+ "FlaxGenerationMixin",
+ "FlaxLogitsProcessor",
+ "FlaxLogitsProcessorList",
+ "FlaxLogitsWarper",
+ "FlaxMinLengthLogitsProcessor",
+ "FlaxTemperatureLogitsWarper",
+ "FlaxSuppressTokensAtBeginLogitsProcessor",
+ "FlaxSuppressTokensLogitsProcessor",
+ "FlaxTopKLogitsWarper",
+ "FlaxTopPLogitsWarper",
+ "FlaxWhisperTimeStampLogitsProcessor",
+ ]
+ )
+ _import_structure["modeling_flax_outputs"] = []
+ _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
+ _import_structure["models.albert"].extend(
+ [
+ "FlaxAlbertForMaskedLM",
+ "FlaxAlbertForMultipleChoice",
+ "FlaxAlbertForPreTraining",
+ "FlaxAlbertForQuestionAnswering",
+ "FlaxAlbertForSequenceClassification",
+ "FlaxAlbertForTokenClassification",
+ "FlaxAlbertModel",
+ "FlaxAlbertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.auto"].extend(
+ [
+ "FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
+ "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
+ "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
+ "FLAX_MODEL_FOR_MASKED_LM_MAPPING",
+ "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
+ "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
+ "FLAX_MODEL_FOR_PRETRAINING_MAPPING",
+ "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
+ "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
+ "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
+ "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
+ "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
+ "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING",
+ "FLAX_MODEL_MAPPING",
+ "FlaxAutoModel",
+ "FlaxAutoModelForCausalLM",
+ "FlaxAutoModelForImageClassification",
+ "FlaxAutoModelForMaskedLM",
+ "FlaxAutoModelForMultipleChoice",
+ "FlaxAutoModelForNextSentencePrediction",
+ "FlaxAutoModelForPreTraining",
+ "FlaxAutoModelForQuestionAnswering",
+ "FlaxAutoModelForSeq2SeqLM",
+ "FlaxAutoModelForSequenceClassification",
+ "FlaxAutoModelForSpeechSeq2Seq",
+ "FlaxAutoModelForTokenClassification",
+ "FlaxAutoModelForVision2Seq",
+ ]
+ )
+
+ # Flax models structure
+
+ _import_structure["models.bart"].extend(
+ [
+ "FlaxBartDecoderPreTrainedModel",
+ "FlaxBartForCausalLM",
+ "FlaxBartForConditionalGeneration",
+ "FlaxBartForQuestionAnswering",
+ "FlaxBartForSequenceClassification",
+ "FlaxBartModel",
+ "FlaxBartPreTrainedModel",
+ ]
+ )
+ _import_structure["models.beit"].extend(
+ [
+ "FlaxBeitForImageClassification",
+ "FlaxBeitForMaskedImageModeling",
+ "FlaxBeitModel",
+ "FlaxBeitPreTrainedModel",
+ ]
+ )
+
+ _import_structure["models.bert"].extend(
+ [
+ "FlaxBertForCausalLM",
+ "FlaxBertForMaskedLM",
+ "FlaxBertForMultipleChoice",
+ "FlaxBertForNextSentencePrediction",
+ "FlaxBertForPreTraining",
+ "FlaxBertForQuestionAnswering",
+ "FlaxBertForSequenceClassification",
+ "FlaxBertForTokenClassification",
+ "FlaxBertModel",
+ "FlaxBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.big_bird"].extend(
+ [
+ "FlaxBigBirdForCausalLM",
+ "FlaxBigBirdForMaskedLM",
+ "FlaxBigBirdForMultipleChoice",
+ "FlaxBigBirdForPreTraining",
+ "FlaxBigBirdForQuestionAnswering",
+ "FlaxBigBirdForSequenceClassification",
+ "FlaxBigBirdForTokenClassification",
+ "FlaxBigBirdModel",
+ "FlaxBigBirdPreTrainedModel",
+ ]
+ )
+ _import_structure["models.blenderbot"].extend(
+ [
+ "FlaxBlenderbotForConditionalGeneration",
+ "FlaxBlenderbotModel",
+ "FlaxBlenderbotPreTrainedModel",
+ ]
+ )
+ _import_structure["models.blenderbot_small"].extend(
+ [
+ "FlaxBlenderbotSmallForConditionalGeneration",
+ "FlaxBlenderbotSmallModel",
+ "FlaxBlenderbotSmallPreTrainedModel",
+ ]
+ )
+ _import_structure["models.bloom"].extend(
+ [
+ "FlaxBloomForCausalLM",
+ "FlaxBloomModel",
+ "FlaxBloomPreTrainedModel",
+ ]
+ )
+ _import_structure["models.clip"].extend(
+ [
+ "FlaxCLIPModel",
+ "FlaxCLIPPreTrainedModel",
+ "FlaxCLIPTextModel",
+ "FlaxCLIPTextPreTrainedModel",
+ "FlaxCLIPTextModelWithProjection",
+ "FlaxCLIPVisionModel",
+ "FlaxCLIPVisionPreTrainedModel",
+ ]
+ )
+ _import_structure["models.dinov2"].extend(
+ [
+ "FlaxDinov2Model",
+ "FlaxDinov2ForImageClassification",
+ "FlaxDinov2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.distilbert"].extend(
+ [
+ "FlaxDistilBertForMaskedLM",
+ "FlaxDistilBertForMultipleChoice",
+ "FlaxDistilBertForQuestionAnswering",
+ "FlaxDistilBertForSequenceClassification",
+ "FlaxDistilBertForTokenClassification",
+ "FlaxDistilBertModel",
+ "FlaxDistilBertPreTrainedModel",
+ ]
+ )
+ _import_structure["models.electra"].extend(
+ [
+ "FlaxElectraForCausalLM",
+ "FlaxElectraForMaskedLM",
+ "FlaxElectraForMultipleChoice",
+ "FlaxElectraForPreTraining",
+ "FlaxElectraForQuestionAnswering",
+ "FlaxElectraForSequenceClassification",
+ "FlaxElectraForTokenClassification",
+ "FlaxElectraModel",
+ "FlaxElectraPreTrainedModel",
+ ]
+ )
+ _import_structure["models.encoder_decoder"].append("FlaxEncoderDecoderModel")
+ _import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
+ _import_structure["models.gpt_neo"].extend(
+ ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
+ )
+ _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
+ _import_structure["models.llama"].extend(["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"])
+ _import_structure["models.gemma"].extend(["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"])
+ _import_structure["models.longt5"].extend(
+ [
+ "FlaxLongT5ForConditionalGeneration",
+ "FlaxLongT5Model",
+ "FlaxLongT5PreTrainedModel",
+ ]
+ )
+ _import_structure["models.marian"].extend(
+ [
+ "FlaxMarianModel",
+ "FlaxMarianMTModel",
+ "FlaxMarianPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mbart"].extend(
+ [
+ "FlaxMBartForConditionalGeneration",
+ "FlaxMBartForQuestionAnswering",
+ "FlaxMBartForSequenceClassification",
+ "FlaxMBartModel",
+ "FlaxMBartPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mistral"].extend(
+ [
+ "FlaxMistralForCausalLM",
+ "FlaxMistralModel",
+ "FlaxMistralPreTrainedModel",
+ ]
+ )
+ _import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
+ _import_structure["models.opt"].extend(
+ [
+ "FlaxOPTForCausalLM",
+ "FlaxOPTModel",
+ "FlaxOPTPreTrainedModel",
+ ]
+ )
+ _import_structure["models.pegasus"].extend(
+ [
+ "FlaxPegasusForConditionalGeneration",
+ "FlaxPegasusModel",
+ "FlaxPegasusPreTrainedModel",
+ ]
+ )
+ _import_structure["models.regnet"].extend(
+ [
+ "FlaxRegNetForImageClassification",
+ "FlaxRegNetModel",
+ "FlaxRegNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.resnet"].extend(
+ [
+ "FlaxResNetForImageClassification",
+ "FlaxResNetModel",
+ "FlaxResNetPreTrainedModel",
+ ]
+ )
+ _import_structure["models.roberta"].extend(
+ [
+ "FlaxRobertaForCausalLM",
+ "FlaxRobertaForMaskedLM",
+ "FlaxRobertaForMultipleChoice",
+ "FlaxRobertaForQuestionAnswering",
+ "FlaxRobertaForSequenceClassification",
+ "FlaxRobertaForTokenClassification",
+ "FlaxRobertaModel",
+ "FlaxRobertaPreTrainedModel",
+ ]
+ )
+ _import_structure["models.roberta_prelayernorm"].extend(
+ [
+ "FlaxRobertaPreLayerNormForCausalLM",
+ "FlaxRobertaPreLayerNormForMaskedLM",
+ "FlaxRobertaPreLayerNormForMultipleChoice",
+ "FlaxRobertaPreLayerNormForQuestionAnswering",
+ "FlaxRobertaPreLayerNormForSequenceClassification",
+ "FlaxRobertaPreLayerNormForTokenClassification",
+ "FlaxRobertaPreLayerNormModel",
+ "FlaxRobertaPreLayerNormPreTrainedModel",
+ ]
+ )
+ _import_structure["models.roformer"].extend(
+ [
+ "FlaxRoFormerForMaskedLM",
+ "FlaxRoFormerForMultipleChoice",
+ "FlaxRoFormerForQuestionAnswering",
+ "FlaxRoFormerForSequenceClassification",
+ "FlaxRoFormerForTokenClassification",
+ "FlaxRoFormerModel",
+ "FlaxRoFormerPreTrainedModel",
+ ]
+ )
+ _import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel")
+ _import_structure["models.t5"].extend(
+ [
+ "FlaxT5EncoderModel",
+ "FlaxT5ForConditionalGeneration",
+ "FlaxT5Model",
+ "FlaxT5PreTrainedModel",
+ ]
+ )
+ _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
+ _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
+ _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
+ _import_structure["models.wav2vec2"].extend(
+ [
+ "FlaxWav2Vec2ForCTC",
+ "FlaxWav2Vec2ForPreTraining",
+ "FlaxWav2Vec2Model",
+ "FlaxWav2Vec2PreTrainedModel",
+ ]
+ )
+ _import_structure["models.whisper"].extend(
+ [
+ "FlaxWhisperForConditionalGeneration",
+ "FlaxWhisperModel",
+ "FlaxWhisperPreTrainedModel",
+ "FlaxWhisperForAudioClassification",
+ ]
+ )
+ _import_structure["models.xglm"].extend(
+ [
+ "FlaxXGLMForCausalLM",
+ "FlaxXGLMModel",
+ "FlaxXGLMPreTrainedModel",
+ ]
+ )
+ _import_structure["models.xlm_roberta"].extend(
+ [
+ "FlaxXLMRobertaForMaskedLM",
+ "FlaxXLMRobertaForMultipleChoice",
+ "FlaxXLMRobertaForQuestionAnswering",
+ "FlaxXLMRobertaForSequenceClassification",
+ "FlaxXLMRobertaForTokenClassification",
+ "FlaxXLMRobertaModel",
+ "FlaxXLMRobertaForCausalLM",
+ "FlaxXLMRobertaPreTrainedModel",
+ ]
+ )
+
+
+# Direct imports for type-checking
+if TYPE_CHECKING:
+ # Configuration
+ # Agents
+ from .agents import (
+ Agent,
+ CodeAgent,
+ HfApiEngine,
+ ManagedAgent,
+ PipelineTool,
+ ReactAgent,
+ ReactCodeAgent,
+ ReactJsonAgent,
+ Tool,
+ Toolbox,
+ ToolCollection,
+ TransformersEngine,
+ launch_gradio_demo,
+ load_tool,
+ stream_to_gradio,
+ tool,
+ )
+ from .configuration_utils import PretrainedConfig
+
+ # Data
+ from .data import (
+ DataProcessor,
+ InputExample,
+ InputFeatures,
+ SingleSentenceClassificationProcessor,
+ SquadExample,
+ SquadFeatures,
+ SquadV1Processor,
+ SquadV2Processor,
+ glue_compute_metrics,
+ glue_convert_examples_to_features,
+ glue_output_modes,
+ glue_processors,
+ glue_tasks_num_labels,
+ squad_convert_examples_to_features,
+ xnli_compute_metrics,
+ xnli_output_modes,
+ xnli_processors,
+ xnli_tasks_num_labels,
+ )
+ from .data.data_collator import (
+ DataCollator,
+ DataCollatorForLanguageModeling,
+ DataCollatorForPermutationLanguageModeling,
+ DataCollatorForSeq2Seq,
+ DataCollatorForSOP,
+ DataCollatorForTokenClassification,
+ DataCollatorForWholeWordMask,
+ DataCollatorWithFlattening,
+ DataCollatorWithPadding,
+ DefaultDataCollator,
+ default_data_collator,
+ )
+ from .feature_extraction_sequence_utils import SequenceFeatureExtractor
+
+ # Feature Extractor
+ from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+
+ # Generation
+ from .generation import (
+ AsyncTextIteratorStreamer,
+ CompileConfig,
+ GenerationConfig,
+ TextIteratorStreamer,
+ TextStreamer,
+ WatermarkingConfig,
+ )
+ from .hf_argparser import HfArgumentParser
+
+ # Integrations
+ from .integrations import (
+ is_clearml_available,
+ is_comet_available,
+ is_dvclive_available,
+ is_neptune_available,
+ is_optuna_available,
+ is_ray_available,
+ is_ray_tune_available,
+ is_sigopt_available,
+ is_tensorboard_available,
+ is_wandb_available,
+ )
+
+ # Model Cards
+ from .modelcard import ModelCard
+
+ # TF 2.0 <=> PyTorch conversion utilities
+ from .modeling_tf_pytorch_utils import (
+ convert_tf_weight_name_to_pt_weight_name,
+ load_pytorch_checkpoint_in_tf2_model,
+ load_pytorch_model_in_tf2_model,
+ load_pytorch_weights_in_tf2_model,
+ load_tf2_checkpoint_in_pytorch_model,
+ load_tf2_model_in_pytorch_model,
+ load_tf2_weights_in_pytorch_model,
+ )
+ from .models.albert import AlbertConfig
+ from .models.align import (
+ AlignConfig,
+ AlignProcessor,
+ AlignTextConfig,
+ AlignVisionConfig,
+ )
+ from .models.altclip import (
+ AltCLIPConfig,
+ AltCLIPProcessor,
+ AltCLIPTextConfig,
+ AltCLIPVisionConfig,
+ )
+ from .models.aria import (
+ AriaConfig,
+ AriaProcessor,
+ AriaTextConfig,
+ )
+ from .models.audio_spectrogram_transformer import (
+ ASTConfig,
+ ASTFeatureExtractor,
+ )
+ from .models.auto import (
+ CONFIG_MAPPING,
+ FEATURE_EXTRACTOR_MAPPING,
+ IMAGE_PROCESSOR_MAPPING,
+ MODEL_NAMES_MAPPING,
+ PROCESSOR_MAPPING,
+ TOKENIZER_MAPPING,
+ AutoConfig,
+ AutoFeatureExtractor,
+ AutoImageProcessor,
+ AutoProcessor,
+ AutoTokenizer,
+ )
+ from .models.autoformer import (
+ AutoformerConfig,
+ )
+ from .models.bamba import BambaConfig
+ from .models.bark import (
+ BarkCoarseConfig,
+ BarkConfig,
+ BarkFineConfig,
+ BarkProcessor,
+ BarkSemanticConfig,
+ )
+ from .models.bart import BartConfig, BartTokenizer
+ from .models.beit import BeitConfig
+ from .models.bert import (
+ BasicTokenizer,
+ BertConfig,
+ BertTokenizer,
+ WordpieceTokenizer,
+ )
+ from .models.bert_generation import BertGenerationConfig
+ from .models.bert_japanese import (
+ BertJapaneseTokenizer,
+ CharacterTokenizer,
+ MecabTokenizer,
+ )
+ from .models.bertweet import BertweetTokenizer
+ from .models.big_bird import BigBirdConfig
+ from .models.bigbird_pegasus import (
+ BigBirdPegasusConfig,
+ )
+ from .models.biogpt import (
+ BioGptConfig,
+ BioGptTokenizer,
+ )
+ from .models.bit import BitConfig
+ from .models.blenderbot import (
+ BlenderbotConfig,
+ BlenderbotTokenizer,
+ )
+ from .models.blenderbot_small import (
+ BlenderbotSmallConfig,
+ BlenderbotSmallTokenizer,
+ )
+ from .models.blip import (
+ BlipConfig,
+ BlipProcessor,
+ BlipTextConfig,
+ BlipVisionConfig,
+ )
+ from .models.blip_2 import (
+ Blip2Config,
+ Blip2Processor,
+ Blip2QFormerConfig,
+ Blip2VisionConfig,
+ )
+ from .models.bloom import BloomConfig
+ from .models.bridgetower import (
+ BridgeTowerConfig,
+ BridgeTowerProcessor,
+ BridgeTowerTextConfig,
+ BridgeTowerVisionConfig,
+ )
+ from .models.bros import (
+ BrosConfig,
+ BrosProcessor,
+ )
+ from .models.byt5 import ByT5Tokenizer
+ from .models.camembert import (
+ CamembertConfig,
+ )
+ from .models.canine import (
+ CanineConfig,
+ CanineTokenizer,
+ )
+ from .models.chameleon import (
+ ChameleonConfig,
+ ChameleonProcessor,
+ ChameleonVQVAEConfig,
+ )
+ from .models.chinese_clip import (
+ ChineseCLIPConfig,
+ ChineseCLIPProcessor,
+ ChineseCLIPTextConfig,
+ ChineseCLIPVisionConfig,
+ )
+ from .models.clap import (
+ ClapAudioConfig,
+ ClapConfig,
+ ClapProcessor,
+ ClapTextConfig,
+ )
+ from .models.clip import (
+ CLIPConfig,
+ CLIPProcessor,
+ CLIPTextConfig,
+ CLIPTokenizer,
+ CLIPVisionConfig,
+ )
+ from .models.clipseg import (
+ CLIPSegConfig,
+ CLIPSegProcessor,
+ CLIPSegTextConfig,
+ CLIPSegVisionConfig,
+ )
+ from .models.clvp import (
+ ClvpConfig,
+ ClvpDecoderConfig,
+ ClvpEncoderConfig,
+ ClvpFeatureExtractor,
+ ClvpProcessor,
+ ClvpTokenizer,
+ )
+ from .models.codegen import (
+ CodeGenConfig,
+ CodeGenTokenizer,
+ )
+ from .models.cohere import CohereConfig
+ from .models.cohere2 import Cohere2Config
+ from .models.colpali import (
+ ColPaliConfig,
+ ColPaliProcessor,
+ )
+ from .models.conditional_detr import (
+ ConditionalDetrConfig,
+ )
+ from .models.convbert import (
+ ConvBertConfig,
+ ConvBertTokenizer,
+ )
+ from .models.convnext import ConvNextConfig
+ from .models.convnextv2 import (
+ ConvNextV2Config,
+ )
+ from .models.cpmant import (
+ CpmAntConfig,
+ CpmAntTokenizer,
+ )
+ from .models.ctrl import (
+ CTRLConfig,
+ CTRLTokenizer,
+ )
+ from .models.cvt import CvtConfig
+ from .models.dac import (
+ DacConfig,
+ DacFeatureExtractor,
+ )
+ from .models.data2vec import (
+ Data2VecAudioConfig,
+ Data2VecTextConfig,
+ Data2VecVisionConfig,
+ )
+ from .models.dbrx import DbrxConfig
+ from .models.deberta import (
+ DebertaConfig,
+ DebertaTokenizer,
+ )
+ from .models.deberta_v2 import (
+ DebertaV2Config,
+ )
+ from .models.decision_transformer import (
+ DecisionTransformerConfig,
+ )
+ from .models.deformable_detr import (
+ DeformableDetrConfig,
+ )
+ from .models.deit import DeiTConfig
+ from .models.deprecated.deta import DetaConfig
+ from .models.deprecated.efficientformer import (
+ EfficientFormerConfig,
+ )
+ from .models.deprecated.ernie_m import ErnieMConfig
+ from .models.deprecated.gptsan_japanese import (
+ GPTSanJapaneseConfig,
+ GPTSanJapaneseTokenizer,
+ )
+ from .models.deprecated.graphormer import GraphormerConfig
+ from .models.deprecated.jukebox import (
+ JukeboxConfig,
+ JukeboxPriorConfig,
+ JukeboxTokenizer,
+ JukeboxVQVAEConfig,
+ )
+ from .models.deprecated.mctct import (
+ MCTCTConfig,
+ MCTCTFeatureExtractor,
+ MCTCTProcessor,
+ )
+ from .models.deprecated.mega import MegaConfig
+ from .models.deprecated.mmbt import MMBTConfig
+ from .models.deprecated.nat import NatConfig
+ from .models.deprecated.nezha import NezhaConfig
+ from .models.deprecated.open_llama import (
+ OpenLlamaConfig,
+ )
+ from .models.deprecated.qdqbert import QDQBertConfig
+ from .models.deprecated.realm import (
+ RealmConfig,
+ RealmTokenizer,
+ )
+ from .models.deprecated.retribert import (
+ RetriBertConfig,
+ RetriBertTokenizer,
+ )
+ from .models.deprecated.speech_to_text_2 import (
+ Speech2Text2Config,
+ Speech2Text2Processor,
+ Speech2Text2Tokenizer,
+ )
+ from .models.deprecated.tapex import TapexTokenizer
+ from .models.deprecated.trajectory_transformer import (
+ TrajectoryTransformerConfig,
+ )
+ from .models.deprecated.transfo_xl import (
+ TransfoXLConfig,
+ TransfoXLCorpus,
+ TransfoXLTokenizer,
+ )
+ from .models.deprecated.tvlt import (
+ TvltConfig,
+ TvltFeatureExtractor,
+ TvltProcessor,
+ )
+ from .models.deprecated.van import VanConfig
+ from .models.deprecated.vit_hybrid import (
+ ViTHybridConfig,
+ )
+ from .models.deprecated.xlm_prophetnet import (
+ XLMProphetNetConfig,
+ )
+ from .models.depth_anything import DepthAnythingConfig
+ from .models.detr import DetrConfig
+ from .models.diffllama import DiffLlamaConfig
+ from .models.dinat import DinatConfig
+ from .models.dinov2 import Dinov2Config
+ from .models.dinov2_with_registers import Dinov2WithRegistersConfig
+ from .models.distilbert import (
+ DistilBertConfig,
+ DistilBertTokenizer,
+ )
+ from .models.donut import (
+ DonutProcessor,
+ DonutSwinConfig,
+ )
+ from .models.dpr import (
+ DPRConfig,
+ DPRContextEncoderTokenizer,
+ DPRQuestionEncoderTokenizer,
+ DPRReaderOutput,
+ DPRReaderTokenizer,
+ )
+ from .models.dpt import DPTConfig
+ from .models.efficientnet import (
+ EfficientNetConfig,
+ )
+ from .models.electra import (
+ ElectraConfig,
+ ElectraTokenizer,
+ )
+ from .models.emu3 import (
+ Emu3Config,
+ Emu3Processor,
+ Emu3TextConfig,
+ Emu3VQVAEConfig,
+ )
+ from .models.encodec import (
+ EncodecConfig,
+ EncodecFeatureExtractor,
+ )
+ from .models.encoder_decoder import EncoderDecoderConfig
+ from .models.ernie import ErnieConfig
+ from .models.esm import EsmConfig, EsmTokenizer
+ from .models.falcon import FalconConfig
+ from .models.falcon_mamba import FalconMambaConfig
+ from .models.fastspeech2_conformer import (
+ FastSpeech2ConformerConfig,
+ FastSpeech2ConformerHifiGanConfig,
+ FastSpeech2ConformerTokenizer,
+ FastSpeech2ConformerWithHifiGanConfig,
+ )
+ from .models.flaubert import FlaubertConfig, FlaubertTokenizer
+ from .models.flava import (
+ FlavaConfig,
+ FlavaImageCodebookConfig,
+ FlavaImageConfig,
+ FlavaMultimodalConfig,
+ FlavaTextConfig,
+ )
+ from .models.fnet import FNetConfig
+ from .models.focalnet import FocalNetConfig
+ from .models.fsmt import (
+ FSMTConfig,
+ FSMTTokenizer,
+ )
+ from .models.funnel import (
+ FunnelConfig,
+ FunnelTokenizer,
+ )
+ from .models.fuyu import FuyuConfig
+ from .models.gemma import GemmaConfig
+ from .models.gemma2 import Gemma2Config
+ from .models.git import (
+ GitConfig,
+ GitProcessor,
+ GitVisionConfig,
+ )
+ from .models.glm import GlmConfig
+ from .models.glpn import GLPNConfig
+ from .models.gpt2 import (
+ GPT2Config,
+ GPT2Tokenizer,
+ )
+ from .models.gpt_bigcode import (
+ GPTBigCodeConfig,
+ )
+ from .models.gpt_neo import GPTNeoConfig
+ from .models.gpt_neox import GPTNeoXConfig
+ from .models.gpt_neox_japanese import (
+ GPTNeoXJapaneseConfig,
+ )
+ from .models.gptj import GPTJConfig
+ from .models.granite import GraniteConfig
+ from .models.granitemoe import GraniteMoeConfig
+ from .models.grounding_dino import (
+ GroundingDinoConfig,
+ GroundingDinoProcessor,
+ )
+ from .models.groupvit import (
+ GroupViTConfig,
+ GroupViTTextConfig,
+ GroupViTVisionConfig,
+ )
+ from .models.herbert import HerbertTokenizer
+ from .models.hiera import HieraConfig
+ from .models.hubert import HubertConfig
+ from .models.ibert import IBertConfig
+ from .models.idefics import (
+ IdeficsConfig,
+ )
+ from .models.idefics2 import Idefics2Config
+ from .models.idefics3 import Idefics3Config
+ from .models.ijepa import IJepaConfig
+ from .models.imagegpt import ImageGPTConfig
+ from .models.informer import InformerConfig
+ from .models.instructblip import (
+ InstructBlipConfig,
+ InstructBlipProcessor,
+ InstructBlipQFormerConfig,
+ InstructBlipVisionConfig,
+ )
+ from .models.instructblipvideo import (
+ InstructBlipVideoConfig,
+ InstructBlipVideoProcessor,
+ InstructBlipVideoQFormerConfig,
+ InstructBlipVideoVisionConfig,
+ )
+ from .models.jamba import JambaConfig
+ from .models.jetmoe import JetMoeConfig
+ from .models.kosmos2 import (
+ Kosmos2Config,
+ Kosmos2Processor,
+ )
+ from .models.layoutlm import (
+ LayoutLMConfig,
+ LayoutLMTokenizer,
+ )
+ from .models.layoutlmv2 import (
+ LayoutLMv2Config,
+ LayoutLMv2FeatureExtractor,
+ LayoutLMv2ImageProcessor,
+ LayoutLMv2Processor,
+ LayoutLMv2Tokenizer,
+ )
+ from .models.layoutlmv3 import (
+ LayoutLMv3Config,
+ LayoutLMv3FeatureExtractor,
+ LayoutLMv3ImageProcessor,
+ LayoutLMv3Processor,
+ LayoutLMv3Tokenizer,
+ )
+ from .models.layoutxlm import LayoutXLMProcessor
+ from .models.led import LEDConfig, LEDTokenizer
+ from .models.levit import LevitConfig
+ from .models.lilt import LiltConfig
+ from .models.llama import LlamaConfig
+ from .models.llava import (
+ LlavaConfig,
+ LlavaProcessor,
+ )
+ from .models.llava_next import (
+ LlavaNextConfig,
+ LlavaNextProcessor,
+ )
+ from .models.llava_next_video import (
+ LlavaNextVideoConfig,
+ LlavaNextVideoProcessor,
+ )
+ from .models.llava_onevision import (
+ LlavaOnevisionConfig,
+ LlavaOnevisionProcessor,
+ )
+ from .models.longformer import (
+ LongformerConfig,
+ LongformerTokenizer,
+ )
+ from .models.longt5 import LongT5Config
+ from .models.luke import (
+ LukeConfig,
+ LukeTokenizer,
+ )
+ from .models.lxmert import (
+ LxmertConfig,
+ LxmertTokenizer,
+ )
+ from .models.m2m_100 import M2M100Config
+ from .models.mamba import MambaConfig
+ from .models.mamba2 import Mamba2Config
+ from .models.marian import MarianConfig
+ from .models.markuplm import (
+ MarkupLMConfig,
+ MarkupLMFeatureExtractor,
+ MarkupLMProcessor,
+ MarkupLMTokenizer,
+ )
+ from .models.mask2former import (
+ Mask2FormerConfig,
+ )
+ from .models.maskformer import (
+ MaskFormerConfig,
+ MaskFormerSwinConfig,
+ )
+ from .models.mbart import MBartConfig
+ from .models.megatron_bert import (
+ MegatronBertConfig,
+ )
+ from .models.mgp_str import (
+ MgpstrConfig,
+ MgpstrProcessor,
+ MgpstrTokenizer,
+ )
+ from .models.mimi import (
+ MimiConfig,
+ )
+ from .models.mistral import MistralConfig
+ from .models.mixtral import MixtralConfig
+ from .models.mllama import (
+ MllamaConfig,
+ MllamaProcessor,
+ )
+ from .models.mobilebert import (
+ MobileBertConfig,
+ MobileBertTokenizer,
+ )
+ from .models.mobilenet_v1 import (
+ MobileNetV1Config,
+ )
+ from .models.mobilenet_v2 import (
+ MobileNetV2Config,
+ )
+ from .models.mobilevit import (
+ MobileViTConfig,
+ )
+ from .models.mobilevitv2 import (
+ MobileViTV2Config,
+ )
+ from .models.modernbert import ModernBertConfig
+ from .models.moonshine import MoonshineConfig
+ from .models.moshi import (
+ MoshiConfig,
+ MoshiDepthConfig,
+ )
+ from .models.mpnet import (
+ MPNetConfig,
+ MPNetTokenizer,
+ )
+ from .models.mpt import MptConfig
+ from .models.mra import MraConfig
+ from .models.mt5 import MT5Config
+ from .models.musicgen import (
+ MusicgenConfig,
+ MusicgenDecoderConfig,
+ )
+ from .models.musicgen_melody import (
+ MusicgenMelodyConfig,
+ MusicgenMelodyDecoderConfig,
+ )
+ from .models.mvp import MvpConfig, MvpTokenizer
+ from .models.myt5 import MyT5Tokenizer
+ from .models.nemotron import NemotronConfig
+ from .models.nllb_moe import NllbMoeConfig
+ from .models.nougat import NougatProcessor
+ from .models.nystromformer import (
+ NystromformerConfig,
+ )
+ from .models.olmo import OlmoConfig
+ from .models.olmo2 import Olmo2Config
+ from .models.olmoe import OlmoeConfig
+ from .models.omdet_turbo import (
+ OmDetTurboConfig,
+ OmDetTurboProcessor,
+ )
+ from .models.oneformer import (
+ OneFormerConfig,
+ OneFormerProcessor,
+ )
+ from .models.openai import (
+ OpenAIGPTConfig,
+ OpenAIGPTTokenizer,
+ )
+ from .models.opt import OPTConfig
+ from .models.owlv2 import (
+ Owlv2Config,
+ Owlv2Processor,
+ Owlv2TextConfig,
+ Owlv2VisionConfig,
+ )
+ from .models.owlvit import (
+ OwlViTConfig,
+ OwlViTProcessor,
+ OwlViTTextConfig,
+ OwlViTVisionConfig,
+ )
+ from .models.paligemma import (
+ PaliGemmaConfig,
+ )
+ from .models.patchtsmixer import (
+ PatchTSMixerConfig,
+ )
+ from .models.patchtst import PatchTSTConfig
+ from .models.pegasus import (
+ PegasusConfig,
+ PegasusTokenizer,
+ )
+ from .models.pegasus_x import (
+ PegasusXConfig,
+ )
+ from .models.perceiver import (
+ PerceiverConfig,
+ PerceiverTokenizer,
+ )
+ from .models.persimmon import (
+ PersimmonConfig,
+ )
+ from .models.phi import PhiConfig
+ from .models.phi3 import Phi3Config
+ from .models.phimoe import PhimoeConfig
+ from .models.phobert import PhobertTokenizer
+ from .models.pix2struct import (
+ Pix2StructConfig,
+ Pix2StructProcessor,
+ Pix2StructTextConfig,
+ Pix2StructVisionConfig,
+ )
+ from .models.pixtral import (
+ PixtralProcessor,
+ PixtralVisionConfig,
+ )
+ from .models.plbart import PLBartConfig
+ from .models.poolformer import (
+ PoolFormerConfig,
+ )
+ from .models.pop2piano import (
+ Pop2PianoConfig,
+ )
+ from .models.prophetnet import (
+ ProphetNetConfig,
+ ProphetNetTokenizer,
+ )
+ from .models.pvt import PvtConfig
+ from .models.pvt_v2 import PvtV2Config
+ from .models.qwen2 import Qwen2Config, Qwen2Tokenizer
+ from .models.qwen2_audio import (
+ Qwen2AudioConfig,
+ Qwen2AudioEncoderConfig,
+ Qwen2AudioProcessor,
+ )
+ from .models.qwen2_moe import Qwen2MoeConfig
+ from .models.qwen2_vl import (
+ Qwen2VLConfig,
+ Qwen2VLProcessor,
+ )
+ from .models.rag import RagConfig, RagRetriever, RagTokenizer
+ from .models.recurrent_gemma import RecurrentGemmaConfig
+ from .models.reformer import ReformerConfig
+ from .models.regnet import RegNetConfig
+ from .models.rembert import RemBertConfig
+ from .models.resnet import ResNetConfig
+ from .models.roberta import (
+ RobertaConfig,
+ RobertaTokenizer,
+ )
+ from .models.roberta_prelayernorm import (
+ RobertaPreLayerNormConfig,
+ )
+ from .models.roc_bert import (
+ RoCBertConfig,
+ RoCBertTokenizer,
+ )
+ from .models.roformer import (
+ RoFormerConfig,
+ RoFormerTokenizer,
+ )
+ from .models.rt_detr import (
+ RTDetrConfig,
+ RTDetrResNetConfig,
+ )
+ from .models.rwkv import RwkvConfig
+ from .models.sam import (
+ SamConfig,
+ SamMaskDecoderConfig,
+ SamProcessor,
+ SamPromptEncoderConfig,
+ SamVisionConfig,
+ )
+ from .models.seamless_m4t import (
+ SeamlessM4TConfig,
+ SeamlessM4TFeatureExtractor,
+ SeamlessM4TProcessor,
+ )
+ from .models.seamless_m4t_v2 import (
+ SeamlessM4Tv2Config,
+ )
+ from .models.segformer import SegformerConfig
+ from .models.seggpt import SegGptConfig
+ from .models.sew import SEWConfig
+ from .models.sew_d import SEWDConfig
+ from .models.siglip import (
+ SiglipConfig,
+ SiglipProcessor,
+ SiglipTextConfig,
+ SiglipVisionConfig,
+ )
+ from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig
+ from .models.speech_to_text import (
+ Speech2TextConfig,
+ Speech2TextFeatureExtractor,
+ Speech2TextProcessor,
+ )
+ from .models.speecht5 import (
+ SpeechT5Config,
+ SpeechT5FeatureExtractor,
+ SpeechT5HifiGanConfig,
+ SpeechT5Processor,
+ )
+ from .models.splinter import (
+ SplinterConfig,
+ SplinterTokenizer,
+ )
+ from .models.squeezebert import (
+ SqueezeBertConfig,
+ SqueezeBertTokenizer,
+ )
+ from .models.stablelm import StableLmConfig
+ from .models.starcoder2 import Starcoder2Config
+ from .models.superpoint import SuperPointConfig
+ from .models.swiftformer import (
+ SwiftFormerConfig,
+ )
+ from .models.swin import SwinConfig
+ from .models.swin2sr import Swin2SRConfig
+ from .models.swinv2 import Swinv2Config
+ from .models.switch_transformers import (
+ SwitchTransformersConfig,
+ )
+ from .models.t5 import T5Config
+ from .models.table_transformer import (
+ TableTransformerConfig,
+ )
+ from .models.tapas import (
+ TapasConfig,
+ TapasTokenizer,
+ )
+ from .models.textnet import TextNetConfig
+ from .models.time_series_transformer import (
+ TimeSeriesTransformerConfig,
+ )
+ from .models.timesformer import (
+ TimesformerConfig,
+ )
+ from .models.timm_backbone import TimmBackboneConfig
+ from .models.timm_wrapper import TimmWrapperConfig
+ from .models.trocr import (
+ TrOCRConfig,
+ TrOCRProcessor,
+ )
+ from .models.tvp import (
+ TvpConfig,
+ TvpProcessor,
+ )
+ from .models.udop import UdopConfig, UdopProcessor
+ from .models.umt5 import UMT5Config
+ from .models.unispeech import (
+ UniSpeechConfig,
+ )
+ from .models.unispeech_sat import (
+ UniSpeechSatConfig,
+ )
+ from .models.univnet import (
+ UnivNetConfig,
+ UnivNetFeatureExtractor,
+ )
+ from .models.upernet import UperNetConfig
+ from .models.video_llava import VideoLlavaConfig
+ from .models.videomae import VideoMAEConfig
+ from .models.vilt import (
+ ViltConfig,
+ ViltFeatureExtractor,
+ ViltImageProcessor,
+ ViltProcessor,
+ )
+ from .models.vipllava import (
+ VipLlavaConfig,
+ )
+ from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
+ from .models.vision_text_dual_encoder import (
+ VisionTextDualEncoderConfig,
+ VisionTextDualEncoderProcessor,
+ )
+ from .models.visual_bert import (
+ VisualBertConfig,
+ )
+ from .models.vit import ViTConfig
+ from .models.vit_mae import ViTMAEConfig
+ from .models.vit_msn import ViTMSNConfig
+ from .models.vitdet import VitDetConfig
+ from .models.vitmatte import VitMatteConfig
+ from .models.vitpose import VitPoseConfig
+ from .models.vitpose_backbone import VitPoseBackboneConfig
+ from .models.vits import (
+ VitsConfig,
+ VitsTokenizer,
+ )
+ from .models.vivit import VivitConfig
+ from .models.wav2vec2 import (
+ Wav2Vec2Config,
+ Wav2Vec2CTCTokenizer,
+ Wav2Vec2FeatureExtractor,
+ Wav2Vec2Processor,
+ Wav2Vec2Tokenizer,
+ )
+ from .models.wav2vec2_bert import (
+ Wav2Vec2BertConfig,
+ Wav2Vec2BertProcessor,
+ )
+ from .models.wav2vec2_conformer import (
+ Wav2Vec2ConformerConfig,
+ )
+ from .models.wav2vec2_phoneme import Wav2Vec2PhonemeCTCTokenizer
+ from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
+ from .models.wavlm import WavLMConfig
+ from .models.whisper import (
+ WhisperConfig,
+ WhisperFeatureExtractor,
+ WhisperProcessor,
+ WhisperTokenizer,
+ )
+ from .models.x_clip import (
+ XCLIPConfig,
+ XCLIPProcessor,
+ XCLIPTextConfig,
+ XCLIPVisionConfig,
+ )
+ from .models.xglm import XGLMConfig
+ from .models.xlm import XLMConfig, XLMTokenizer
+ from .models.xlm_roberta import (
+ XLMRobertaConfig,
+ )
+ from .models.xlm_roberta_xl import (
+ XLMRobertaXLConfig,
+ )
+ from .models.xlnet import XLNetConfig
+ from .models.xmod import XmodConfig
+ from .models.yolos import YolosConfig
+ from .models.yoso import YosoConfig
+ from .models.zamba import ZambaConfig
+ from .models.zoedepth import ZoeDepthConfig
+
+ # Pipelines
+ from .pipelines import (
+ AudioClassificationPipeline,
+ AutomaticSpeechRecognitionPipeline,
+ CsvPipelineDataFormat,
+ DepthEstimationPipeline,
+ DocumentQuestionAnsweringPipeline,
+ FeatureExtractionPipeline,
+ FillMaskPipeline,
+ ImageClassificationPipeline,
+ ImageFeatureExtractionPipeline,
+ ImageSegmentationPipeline,
+ ImageTextToTextPipeline,
+ ImageToImagePipeline,
+ ImageToTextPipeline,
+ JsonPipelineDataFormat,
+ MaskGenerationPipeline,
+ NerPipeline,
+ ObjectDetectionPipeline,
+ PipedPipelineDataFormat,
+ Pipeline,
+ PipelineDataFormat,
+ QuestionAnsweringPipeline,
+ SummarizationPipeline,
+ TableQuestionAnsweringPipeline,
+ Text2TextGenerationPipeline,
+ TextClassificationPipeline,
+ TextGenerationPipeline,
+ TextToAudioPipeline,
+ TokenClassificationPipeline,
+ TranslationPipeline,
+ VideoClassificationPipeline,
+ VisualQuestionAnsweringPipeline,
+ ZeroShotAudioClassificationPipeline,
+ ZeroShotClassificationPipeline,
+ ZeroShotImageClassificationPipeline,
+ ZeroShotObjectDetectionPipeline,
+ pipeline,
+ )
+ from .processing_utils import ProcessorMixin
+
+ # Tokenization
+ from .tokenization_utils import PreTrainedTokenizer
+ from .tokenization_utils_base import (
+ AddedToken,
+ BatchEncoding,
+ CharSpan,
+ PreTrainedTokenizerBase,
+ SpecialTokensMixin,
+ TokenSpan,
+ )
+
+ # Trainer
+ from .trainer_callback import (
+ DefaultFlowCallback,
+ EarlyStoppingCallback,
+ PrinterCallback,
+ ProgressCallback,
+ TrainerCallback,
+ TrainerControl,
+ TrainerState,
+ )
+ from .trainer_utils import (
+ EvalPrediction,
+ IntervalStrategy,
+ SchedulerType,
+ enable_full_determinism,
+ set_seed,
+ )
+ from .training_args import TrainingArguments
+ from .training_args_seq2seq import Seq2SeqTrainingArguments
+ from .training_args_tf import TFTrainingArguments
+
+ # Files and general utilities
+ from .utils import (
+ CONFIG_NAME,
+ MODEL_CARD_NAME,
+ PYTORCH_PRETRAINED_BERT_CACHE,
+ PYTORCH_TRANSFORMERS_CACHE,
+ SPIECE_UNDERLINE,
+ TF2_WEIGHTS_NAME,
+ TF_WEIGHTS_NAME,
+ TRANSFORMERS_CACHE,
+ WEIGHTS_NAME,
+ TensorType,
+ add_end_docstrings,
+ add_start_docstrings,
+ is_apex_available,
+ is_av_available,
+ is_bitsandbytes_available,
+ is_datasets_available,
+ is_faiss_available,
+ is_flax_available,
+ is_keras_nlp_available,
+ is_phonemizer_available,
+ is_psutil_available,
+ is_py3nvml_available,
+ is_pyctcdecode_available,
+ is_sacremoses_available,
+ is_safetensors_available,
+ is_scipy_available,
+ is_sentencepiece_available,
+ is_sklearn_available,
+ is_speech_available,
+ is_tensorflow_text_available,
+ is_tf_available,
+ is_timm_available,
+ is_tokenizers_available,
+ is_torch_available,
+ is_torch_mlu_available,
+ is_torch_musa_available,
+ is_torch_neuroncore_available,
+ is_torch_npu_available,
+ is_torch_tpu_available,
+ is_torch_xla_available,
+ is_torch_xpu_available,
+ is_torchvision_available,
+ is_vision_available,
+ logging,
+ )
+
+ # bitsandbytes config
+ from .utils.quantization_config import (
+ AqlmConfig,
+ AwqConfig,
+ BitNetConfig,
+ BitsAndBytesConfig,
+ CompressedTensorsConfig,
+ EetqConfig,
+ FbgemmFp8Config,
+ GPTQConfig,
+ HiggsConfig,
+ HqqConfig,
+ QuantoConfig,
+ TorchAoConfig,
+ VptqConfig,
+ )
+
+ try:
+ if not is_sentencepiece_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_sentencepiece_objects import *
+ else:
+ from .models.albert import AlbertTokenizer
+ from .models.barthez import BarthezTokenizer
+ from .models.bartpho import BartphoTokenizer
+ from .models.bert_generation import BertGenerationTokenizer
+ from .models.big_bird import BigBirdTokenizer
+ from .models.camembert import CamembertTokenizer
+ from .models.code_llama import CodeLlamaTokenizer
+ from .models.cpm import CpmTokenizer
+ from .models.deberta_v2 import DebertaV2Tokenizer
+ from .models.deprecated.ernie_m import ErnieMTokenizer
+ from .models.deprecated.xlm_prophetnet import XLMProphetNetTokenizer
+ from .models.fnet import FNetTokenizer
+ from .models.gemma import GemmaTokenizer
+ from .models.gpt_sw3 import GPTSw3Tokenizer
+ from .models.layoutxlm import LayoutXLMTokenizer
+ from .models.llama import LlamaTokenizer
+ from .models.m2m_100 import M2M100Tokenizer
+ from .models.marian import MarianTokenizer
+ from .models.mbart import MBartTokenizer
+ from .models.mbart50 import MBart50Tokenizer
+ from .models.mluke import MLukeTokenizer
+ from .models.mt5 import MT5Tokenizer
+ from .models.nllb import NllbTokenizer
+ from .models.pegasus import PegasusTokenizer
+ from .models.plbart import PLBartTokenizer
+ from .models.reformer import ReformerTokenizer
+ from .models.rembert import RemBertTokenizer
+ from .models.seamless_m4t import SeamlessM4TTokenizer
+ from .models.siglip import SiglipTokenizer
+ from .models.speech_to_text import Speech2TextTokenizer
+ from .models.speecht5 import SpeechT5Tokenizer
+ from .models.t5 import T5Tokenizer
+ from .models.udop import UdopTokenizer
+ from .models.xglm import XGLMTokenizer
+ from .models.xlm_roberta import XLMRobertaTokenizer
+ from .models.xlnet import XLNetTokenizer
+
+ try:
+ if not is_tokenizers_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_tokenizers_objects import *
+ else:
+ # Fast tokenizers imports
+ from .models.albert import AlbertTokenizerFast
+ from .models.bart import BartTokenizerFast
+ from .models.barthez import BarthezTokenizerFast
+ from .models.bert import BertTokenizerFast
+ from .models.big_bird import BigBirdTokenizerFast
+ from .models.blenderbot import BlenderbotTokenizerFast
+ from .models.blenderbot_small import BlenderbotSmallTokenizerFast
+ from .models.bloom import BloomTokenizerFast
+ from .models.camembert import CamembertTokenizerFast
+ from .models.clip import CLIPTokenizerFast
+ from .models.code_llama import CodeLlamaTokenizerFast
+ from .models.codegen import CodeGenTokenizerFast
+ from .models.cohere import CohereTokenizerFast
+ from .models.convbert import ConvBertTokenizerFast
+ from .models.cpm import CpmTokenizerFast
+ from .models.deberta import DebertaTokenizerFast
+ from .models.deberta_v2 import DebertaV2TokenizerFast
+ from .models.deprecated.realm import RealmTokenizerFast
+ from .models.deprecated.retribert import RetriBertTokenizerFast
+ from .models.distilbert import DistilBertTokenizerFast
+ from .models.dpr import (
+ DPRContextEncoderTokenizerFast,
+ DPRQuestionEncoderTokenizerFast,
+ DPRReaderTokenizerFast,
+ )
+ from .models.electra import ElectraTokenizerFast
+ from .models.fnet import FNetTokenizerFast
+ from .models.funnel import FunnelTokenizerFast
+ from .models.gemma import GemmaTokenizerFast
+ from .models.gpt2 import GPT2TokenizerFast
+ from .models.gpt_neox import GPTNeoXTokenizerFast
+ from .models.gpt_neox_japanese import GPTNeoXJapaneseTokenizer
+ from .models.herbert import HerbertTokenizerFast
+ from .models.layoutlm import LayoutLMTokenizerFast
+ from .models.layoutlmv2 import LayoutLMv2TokenizerFast
+ from .models.layoutlmv3 import LayoutLMv3TokenizerFast
+ from .models.layoutxlm import LayoutXLMTokenizerFast
+ from .models.led import LEDTokenizerFast
+ from .models.llama import LlamaTokenizerFast
+ from .models.longformer import LongformerTokenizerFast
+ from .models.lxmert import LxmertTokenizerFast
+ from .models.markuplm import MarkupLMTokenizerFast
+ from .models.mbart import MBartTokenizerFast
+ from .models.mbart50 import MBart50TokenizerFast
+ from .models.mobilebert import MobileBertTokenizerFast
+ from .models.mpnet import MPNetTokenizerFast
+ from .models.mt5 import MT5TokenizerFast
+ from .models.mvp import MvpTokenizerFast
+ from .models.nllb import NllbTokenizerFast
+ from .models.nougat import NougatTokenizerFast
+ from .models.openai import OpenAIGPTTokenizerFast
+ from .models.pegasus import PegasusTokenizerFast
+ from .models.qwen2 import Qwen2TokenizerFast
+ from .models.reformer import ReformerTokenizerFast
+ from .models.rembert import RemBertTokenizerFast
+ from .models.roberta import RobertaTokenizerFast
+ from .models.roformer import RoFormerTokenizerFast
+ from .models.seamless_m4t import SeamlessM4TTokenizerFast
+ from .models.splinter import SplinterTokenizerFast
+ from .models.squeezebert import SqueezeBertTokenizerFast
+ from .models.t5 import T5TokenizerFast
+ from .models.udop import UdopTokenizerFast
+ from .models.whisper import WhisperTokenizerFast
+ from .models.xglm import XGLMTokenizerFast
+ from .models.xlm_roberta import XLMRobertaTokenizerFast
+ from .models.xlnet import XLNetTokenizerFast
+ from .tokenization_utils_fast import PreTrainedTokenizerFast
+
+ try:
+ if not (is_sentencepiece_available() and is_tokenizers_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummies_sentencepiece_and_tokenizers_objects import *
+ else:
+ from .convert_slow_tokenizer import (
+ SLOW_TO_FAST_CONVERTERS,
+ convert_slow_tokenizer,
+ )
+
+ try:
+ if not is_tensorflow_text_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_tensorflow_text_objects import *
+ else:
+ from .models.bert import TFBertTokenizer
+
+ try:
+ if not is_keras_nlp_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_keras_nlp_objects import *
+ else:
+ from .models.gpt2 import TFGPT2Tokenizer
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_vision_objects import *
+ else:
+ from .image_processing_base import ImageProcessingMixin
+ from .image_processing_utils import BaseImageProcessor
+ from .image_utils import ImageFeatureExtractionMixin
+ from .models.aria import AriaImageProcessor
+ from .models.beit import BeitFeatureExtractor, BeitImageProcessor
+ from .models.bit import BitImageProcessor
+ from .models.blip import BlipImageProcessor
+ from .models.bridgetower import BridgeTowerImageProcessor
+ from .models.chameleon import ChameleonImageProcessor
+ from .models.chinese_clip import (
+ ChineseCLIPFeatureExtractor,
+ ChineseCLIPImageProcessor,
+ )
+ from .models.clip import CLIPFeatureExtractor, CLIPImageProcessor
+ from .models.conditional_detr import (
+ ConditionalDetrFeatureExtractor,
+ ConditionalDetrImageProcessor,
+ )
+ from .models.convnext import ConvNextFeatureExtractor, ConvNextImageProcessor
+ from .models.deformable_detr import DeformableDetrFeatureExtractor, DeformableDetrImageProcessor
+ from .models.deit import DeiTFeatureExtractor, DeiTImageProcessor
+ from .models.deprecated.deta import DetaImageProcessor
+ from .models.deprecated.efficientformer import EfficientFormerImageProcessor
+ from .models.deprecated.tvlt import TvltImageProcessor
+ from .models.deprecated.vit_hybrid import ViTHybridImageProcessor
+ from .models.detr import DetrFeatureExtractor, DetrImageProcessor
+ from .models.donut import DonutFeatureExtractor, DonutImageProcessor
+ from .models.dpt import DPTFeatureExtractor, DPTImageProcessor
+ from .models.efficientnet import EfficientNetImageProcessor
+ from .models.emu3 import Emu3ImageProcessor
+ from .models.flava import (
+ FlavaFeatureExtractor,
+ FlavaImageProcessor,
+ FlavaProcessor,
+ )
+ from .models.fuyu import FuyuImageProcessor, FuyuProcessor
+ from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor
+ from .models.grounding_dino import GroundingDinoImageProcessor
+ from .models.idefics import IdeficsImageProcessor
+ from .models.idefics2 import Idefics2ImageProcessor
+ from .models.idefics3 import Idefics3ImageProcessor
+ from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor
+ from .models.instructblipvideo import InstructBlipVideoImageProcessor
+ from .models.layoutlmv2 import (
+ LayoutLMv2FeatureExtractor,
+ LayoutLMv2ImageProcessor,
+ )
+ from .models.layoutlmv3 import (
+ LayoutLMv3FeatureExtractor,
+ LayoutLMv3ImageProcessor,
+ )
+ from .models.levit import LevitFeatureExtractor, LevitImageProcessor
+ from .models.llava_next import LlavaNextImageProcessor
+ from .models.llava_next_video import LlavaNextVideoImageProcessor
+ from .models.llava_onevision import LlavaOnevisionImageProcessor, LlavaOnevisionVideoProcessor
+ from .models.mask2former import Mask2FormerImageProcessor
+ from .models.maskformer import (
+ MaskFormerFeatureExtractor,
+ MaskFormerImageProcessor,
+ )
+ from .models.mllama import MllamaImageProcessor
+ from .models.mobilenet_v1 import (
+ MobileNetV1FeatureExtractor,
+ MobileNetV1ImageProcessor,
+ )
+ from .models.mobilenet_v2 import (
+ MobileNetV2FeatureExtractor,
+ MobileNetV2ImageProcessor,
+ )
+ from .models.mobilevit import MobileViTFeatureExtractor, MobileViTImageProcessor
+ from .models.nougat import NougatImageProcessor
+ from .models.oneformer import OneFormerImageProcessor
+ from .models.owlv2 import Owlv2ImageProcessor
+ from .models.owlvit import OwlViTFeatureExtractor, OwlViTImageProcessor
+ from .models.perceiver import PerceiverFeatureExtractor, PerceiverImageProcessor
+ from .models.pix2struct import Pix2StructImageProcessor
+ from .models.pixtral import PixtralImageProcessor
+ from .models.poolformer import (
+ PoolFormerFeatureExtractor,
+ PoolFormerImageProcessor,
+ )
+ from .models.pvt import PvtImageProcessor
+ from .models.qwen2_vl import Qwen2VLImageProcessor
+ from .models.rt_detr import RTDetrImageProcessor
+ from .models.sam import SamImageProcessor
+ from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
+ from .models.seggpt import SegGptImageProcessor
+ from .models.siglip import SiglipImageProcessor
+ from .models.superpoint import SuperPointImageProcessor
+ from .models.swin2sr import Swin2SRImageProcessor
+ from .models.textnet import TextNetImageProcessor
+ from .models.tvp import TvpImageProcessor
+ from .models.video_llava import VideoLlavaImageProcessor
+ from .models.videomae import VideoMAEFeatureExtractor, VideoMAEImageProcessor
+ from .models.vilt import ViltFeatureExtractor, ViltImageProcessor, ViltProcessor
+ from .models.vit import ViTFeatureExtractor, ViTImageProcessor
+ from .models.vitmatte import VitMatteImageProcessor
+ from .models.vitpose import VitPoseImageProcessor
+ from .models.vivit import VivitImageProcessor
+ from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
+ from .models.zoedepth import ZoeDepthImageProcessor
+
+ try:
+ if not is_torchvision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_torchvision_objects import *
+ else:
+ from .image_processing_utils_fast import BaseImageProcessorFast
+ from .models.deformable_detr import DeformableDetrImageProcessorFast
+ from .models.detr import DetrImageProcessorFast
+ from .models.pixtral import PixtralImageProcessorFast
+ from .models.rt_detr import RTDetrImageProcessorFast
+ from .models.vit import ViTImageProcessorFast
+
+ try:
+ if not is_torchvision_available() and not is_timm_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_timm_and_torchvision_objects import *
+ else:
+ from .models.timm_wrapper import TimmWrapperImageProcessor
+
+ # Modeling
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_pt_objects import *
+ else:
+ # Benchmarks
+ from .benchmark.benchmark import PyTorchBenchmark
+ from .benchmark.benchmark_args import PyTorchBenchmarkArguments
+ from .cache_utils import (
+ Cache,
+ CacheConfig,
+ DynamicCache,
+ EncoderDecoderCache,
+ HQQQuantizedCache,
+ HybridCache,
+ MambaCache,
+ OffloadedCache,
+ OffloadedStaticCache,
+ QuantizedCache,
+ QuantizedCacheConfig,
+ QuantoQuantizedCache,
+ SinkCache,
+ SlidingWindowCache,
+ StaticCache,
+ )
+ from .data.datasets import (
+ GlueDataset,
+ GlueDataTrainingArguments,
+ LineByLineTextDataset,
+ LineByLineWithRefDataset,
+ LineByLineWithSOPTextDataset,
+ SquadDataset,
+ SquadDataTrainingArguments,
+ TextDataset,
+ TextDatasetForNextSentencePrediction,
+ )
+ from .generation import (
+ AlternatingCodebooksLogitsProcessor,
+ BayesianDetectorConfig,
+ BayesianDetectorModel,
+ BeamScorer,
+ BeamSearchScorer,
+ ClassifierFreeGuidanceLogitsProcessor,
+ ConstrainedBeamSearchScorer,
+ Constraint,
+ ConstraintListState,
+ DisjunctiveConstraint,
+ EncoderNoRepeatNGramLogitsProcessor,
+ EncoderRepetitionPenaltyLogitsProcessor,
+ EosTokenCriteria,
+ EpsilonLogitsWarper,
+ EtaLogitsWarper,
+ ExponentialDecayLengthPenalty,
+ ForcedBOSTokenLogitsProcessor,
+ ForcedEOSTokenLogitsProcessor,
+ GenerationMixin,
+ HammingDiversityLogitsProcessor,
+ InfNanRemoveLogitsProcessor,
+ LogitNormalization,
+ LogitsProcessor,
+ LogitsProcessorList,
+ LogitsWarper,
+ MaxLengthCriteria,
+ MaxTimeCriteria,
+ MinLengthLogitsProcessor,
+ MinNewTokensLengthLogitsProcessor,
+ MinPLogitsWarper,
+ NoBadWordsLogitsProcessor,
+ NoRepeatNGramLogitsProcessor,
+ PhrasalConstraint,
+ PrefixConstrainedLogitsProcessor,
+ RepetitionPenaltyLogitsProcessor,
+ SequenceBiasLogitsProcessor,
+ StoppingCriteria,
+ StoppingCriteriaList,
+ StopStringCriteria,
+ SuppressTokensAtBeginLogitsProcessor,
+ SuppressTokensLogitsProcessor,
+ SynthIDTextWatermarkDetector,
+ SynthIDTextWatermarkingConfig,
+ SynthIDTextWatermarkLogitsProcessor,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+ TypicalLogitsWarper,
+ UnbatchedClassifierFreeGuidanceLogitsProcessor,
+ WatermarkDetector,
+ WatermarkLogitsProcessor,
+ WhisperTimeStampLogitsProcessor,
+ )
+ from .integrations.executorch import (
+ TorchExportableModuleWithStaticCache,
+ convert_and_export_with_cache,
+ )
+ from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
+ from .modeling_utils import PreTrainedModel
+ from .models.albert import (
+ AlbertForMaskedLM,
+ AlbertForMultipleChoice,
+ AlbertForPreTraining,
+ AlbertForQuestionAnswering,
+ AlbertForSequenceClassification,
+ AlbertForTokenClassification,
+ AlbertModel,
+ AlbertPreTrainedModel,
+ load_tf_weights_in_albert,
+ )
+ from .models.align import (
+ AlignModel,
+ AlignPreTrainedModel,
+ AlignTextModel,
+ AlignVisionModel,
+ )
+ from .models.altclip import (
+ AltCLIPModel,
+ AltCLIPPreTrainedModel,
+ AltCLIPTextModel,
+ AltCLIPVisionModel,
+ )
+ from .models.aria import (
+ AriaForConditionalGeneration,
+ AriaPreTrainedModel,
+ AriaTextForCausalLM,
+ AriaTextModel,
+ AriaTextPreTrainedModel,
+ )
+ from .models.audio_spectrogram_transformer import (
+ ASTForAudioClassification,
+ ASTModel,
+ ASTPreTrainedModel,
+ )
+ from .models.auto import (
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
+ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING,
+ MODEL_FOR_AUDIO_XVECTOR_MAPPING,
+ MODEL_FOR_BACKBONE_MAPPING,
+ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
+ MODEL_FOR_CAUSAL_LM_MAPPING,
+ MODEL_FOR_CTC_MAPPING,
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING,
+ MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING,
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
+ MODEL_FOR_IMAGE_MAPPING,
+ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING,
+ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
+ MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
+ MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
+ MODEL_FOR_MASK_GENERATION_MAPPING,
+ MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
+ MODEL_FOR_MASKED_LM_MAPPING,
+ MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
+ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
+ MODEL_FOR_OBJECT_DETECTION_MAPPING,
+ MODEL_FOR_PRETRAINING_MAPPING,
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING,
+ MODEL_FOR_RETRIEVAL_MAPPING,
+ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
+ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
+ MODEL_FOR_TEXT_ENCODING_MAPPING,
+ MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING,
+ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING,
+ MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING,
+ MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING,
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
+ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
+ MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
+ MODEL_FOR_VISION_2_SEQ_MAPPING,
+ MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING,
+ MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
+ MODEL_MAPPING,
+ MODEL_WITH_LM_HEAD_MAPPING,
+ AutoBackbone,
+ AutoModel,
+ AutoModelForAudioClassification,
+ AutoModelForAudioFrameClassification,
+ AutoModelForAudioXVector,
+ AutoModelForCausalLM,
+ AutoModelForCTC,
+ AutoModelForDepthEstimation,
+ AutoModelForDocumentQuestionAnswering,
+ AutoModelForImageClassification,
+ AutoModelForImageSegmentation,
+ AutoModelForImageTextToText,
+ AutoModelForImageToImage,
+ AutoModelForInstanceSegmentation,
+ AutoModelForKeypointDetection,
+ AutoModelForMaskedImageModeling,
+ AutoModelForMaskedLM,
+ AutoModelForMaskGeneration,
+ AutoModelForMultipleChoice,
+ AutoModelForNextSentencePrediction,
+ AutoModelForObjectDetection,
+ AutoModelForPreTraining,
+ AutoModelForQuestionAnswering,
+ AutoModelForSemanticSegmentation,
+ AutoModelForSeq2SeqLM,
+ AutoModelForSequenceClassification,
+ AutoModelForSpeechSeq2Seq,
+ AutoModelForTableQuestionAnswering,
+ AutoModelForTextEncoding,
+ AutoModelForTextToSpectrogram,
+ AutoModelForTextToWaveform,
+ AutoModelForTokenClassification,
+ AutoModelForUniversalSegmentation,
+ AutoModelForVideoClassification,
+ AutoModelForVision2Seq,
+ AutoModelForVisualQuestionAnswering,
+ AutoModelForZeroShotImageClassification,
+ AutoModelForZeroShotObjectDetection,
+ AutoModelWithLMHead,
+ )
+ from .models.autoformer import (
+ AutoformerForPrediction,
+ AutoformerModel,
+ AutoformerPreTrainedModel,
+ )
+ from .models.bamba import BambaForCausalLM, BambaModel, BambaPreTrainedModel
+ from .models.bark import (
+ BarkCausalModel,
+ BarkCoarseModel,
+ BarkFineModel,
+ BarkModel,
+ BarkPreTrainedModel,
+ BarkSemanticModel,
+ )
+ from .models.bart import (
+ BartForCausalLM,
+ BartForConditionalGeneration,
+ BartForQuestionAnswering,
+ BartForSequenceClassification,
+ BartModel,
+ BartPreTrainedModel,
+ BartPretrainedModel,
+ PretrainedBartModel,
+ )
+ from .models.beit import (
+ BeitBackbone,
+ BeitForImageClassification,
+ BeitForMaskedImageModeling,
+ BeitForSemanticSegmentation,
+ BeitModel,
+ BeitPreTrainedModel,
+ )
+ from .models.bert import (
+ BertForMaskedLM,
+ BertForMultipleChoice,
+ BertForNextSentencePrediction,
+ BertForPreTraining,
+ BertForQuestionAnswering,
+ BertForSequenceClassification,
+ BertForTokenClassification,
+ BertLMHeadModel,
+ BertModel,
+ BertPreTrainedModel,
+ load_tf_weights_in_bert,
+ )
+ from .models.bert_generation import (
+ BertGenerationDecoder,
+ BertGenerationEncoder,
+ BertGenerationPreTrainedModel,
+ load_tf_weights_in_bert_generation,
+ )
+ from .models.big_bird import (
+ BigBirdForCausalLM,
+ BigBirdForMaskedLM,
+ BigBirdForMultipleChoice,
+ BigBirdForPreTraining,
+ BigBirdForQuestionAnswering,
+ BigBirdForSequenceClassification,
+ BigBirdForTokenClassification,
+ BigBirdModel,
+ BigBirdPreTrainedModel,
+ load_tf_weights_in_big_bird,
+ )
+ from .models.bigbird_pegasus import (
+ BigBirdPegasusForCausalLM,
+ BigBirdPegasusForConditionalGeneration,
+ BigBirdPegasusForQuestionAnswering,
+ BigBirdPegasusForSequenceClassification,
+ BigBirdPegasusModel,
+ BigBirdPegasusPreTrainedModel,
+ )
+ from .models.biogpt import (
+ BioGptForCausalLM,
+ BioGptForSequenceClassification,
+ BioGptForTokenClassification,
+ BioGptModel,
+ BioGptPreTrainedModel,
+ )
+ from .models.bit import (
+ BitBackbone,
+ BitForImageClassification,
+ BitModel,
+ BitPreTrainedModel,
+ )
+ from .models.blenderbot import (
+ BlenderbotForCausalLM,
+ BlenderbotForConditionalGeneration,
+ BlenderbotModel,
+ BlenderbotPreTrainedModel,
+ )
+ from .models.blenderbot_small import (
+ BlenderbotSmallForCausalLM,
+ BlenderbotSmallForConditionalGeneration,
+ BlenderbotSmallModel,
+ BlenderbotSmallPreTrainedModel,
+ )
+ from .models.blip import (
+ BlipForConditionalGeneration,
+ BlipForImageTextRetrieval,
+ BlipForQuestionAnswering,
+ BlipModel,
+ BlipPreTrainedModel,
+ BlipTextModel,
+ BlipVisionModel,
+ )
+ from .models.blip_2 import (
+ Blip2ForConditionalGeneration,
+ Blip2ForImageTextRetrieval,
+ Blip2Model,
+ Blip2PreTrainedModel,
+ Blip2QFormerModel,
+ Blip2TextModelWithProjection,
+ Blip2VisionModel,
+ Blip2VisionModelWithProjection,
+ )
+ from .models.bloom import (
+ BloomForCausalLM,
+ BloomForQuestionAnswering,
+ BloomForSequenceClassification,
+ BloomForTokenClassification,
+ BloomModel,
+ BloomPreTrainedModel,
+ )
+ from .models.bridgetower import (
+ BridgeTowerForContrastiveLearning,
+ BridgeTowerForImageAndTextRetrieval,
+ BridgeTowerForMaskedLM,
+ BridgeTowerModel,
+ BridgeTowerPreTrainedModel,
+ )
+ from .models.bros import (
+ BrosForTokenClassification,
+ BrosModel,
+ BrosPreTrainedModel,
+ BrosProcessor,
+ BrosSpadeEEForTokenClassification,
+ BrosSpadeELForTokenClassification,
+ )
+ from .models.camembert import (
+ CamembertForCausalLM,
+ CamembertForMaskedLM,
+ CamembertForMultipleChoice,
+ CamembertForQuestionAnswering,
+ CamembertForSequenceClassification,
+ CamembertForTokenClassification,
+ CamembertModel,
+ CamembertPreTrainedModel,
+ )
+ from .models.canine import (
+ CanineForMultipleChoice,
+ CanineForQuestionAnswering,
+ CanineForSequenceClassification,
+ CanineForTokenClassification,
+ CanineModel,
+ CaninePreTrainedModel,
+ load_tf_weights_in_canine,
+ )
+ from .models.chameleon import (
+ ChameleonForConditionalGeneration,
+ ChameleonModel,
+ ChameleonPreTrainedModel,
+ ChameleonProcessor,
+ ChameleonVQVAE,
+ )
+ from .models.chinese_clip import (
+ ChineseCLIPModel,
+ ChineseCLIPPreTrainedModel,
+ ChineseCLIPTextModel,
+ ChineseCLIPVisionModel,
+ )
+ from .models.clap import (
+ ClapAudioModel,
+ ClapAudioModelWithProjection,
+ ClapFeatureExtractor,
+ ClapModel,
+ ClapPreTrainedModel,
+ ClapTextModel,
+ ClapTextModelWithProjection,
+ )
+ from .models.clip import (
+ CLIPForImageClassification,
+ CLIPModel,
+ CLIPPreTrainedModel,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPVisionModel,
+ CLIPVisionModelWithProjection,
+ )
+ from .models.clipseg import (
+ CLIPSegForImageSegmentation,
+ CLIPSegModel,
+ CLIPSegPreTrainedModel,
+ CLIPSegTextModel,
+ CLIPSegVisionModel,
+ )
+ from .models.clvp import (
+ ClvpDecoder,
+ ClvpEncoder,
+ ClvpForCausalLM,
+ ClvpModel,
+ ClvpModelForConditionalGeneration,
+ ClvpPreTrainedModel,
+ )
+ from .models.codegen import (
+ CodeGenForCausalLM,
+ CodeGenModel,
+ CodeGenPreTrainedModel,
+ )
+ from .models.cohere import (
+ CohereForCausalLM,
+ CohereModel,
+ CoherePreTrainedModel,
+ )
+ from .models.cohere2 import (
+ Cohere2ForCausalLM,
+ Cohere2Model,
+ Cohere2PreTrainedModel,
+ )
+ from .models.colpali import (
+ ColPaliForRetrieval,
+ ColPaliPreTrainedModel,
+ )
+ from .models.conditional_detr import (
+ ConditionalDetrForObjectDetection,
+ ConditionalDetrForSegmentation,
+ ConditionalDetrModel,
+ ConditionalDetrPreTrainedModel,
+ )
+ from .models.convbert import (
+ ConvBertForMaskedLM,
+ ConvBertForMultipleChoice,
+ ConvBertForQuestionAnswering,
+ ConvBertForSequenceClassification,
+ ConvBertForTokenClassification,
+ ConvBertModel,
+ ConvBertPreTrainedModel,
+ load_tf_weights_in_convbert,
+ )
+ from .models.convnext import (
+ ConvNextBackbone,
+ ConvNextForImageClassification,
+ ConvNextModel,
+ ConvNextPreTrainedModel,
+ )
+ from .models.convnextv2 import (
+ ConvNextV2Backbone,
+ ConvNextV2ForImageClassification,
+ ConvNextV2Model,
+ ConvNextV2PreTrainedModel,
+ )
+ from .models.cpmant import (
+ CpmAntForCausalLM,
+ CpmAntModel,
+ CpmAntPreTrainedModel,
+ )
+ from .models.ctrl import (
+ CTRLForSequenceClassification,
+ CTRLLMHeadModel,
+ CTRLModel,
+ CTRLPreTrainedModel,
+ )
+ from .models.cvt import (
+ CvtForImageClassification,
+ CvtModel,
+ CvtPreTrainedModel,
+ )
+ from .models.dac import (
+ DacModel,
+ DacPreTrainedModel,
+ )
+ from .models.data2vec import (
+ Data2VecAudioForAudioFrameClassification,
+ Data2VecAudioForCTC,
+ Data2VecAudioForSequenceClassification,
+ Data2VecAudioForXVector,
+ Data2VecAudioModel,
+ Data2VecAudioPreTrainedModel,
+ Data2VecTextForCausalLM,
+ Data2VecTextForMaskedLM,
+ Data2VecTextForMultipleChoice,
+ Data2VecTextForQuestionAnswering,
+ Data2VecTextForSequenceClassification,
+ Data2VecTextForTokenClassification,
+ Data2VecTextModel,
+ Data2VecTextPreTrainedModel,
+ Data2VecVisionForImageClassification,
+ Data2VecVisionForSemanticSegmentation,
+ Data2VecVisionModel,
+ Data2VecVisionPreTrainedModel,
+ )
+
+ # PyTorch model imports
+ from .models.dbrx import (
+ DbrxForCausalLM,
+ DbrxModel,
+ DbrxPreTrainedModel,
+ )
+ from .models.deberta import (
+ DebertaForMaskedLM,
+ DebertaForQuestionAnswering,
+ DebertaForSequenceClassification,
+ DebertaForTokenClassification,
+ DebertaModel,
+ DebertaPreTrainedModel,
+ )
+ from .models.deberta_v2 import (
+ DebertaV2ForMaskedLM,
+ DebertaV2ForMultipleChoice,
+ DebertaV2ForQuestionAnswering,
+ DebertaV2ForSequenceClassification,
+ DebertaV2ForTokenClassification,
+ DebertaV2Model,
+ DebertaV2PreTrainedModel,
+ )
+ from .models.decision_transformer import (
+ DecisionTransformerGPT2Model,
+ DecisionTransformerGPT2PreTrainedModel,
+ DecisionTransformerModel,
+ DecisionTransformerPreTrainedModel,
+ )
+ from .models.deformable_detr import (
+ DeformableDetrForObjectDetection,
+ DeformableDetrModel,
+ DeformableDetrPreTrainedModel,
+ )
+ from .models.deit import (
+ DeiTForImageClassification,
+ DeiTForImageClassificationWithTeacher,
+ DeiTForMaskedImageModeling,
+ DeiTModel,
+ DeiTPreTrainedModel,
+ )
+ from .models.deprecated.deta import (
+ DetaForObjectDetection,
+ DetaModel,
+ DetaPreTrainedModel,
+ )
+ from .models.deprecated.efficientformer import (
+ EfficientFormerForImageClassification,
+ EfficientFormerForImageClassificationWithTeacher,
+ EfficientFormerModel,
+ EfficientFormerPreTrainedModel,
+ )
+ from .models.deprecated.ernie_m import (
+ ErnieMForInformationExtraction,
+ ErnieMForMultipleChoice,
+ ErnieMForQuestionAnswering,
+ ErnieMForSequenceClassification,
+ ErnieMForTokenClassification,
+ ErnieMModel,
+ ErnieMPreTrainedModel,
+ )
+ from .models.deprecated.gptsan_japanese import (
+ GPTSanJapaneseForConditionalGeneration,
+ GPTSanJapaneseModel,
+ GPTSanJapanesePreTrainedModel,
+ )
+ from .models.deprecated.graphormer import (
+ GraphormerForGraphClassification,
+ GraphormerModel,
+ GraphormerPreTrainedModel,
+ )
+ from .models.deprecated.jukebox import (
+ JukeboxModel,
+ JukeboxPreTrainedModel,
+ JukeboxPrior,
+ JukeboxVQVAE,
+ )
+ from .models.deprecated.mctct import (
+ MCTCTForCTC,
+ MCTCTModel,
+ MCTCTPreTrainedModel,
+ )
+ from .models.deprecated.mega import (
+ MegaForCausalLM,
+ MegaForMaskedLM,
+ MegaForMultipleChoice,
+ MegaForQuestionAnswering,
+ MegaForSequenceClassification,
+ MegaForTokenClassification,
+ MegaModel,
+ MegaPreTrainedModel,
+ )
+ from .models.deprecated.mmbt import (
+ MMBTForClassification,
+ MMBTModel,
+ ModalEmbeddings,
+ )
+ from .models.deprecated.nat import (
+ NatBackbone,
+ NatForImageClassification,
+ NatModel,
+ NatPreTrainedModel,
+ )
+ from .models.deprecated.nezha import (
+ NezhaForMaskedLM,
+ NezhaForMultipleChoice,
+ NezhaForNextSentencePrediction,
+ NezhaForPreTraining,
+ NezhaForQuestionAnswering,
+ NezhaForSequenceClassification,
+ NezhaForTokenClassification,
+ NezhaModel,
+ NezhaPreTrainedModel,
+ )
+ from .models.deprecated.open_llama import (
+ OpenLlamaForCausalLM,
+ OpenLlamaForSequenceClassification,
+ OpenLlamaModel,
+ OpenLlamaPreTrainedModel,
+ )
+ from .models.deprecated.qdqbert import (
+ QDQBertForMaskedLM,
+ QDQBertForMultipleChoice,
+ QDQBertForNextSentencePrediction,
+ QDQBertForQuestionAnswering,
+ QDQBertForSequenceClassification,
+ QDQBertForTokenClassification,
+ QDQBertLMHeadModel,
+ QDQBertModel,
+ QDQBertPreTrainedModel,
+ load_tf_weights_in_qdqbert,
+ )
+ from .models.deprecated.realm import (
+ RealmEmbedder,
+ RealmForOpenQA,
+ RealmKnowledgeAugEncoder,
+ RealmPreTrainedModel,
+ RealmReader,
+ RealmRetriever,
+ RealmScorer,
+ load_tf_weights_in_realm,
+ )
+ from .models.deprecated.retribert import (
+ RetriBertModel,
+ RetriBertPreTrainedModel,
+ )
+ from .models.deprecated.speech_to_text_2 import (
+ Speech2Text2ForCausalLM,
+ Speech2Text2PreTrainedModel,
+ )
+ from .models.deprecated.trajectory_transformer import (
+ TrajectoryTransformerModel,
+ TrajectoryTransformerPreTrainedModel,
+ )
+ from .models.deprecated.transfo_xl import (
+ AdaptiveEmbedding,
+ TransfoXLForSequenceClassification,
+ TransfoXLLMHeadModel,
+ TransfoXLModel,
+ TransfoXLPreTrainedModel,
+ load_tf_weights_in_transfo_xl,
+ )
+ from .models.deprecated.tvlt import (
+ TvltForAudioVisualClassification,
+ TvltForPreTraining,
+ TvltModel,
+ TvltPreTrainedModel,
+ )
+ from .models.deprecated.van import (
+ VanForImageClassification,
+ VanModel,
+ VanPreTrainedModel,
+ )
+ from .models.deprecated.vit_hybrid import (
+ ViTHybridForImageClassification,
+ ViTHybridModel,
+ ViTHybridPreTrainedModel,
+ )
+ from .models.deprecated.xlm_prophetnet import (
+ XLMProphetNetDecoder,
+ XLMProphetNetEncoder,
+ XLMProphetNetForCausalLM,
+ XLMProphetNetForConditionalGeneration,
+ XLMProphetNetModel,
+ XLMProphetNetPreTrainedModel,
+ )
+ from .models.depth_anything import (
+ DepthAnythingForDepthEstimation,
+ DepthAnythingPreTrainedModel,
+ )
+ from .models.detr import (
+ DetrForObjectDetection,
+ DetrForSegmentation,
+ DetrModel,
+ DetrPreTrainedModel,
+ )
+ from .models.diffllama import (
+ DiffLlamaForCausalLM,
+ DiffLlamaForQuestionAnswering,
+ DiffLlamaForSequenceClassification,
+ DiffLlamaForTokenClassification,
+ DiffLlamaModel,
+ DiffLlamaPreTrainedModel,
+ )
+ from .models.dinat import (
+ DinatBackbone,
+ DinatForImageClassification,
+ DinatModel,
+ DinatPreTrainedModel,
+ )
+ from .models.dinov2 import (
+ Dinov2Backbone,
+ Dinov2ForImageClassification,
+ Dinov2Model,
+ Dinov2PreTrainedModel,
+ )
+ from .models.dinov2_with_registers import (
+ Dinov2WithRegistersBackbone,
+ Dinov2WithRegistersForImageClassification,
+ Dinov2WithRegistersModel,
+ Dinov2WithRegistersPreTrainedModel,
+ )
+ from .models.distilbert import (
+ DistilBertForMaskedLM,
+ DistilBertForMultipleChoice,
+ DistilBertForQuestionAnswering,
+ DistilBertForSequenceClassification,
+ DistilBertForTokenClassification,
+ DistilBertModel,
+ DistilBertPreTrainedModel,
+ )
+ from .models.donut import (
+ DonutSwinModel,
+ DonutSwinPreTrainedModel,
+ )
+ from .models.dpr import (
+ DPRContextEncoder,
+ DPRPretrainedContextEncoder,
+ DPRPreTrainedModel,
+ DPRPretrainedQuestionEncoder,
+ DPRPretrainedReader,
+ DPRQuestionEncoder,
+ DPRReader,
+ )
+ from .models.dpt import (
+ DPTForDepthEstimation,
+ DPTForSemanticSegmentation,
+ DPTModel,
+ DPTPreTrainedModel,
+ )
+ from .models.efficientnet import (
+ EfficientNetForImageClassification,
+ EfficientNetModel,
+ EfficientNetPreTrainedModel,
+ )
+ from .models.electra import (
+ ElectraForCausalLM,
+ ElectraForMaskedLM,
+ ElectraForMultipleChoice,
+ ElectraForPreTraining,
+ ElectraForQuestionAnswering,
+ ElectraForSequenceClassification,
+ ElectraForTokenClassification,
+ ElectraModel,
+ ElectraPreTrainedModel,
+ load_tf_weights_in_electra,
+ )
+ from .models.emu3 import (
+ Emu3ForCausalLM,
+ Emu3ForConditionalGeneration,
+ Emu3PreTrainedModel,
+ Emu3TextModel,
+ Emu3VQVAE,
+ )
+ from .models.encodec import (
+ EncodecModel,
+ EncodecPreTrainedModel,
+ )
+ from .models.encoder_decoder import EncoderDecoderModel
+ from .models.ernie import (
+ ErnieForCausalLM,
+ ErnieForMaskedLM,
+ ErnieForMultipleChoice,
+ ErnieForNextSentencePrediction,
+ ErnieForPreTraining,
+ ErnieForQuestionAnswering,
+ ErnieForSequenceClassification,
+ ErnieForTokenClassification,
+ ErnieModel,
+ ErniePreTrainedModel,
+ )
+ from .models.esm import (
+ EsmFoldPreTrainedModel,
+ EsmForMaskedLM,
+ EsmForProteinFolding,
+ EsmForSequenceClassification,
+ EsmForTokenClassification,
+ EsmModel,
+ EsmPreTrainedModel,
+ )
+ from .models.falcon import (
+ FalconForCausalLM,
+ FalconForQuestionAnswering,
+ FalconForSequenceClassification,
+ FalconForTokenClassification,
+ FalconModel,
+ FalconPreTrainedModel,
+ )
+ from .models.falcon_mamba import (
+ FalconMambaForCausalLM,
+ FalconMambaModel,
+ FalconMambaPreTrainedModel,
+ )
+ from .models.fastspeech2_conformer import (
+ FastSpeech2ConformerHifiGan,
+ FastSpeech2ConformerModel,
+ FastSpeech2ConformerPreTrainedModel,
+ FastSpeech2ConformerWithHifiGan,
+ )
+ from .models.flaubert import (
+ FlaubertForMultipleChoice,
+ FlaubertForQuestionAnswering,
+ FlaubertForQuestionAnsweringSimple,
+ FlaubertForSequenceClassification,
+ FlaubertForTokenClassification,
+ FlaubertModel,
+ FlaubertPreTrainedModel,
+ FlaubertWithLMHeadModel,
+ )
+ from .models.flava import (
+ FlavaForPreTraining,
+ FlavaImageCodebook,
+ FlavaImageModel,
+ FlavaModel,
+ FlavaMultimodalModel,
+ FlavaPreTrainedModel,
+ FlavaTextModel,
+ )
+ from .models.fnet import (
+ FNetForMaskedLM,
+ FNetForMultipleChoice,
+ FNetForNextSentencePrediction,
+ FNetForPreTraining,
+ FNetForQuestionAnswering,
+ FNetForSequenceClassification,
+ FNetForTokenClassification,
+ FNetModel,
+ FNetPreTrainedModel,
+ )
+ from .models.focalnet import (
+ FocalNetBackbone,
+ FocalNetForImageClassification,
+ FocalNetForMaskedImageModeling,
+ FocalNetModel,
+ FocalNetPreTrainedModel,
+ )
+ from .models.fsmt import (
+ FSMTForConditionalGeneration,
+ FSMTModel,
+ PretrainedFSMTModel,
+ )
+ from .models.funnel import (
+ FunnelBaseModel,
+ FunnelForMaskedLM,
+ FunnelForMultipleChoice,
+ FunnelForPreTraining,
+ FunnelForQuestionAnswering,
+ FunnelForSequenceClassification,
+ FunnelForTokenClassification,
+ FunnelModel,
+ FunnelPreTrainedModel,
+ load_tf_weights_in_funnel,
+ )
+ from .models.fuyu import (
+ FuyuForCausalLM,
+ FuyuPreTrainedModel,
+ )
+ from .models.gemma import (
+ GemmaForCausalLM,
+ GemmaForSequenceClassification,
+ GemmaForTokenClassification,
+ GemmaModel,
+ GemmaPreTrainedModel,
+ )
+ from .models.gemma2 import (
+ Gemma2ForCausalLM,
+ Gemma2ForSequenceClassification,
+ Gemma2ForTokenClassification,
+ Gemma2Model,
+ Gemma2PreTrainedModel,
+ )
+ from .models.git import (
+ GitForCausalLM,
+ GitModel,
+ GitPreTrainedModel,
+ GitVisionModel,
+ )
+ from .models.glm import (
+ GlmForCausalLM,
+ GlmForSequenceClassification,
+ GlmForTokenClassification,
+ GlmModel,
+ GlmPreTrainedModel,
+ )
+ from .models.glpn import (
+ GLPNForDepthEstimation,
+ GLPNModel,
+ GLPNPreTrainedModel,
+ )
+ from .models.gpt2 import (
+ GPT2DoubleHeadsModel,
+ GPT2ForQuestionAnswering,
+ GPT2ForSequenceClassification,
+ GPT2ForTokenClassification,
+ GPT2LMHeadModel,
+ GPT2Model,
+ GPT2PreTrainedModel,
+ load_tf_weights_in_gpt2,
+ )
+ from .models.gpt_bigcode import (
+ GPTBigCodeForCausalLM,
+ GPTBigCodeForSequenceClassification,
+ GPTBigCodeForTokenClassification,
+ GPTBigCodeModel,
+ GPTBigCodePreTrainedModel,
+ )
+ from .models.gpt_neo import (
+ GPTNeoForCausalLM,
+ GPTNeoForQuestionAnswering,
+ GPTNeoForSequenceClassification,
+ GPTNeoForTokenClassification,
+ GPTNeoModel,
+ GPTNeoPreTrainedModel,
+ load_tf_weights_in_gpt_neo,
+ )
+ from .models.gpt_neox import (
+ GPTNeoXForCausalLM,
+ GPTNeoXForQuestionAnswering,
+ GPTNeoXForSequenceClassification,
+ GPTNeoXForTokenClassification,
+ GPTNeoXModel,
+ GPTNeoXPreTrainedModel,
+ )
+ from .models.gpt_neox_japanese import (
+ GPTNeoXJapaneseForCausalLM,
+ GPTNeoXJapaneseModel,
+ GPTNeoXJapanesePreTrainedModel,
+ )
+ from .models.gptj import (
+ GPTJForCausalLM,
+ GPTJForQuestionAnswering,
+ GPTJForSequenceClassification,
+ GPTJModel,
+ GPTJPreTrainedModel,
+ )
+ from .models.granite import (
+ GraniteForCausalLM,
+ GraniteModel,
+ GranitePreTrainedModel,
+ )
+ from .models.granitemoe import (
+ GraniteMoeForCausalLM,
+ GraniteMoeModel,
+ GraniteMoePreTrainedModel,
+ )
+ from .models.grounding_dino import (
+ GroundingDinoForObjectDetection,
+ GroundingDinoModel,
+ GroundingDinoPreTrainedModel,
+ )
+ from .models.groupvit import (
+ GroupViTModel,
+ GroupViTPreTrainedModel,
+ GroupViTTextModel,
+ GroupViTVisionModel,
+ )
+ from .models.hiera import (
+ HieraBackbone,
+ HieraForImageClassification,
+ HieraForPreTraining,
+ HieraModel,
+ HieraPreTrainedModel,
+ )
+ from .models.hubert import (
+ HubertForCTC,
+ HubertForSequenceClassification,
+ HubertModel,
+ HubertPreTrainedModel,
+ )
+ from .models.ibert import (
+ IBertForMaskedLM,
+ IBertForMultipleChoice,
+ IBertForQuestionAnswering,
+ IBertForSequenceClassification,
+ IBertForTokenClassification,
+ IBertModel,
+ IBertPreTrainedModel,
+ )
+ from .models.idefics import (
+ IdeficsForVisionText2Text,
+ IdeficsModel,
+ IdeficsPreTrainedModel,
+ IdeficsProcessor,
+ )
+ from .models.idefics2 import (
+ Idefics2ForConditionalGeneration,
+ Idefics2Model,
+ Idefics2PreTrainedModel,
+ Idefics2Processor,
+ )
+ from .models.idefics3 import (
+ Idefics3ForConditionalGeneration,
+ Idefics3Model,
+ Idefics3PreTrainedModel,
+ Idefics3Processor,
+ Idefics3VisionConfig,
+ Idefics3VisionTransformer,
+ )
+ from .models.ijepa import (
+ IJepaForImageClassification,
+ IJepaModel,
+ IJepaPreTrainedModel,
+ )
+ from .models.imagegpt import (
+ ImageGPTForCausalImageModeling,
+ ImageGPTForImageClassification,
+ ImageGPTModel,
+ ImageGPTPreTrainedModel,
+ load_tf_weights_in_imagegpt,
+ )
+ from .models.informer import (
+ InformerForPrediction,
+ InformerModel,
+ InformerPreTrainedModel,
+ )
+ from .models.instructblip import (
+ InstructBlipForConditionalGeneration,
+ InstructBlipPreTrainedModel,
+ InstructBlipQFormerModel,
+ InstructBlipVisionModel,
+ )
+ from .models.instructblipvideo import (
+ InstructBlipVideoForConditionalGeneration,
+ InstructBlipVideoPreTrainedModel,
+ InstructBlipVideoQFormerModel,
+ InstructBlipVideoVisionModel,
+ )
+ from .models.jamba import (
+ JambaForCausalLM,
+ JambaForSequenceClassification,
+ JambaModel,
+ JambaPreTrainedModel,
+ )
+ from .models.jetmoe import (
+ JetMoeForCausalLM,
+ JetMoeForSequenceClassification,
+ JetMoeModel,
+ JetMoePreTrainedModel,
+ )
+ from .models.kosmos2 import (
+ Kosmos2ForConditionalGeneration,
+ Kosmos2Model,
+ Kosmos2PreTrainedModel,
+ )
+ from .models.layoutlm import (
+ LayoutLMForMaskedLM,
+ LayoutLMForQuestionAnswering,
+ LayoutLMForSequenceClassification,
+ LayoutLMForTokenClassification,
+ LayoutLMModel,
+ LayoutLMPreTrainedModel,
+ )
+ from .models.layoutlmv2 import (
+ LayoutLMv2ForQuestionAnswering,
+ LayoutLMv2ForSequenceClassification,
+ LayoutLMv2ForTokenClassification,
+ LayoutLMv2Model,
+ LayoutLMv2PreTrainedModel,
+ )
+ from .models.layoutlmv3 import (
+ LayoutLMv3ForQuestionAnswering,
+ LayoutLMv3ForSequenceClassification,
+ LayoutLMv3ForTokenClassification,
+ LayoutLMv3Model,
+ LayoutLMv3PreTrainedModel,
+ )
+ from .models.led import (
+ LEDForConditionalGeneration,
+ LEDForQuestionAnswering,
+ LEDForSequenceClassification,
+ LEDModel,
+ LEDPreTrainedModel,
+ )
+ from .models.levit import (
+ LevitForImageClassification,
+ LevitForImageClassificationWithTeacher,
+ LevitModel,
+ LevitPreTrainedModel,
+ )
+ from .models.lilt import (
+ LiltForQuestionAnswering,
+ LiltForSequenceClassification,
+ LiltForTokenClassification,
+ LiltModel,
+ LiltPreTrainedModel,
+ )
+ from .models.llama import (
+ LlamaForCausalLM,
+ LlamaForQuestionAnswering,
+ LlamaForSequenceClassification,
+ LlamaForTokenClassification,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ )
+ from .models.llava import (
+ LlavaForConditionalGeneration,
+ LlavaPreTrainedModel,
+ )
+ from .models.llava_next import (
+ LlavaNextForConditionalGeneration,
+ LlavaNextPreTrainedModel,
+ )
+ from .models.llava_next_video import (
+ LlavaNextVideoForConditionalGeneration,
+ LlavaNextVideoPreTrainedModel,
+ )
+ from .models.llava_onevision import (
+ LlavaOnevisionForConditionalGeneration,
+ LlavaOnevisionPreTrainedModel,
+ )
+ from .models.longformer import (
+ LongformerForMaskedLM,
+ LongformerForMultipleChoice,
+ LongformerForQuestionAnswering,
+ LongformerForSequenceClassification,
+ LongformerForTokenClassification,
+ LongformerModel,
+ LongformerPreTrainedModel,
+ )
+ from .models.longt5 import (
+ LongT5EncoderModel,
+ LongT5ForConditionalGeneration,
+ LongT5Model,
+ LongT5PreTrainedModel,
+ )
+ from .models.luke import (
+ LukeForEntityClassification,
+ LukeForEntityPairClassification,
+ LukeForEntitySpanClassification,
+ LukeForMaskedLM,
+ LukeForMultipleChoice,
+ LukeForQuestionAnswering,
+ LukeForSequenceClassification,
+ LukeForTokenClassification,
+ LukeModel,
+ LukePreTrainedModel,
+ )
+ from .models.lxmert import (
+ LxmertEncoder,
+ LxmertForPreTraining,
+ LxmertForQuestionAnswering,
+ LxmertModel,
+ LxmertPreTrainedModel,
+ LxmertVisualFeatureEncoder,
+ )
+ from .models.m2m_100 import (
+ M2M100ForConditionalGeneration,
+ M2M100Model,
+ M2M100PreTrainedModel,
+ )
+ from .models.mamba import (
+ MambaForCausalLM,
+ MambaModel,
+ MambaPreTrainedModel,
+ )
+ from .models.mamba2 import (
+ Mamba2ForCausalLM,
+ Mamba2Model,
+ Mamba2PreTrainedModel,
+ )
+ from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel, MarianPreTrainedModel
+ from .models.markuplm import (
+ MarkupLMForQuestionAnswering,
+ MarkupLMForSequenceClassification,
+ MarkupLMForTokenClassification,
+ MarkupLMModel,
+ MarkupLMPreTrainedModel,
+ )
+ from .models.mask2former import (
+ Mask2FormerForUniversalSegmentation,
+ Mask2FormerModel,
+ Mask2FormerPreTrainedModel,
+ )
+ from .models.maskformer import (
+ MaskFormerForInstanceSegmentation,
+ MaskFormerModel,
+ MaskFormerPreTrainedModel,
+ MaskFormerSwinBackbone,
+ )
+ from .models.mbart import (
+ MBartForCausalLM,
+ MBartForConditionalGeneration,
+ MBartForQuestionAnswering,
+ MBartForSequenceClassification,
+ MBartModel,
+ MBartPreTrainedModel,
+ )
+ from .models.megatron_bert import (
+ MegatronBertForCausalLM,
+ MegatronBertForMaskedLM,
+ MegatronBertForMultipleChoice,
+ MegatronBertForNextSentencePrediction,
+ MegatronBertForPreTraining,
+ MegatronBertForQuestionAnswering,
+ MegatronBertForSequenceClassification,
+ MegatronBertForTokenClassification,
+ MegatronBertModel,
+ MegatronBertPreTrainedModel,
+ )
+ from .models.mgp_str import (
+ MgpstrForSceneTextRecognition,
+ MgpstrModel,
+ MgpstrPreTrainedModel,
+ )
+ from .models.mimi import (
+ MimiModel,
+ MimiPreTrainedModel,
+ )
+ from .models.mistral import (
+ MistralForCausalLM,
+ MistralForQuestionAnswering,
+ MistralForSequenceClassification,
+ MistralForTokenClassification,
+ MistralModel,
+ MistralPreTrainedModel,
+ )
+ from .models.mixtral import (
+ MixtralForCausalLM,
+ MixtralForQuestionAnswering,
+ MixtralForSequenceClassification,
+ MixtralForTokenClassification,
+ MixtralModel,
+ MixtralPreTrainedModel,
+ )
+ from .models.mllama import (
+ MllamaForCausalLM,
+ MllamaForConditionalGeneration,
+ MllamaPreTrainedModel,
+ MllamaProcessor,
+ MllamaTextModel,
+ MllamaVisionModel,
+ )
+ from .models.mobilebert import (
+ MobileBertForMaskedLM,
+ MobileBertForMultipleChoice,
+ MobileBertForNextSentencePrediction,
+ MobileBertForPreTraining,
+ MobileBertForQuestionAnswering,
+ MobileBertForSequenceClassification,
+ MobileBertForTokenClassification,
+ MobileBertModel,
+ MobileBertPreTrainedModel,
+ load_tf_weights_in_mobilebert,
+ )
+ from .models.mobilenet_v1 import (
+ MobileNetV1ForImageClassification,
+ MobileNetV1Model,
+ MobileNetV1PreTrainedModel,
+ load_tf_weights_in_mobilenet_v1,
+ )
+ from .models.mobilenet_v2 import (
+ MobileNetV2ForImageClassification,
+ MobileNetV2ForSemanticSegmentation,
+ MobileNetV2Model,
+ MobileNetV2PreTrainedModel,
+ load_tf_weights_in_mobilenet_v2,
+ )
+ from .models.mobilevit import (
+ MobileViTForImageClassification,
+ MobileViTForSemanticSegmentation,
+ MobileViTModel,
+ MobileViTPreTrainedModel,
+ )
+ from .models.mobilevitv2 import (
+ MobileViTV2ForImageClassification,
+ MobileViTV2ForSemanticSegmentation,
+ MobileViTV2Model,
+ MobileViTV2PreTrainedModel,
+ )
+ from .models.modernbert import (
+ ModernBertForMaskedLM,
+ ModernBertForSequenceClassification,
+ ModernBertForTokenClassification,
+ ModernBertModel,
+ ModernBertPreTrainedModel,
+ )
+ from .models.moonshine import (
+ MoonshineForConditionalGeneration,
+ MoonshineModel,
+ MoonshinePreTrainedModel,
+ )
+ from .models.moshi import (
+ MoshiForCausalLM,
+ MoshiForConditionalGeneration,
+ MoshiModel,
+ MoshiPreTrainedModel,
+ )
+ from .models.mpnet import (
+ MPNetForMaskedLM,
+ MPNetForMultipleChoice,
+ MPNetForQuestionAnswering,
+ MPNetForSequenceClassification,
+ MPNetForTokenClassification,
+ MPNetModel,
+ MPNetPreTrainedModel,
+ )
+ from .models.mpt import (
+ MptForCausalLM,
+ MptForQuestionAnswering,
+ MptForSequenceClassification,
+ MptForTokenClassification,
+ MptModel,
+ MptPreTrainedModel,
+ )
+ from .models.mra import (
+ MraForMaskedLM,
+ MraForMultipleChoice,
+ MraForQuestionAnswering,
+ MraForSequenceClassification,
+ MraForTokenClassification,
+ MraModel,
+ MraPreTrainedModel,
+ )
+ from .models.mt5 import (
+ MT5EncoderModel,
+ MT5ForConditionalGeneration,
+ MT5ForQuestionAnswering,
+ MT5ForSequenceClassification,
+ MT5ForTokenClassification,
+ MT5Model,
+ MT5PreTrainedModel,
+ )
+ from .models.musicgen import (
+ MusicgenForCausalLM,
+ MusicgenForConditionalGeneration,
+ MusicgenModel,
+ MusicgenPreTrainedModel,
+ MusicgenProcessor,
+ )
+ from .models.musicgen_melody import (
+ MusicgenMelodyForCausalLM,
+ MusicgenMelodyForConditionalGeneration,
+ MusicgenMelodyModel,
+ MusicgenMelodyPreTrainedModel,
+ )
+ from .models.mvp import (
+ MvpForCausalLM,
+ MvpForConditionalGeneration,
+ MvpForQuestionAnswering,
+ MvpForSequenceClassification,
+ MvpModel,
+ MvpPreTrainedModel,
+ )
+ from .models.nemotron import (
+ NemotronForCausalLM,
+ NemotronForQuestionAnswering,
+ NemotronForSequenceClassification,
+ NemotronForTokenClassification,
+ NemotronModel,
+ NemotronPreTrainedModel,
+ )
+ from .models.nllb_moe import (
+ NllbMoeForConditionalGeneration,
+ NllbMoeModel,
+ NllbMoePreTrainedModel,
+ NllbMoeSparseMLP,
+ NllbMoeTop2Router,
+ )
+ from .models.nystromformer import (
+ NystromformerForMaskedLM,
+ NystromformerForMultipleChoice,
+ NystromformerForQuestionAnswering,
+ NystromformerForSequenceClassification,
+ NystromformerForTokenClassification,
+ NystromformerModel,
+ NystromformerPreTrainedModel,
+ )
+ from .models.olmo import (
+ OlmoForCausalLM,
+ OlmoModel,
+ OlmoPreTrainedModel,
+ )
+ from .models.olmo2 import (
+ Olmo2ForCausalLM,
+ Olmo2Model,
+ Olmo2PreTrainedModel,
+ )
+ from .models.olmoe import (
+ OlmoeForCausalLM,
+ OlmoeModel,
+ OlmoePreTrainedModel,
+ )
+ from .models.omdet_turbo import (
+ OmDetTurboForObjectDetection,
+ OmDetTurboPreTrainedModel,
+ )
+ from .models.oneformer import (
+ OneFormerForUniversalSegmentation,
+ OneFormerModel,
+ OneFormerPreTrainedModel,
+ )
+ from .models.openai import (
+ OpenAIGPTDoubleHeadsModel,
+ OpenAIGPTForSequenceClassification,
+ OpenAIGPTLMHeadModel,
+ OpenAIGPTModel,
+ OpenAIGPTPreTrainedModel,
+ load_tf_weights_in_openai_gpt,
+ )
+ from .models.opt import (
+ OPTForCausalLM,
+ OPTForQuestionAnswering,
+ OPTForSequenceClassification,
+ OPTModel,
+ OPTPreTrainedModel,
+ )
+ from .models.owlv2 import (
+ Owlv2ForObjectDetection,
+ Owlv2Model,
+ Owlv2PreTrainedModel,
+ Owlv2TextModel,
+ Owlv2VisionModel,
+ )
+ from .models.owlvit import (
+ OwlViTForObjectDetection,
+ OwlViTModel,
+ OwlViTPreTrainedModel,
+ OwlViTTextModel,
+ OwlViTVisionModel,
+ )
+ from .models.paligemma import (
+ PaliGemmaForConditionalGeneration,
+ PaliGemmaPreTrainedModel,
+ PaliGemmaProcessor,
+ )
+ from .models.patchtsmixer import (
+ PatchTSMixerForPrediction,
+ PatchTSMixerForPretraining,
+ PatchTSMixerForRegression,
+ PatchTSMixerForTimeSeriesClassification,
+ PatchTSMixerModel,
+ PatchTSMixerPreTrainedModel,
+ )
+ from .models.patchtst import (
+ PatchTSTForClassification,
+ PatchTSTForPrediction,
+ PatchTSTForPretraining,
+ PatchTSTForRegression,
+ PatchTSTModel,
+ PatchTSTPreTrainedModel,
+ )
+ from .models.pegasus import (
+ PegasusForCausalLM,
+ PegasusForConditionalGeneration,
+ PegasusModel,
+ PegasusPreTrainedModel,
+ )
+ from .models.pegasus_x import (
+ PegasusXForConditionalGeneration,
+ PegasusXModel,
+ PegasusXPreTrainedModel,
+ )
+ from .models.perceiver import (
+ PerceiverForImageClassificationConvProcessing,
+ PerceiverForImageClassificationFourier,
+ PerceiverForImageClassificationLearned,
+ PerceiverForMaskedLM,
+ PerceiverForMultimodalAutoencoding,
+ PerceiverForOpticalFlow,
+ PerceiverForSequenceClassification,
+ PerceiverModel,
+ PerceiverPreTrainedModel,
+ )
+ from .models.persimmon import (
+ PersimmonForCausalLM,
+ PersimmonForSequenceClassification,
+ PersimmonForTokenClassification,
+ PersimmonModel,
+ PersimmonPreTrainedModel,
+ )
+ from .models.phi import (
+ PhiForCausalLM,
+ PhiForSequenceClassification,
+ PhiForTokenClassification,
+ PhiModel,
+ PhiPreTrainedModel,
+ )
+ from .models.phi3 import (
+ Phi3ForCausalLM,
+ Phi3ForSequenceClassification,
+ Phi3ForTokenClassification,
+ Phi3Model,
+ Phi3PreTrainedModel,
+ )
+ from .models.phimoe import (
+ PhimoeForCausalLM,
+ PhimoeForSequenceClassification,
+ PhimoeModel,
+ PhimoePreTrainedModel,
+ )
+ from .models.pix2struct import (
+ Pix2StructForConditionalGeneration,
+ Pix2StructPreTrainedModel,
+ Pix2StructTextModel,
+ Pix2StructVisionModel,
+ )
+ from .models.pixtral import (
+ PixtralPreTrainedModel,
+ PixtralVisionModel,
+ )
+ from .models.plbart import (
+ PLBartForCausalLM,
+ PLBartForConditionalGeneration,
+ PLBartForSequenceClassification,
+ PLBartModel,
+ PLBartPreTrainedModel,
+ )
+ from .models.poolformer import (
+ PoolFormerForImageClassification,
+ PoolFormerModel,
+ PoolFormerPreTrainedModel,
+ )
+ from .models.pop2piano import (
+ Pop2PianoForConditionalGeneration,
+ Pop2PianoPreTrainedModel,
+ )
+ from .models.prophetnet import (
+ ProphetNetDecoder,
+ ProphetNetEncoder,
+ ProphetNetForCausalLM,
+ ProphetNetForConditionalGeneration,
+ ProphetNetModel,
+ ProphetNetPreTrainedModel,
+ )
+ from .models.pvt import (
+ PvtForImageClassification,
+ PvtModel,
+ PvtPreTrainedModel,
+ )
+ from .models.pvt_v2 import (
+ PvtV2Backbone,
+ PvtV2ForImageClassification,
+ PvtV2Model,
+ PvtV2PreTrainedModel,
+ )
+ from .models.qwen2 import (
+ Qwen2ForCausalLM,
+ Qwen2ForQuestionAnswering,
+ Qwen2ForSequenceClassification,
+ Qwen2ForTokenClassification,
+ Qwen2Model,
+ Qwen2PreTrainedModel,
+ )
+ from .models.qwen2_audio import (
+ Qwen2AudioEncoder,
+ Qwen2AudioForConditionalGeneration,
+ Qwen2AudioPreTrainedModel,
+ )
+ from .models.qwen2_moe import (
+ Qwen2MoeForCausalLM,
+ Qwen2MoeForQuestionAnswering,
+ Qwen2MoeForSequenceClassification,
+ Qwen2MoeForTokenClassification,
+ Qwen2MoeModel,
+ Qwen2MoePreTrainedModel,
+ )
+ from .models.qwen2_vl import (
+ Qwen2VLForConditionalGeneration,
+ Qwen2VLModel,
+ Qwen2VLPreTrainedModel,
+ )
+ from .models.rag import (
+ RagModel,
+ RagPreTrainedModel,
+ RagSequenceForGeneration,
+ RagTokenForGeneration,
+ )
+ from .models.recurrent_gemma import (
+ RecurrentGemmaForCausalLM,
+ RecurrentGemmaModel,
+ RecurrentGemmaPreTrainedModel,
+ )
+ from .models.reformer import (
+ ReformerForMaskedLM,
+ ReformerForQuestionAnswering,
+ ReformerForSequenceClassification,
+ ReformerModel,
+ ReformerModelWithLMHead,
+ ReformerPreTrainedModel,
+ )
+ from .models.regnet import (
+ RegNetForImageClassification,
+ RegNetModel,
+ RegNetPreTrainedModel,
+ )
+ from .models.rembert import (
+ RemBertForCausalLM,
+ RemBertForMaskedLM,
+ RemBertForMultipleChoice,
+ RemBertForQuestionAnswering,
+ RemBertForSequenceClassification,
+ RemBertForTokenClassification,
+ RemBertModel,
+ RemBertPreTrainedModel,
+ load_tf_weights_in_rembert,
+ )
+ from .models.resnet import (
+ ResNetBackbone,
+ ResNetForImageClassification,
+ ResNetModel,
+ ResNetPreTrainedModel,
+ )
+ from .models.roberta import (
+ RobertaForCausalLM,
+ RobertaForMaskedLM,
+ RobertaForMultipleChoice,
+ RobertaForQuestionAnswering,
+ RobertaForSequenceClassification,
+ RobertaForTokenClassification,
+ RobertaModel,
+ RobertaPreTrainedModel,
+ )
+ from .models.roberta_prelayernorm import (
+ RobertaPreLayerNormForCausalLM,
+ RobertaPreLayerNormForMaskedLM,
+ RobertaPreLayerNormForMultipleChoice,
+ RobertaPreLayerNormForQuestionAnswering,
+ RobertaPreLayerNormForSequenceClassification,
+ RobertaPreLayerNormForTokenClassification,
+ RobertaPreLayerNormModel,
+ RobertaPreLayerNormPreTrainedModel,
+ )
+ from .models.roc_bert import (
+ RoCBertForCausalLM,
+ RoCBertForMaskedLM,
+ RoCBertForMultipleChoice,
+ RoCBertForPreTraining,
+ RoCBertForQuestionAnswering,
+ RoCBertForSequenceClassification,
+ RoCBertForTokenClassification,
+ RoCBertModel,
+ RoCBertPreTrainedModel,
+ load_tf_weights_in_roc_bert,
+ )
+ from .models.roformer import (
+ RoFormerForCausalLM,
+ RoFormerForMaskedLM,
+ RoFormerForMultipleChoice,
+ RoFormerForQuestionAnswering,
+ RoFormerForSequenceClassification,
+ RoFormerForTokenClassification,
+ RoFormerModel,
+ RoFormerPreTrainedModel,
+ load_tf_weights_in_roformer,
+ )
+ from .models.rt_detr import (
+ RTDetrForObjectDetection,
+ RTDetrModel,
+ RTDetrPreTrainedModel,
+ RTDetrResNetBackbone,
+ RTDetrResNetPreTrainedModel,
+ )
+ from .models.rwkv import (
+ RwkvForCausalLM,
+ RwkvModel,
+ RwkvPreTrainedModel,
+ )
+ from .models.sam import (
+ SamModel,
+ SamPreTrainedModel,
+ )
+ from .models.seamless_m4t import (
+ SeamlessM4TCodeHifiGan,
+ SeamlessM4TForSpeechToSpeech,
+ SeamlessM4TForSpeechToText,
+ SeamlessM4TForTextToSpeech,
+ SeamlessM4TForTextToText,
+ SeamlessM4THifiGan,
+ SeamlessM4TModel,
+ SeamlessM4TPreTrainedModel,
+ SeamlessM4TTextToUnitForConditionalGeneration,
+ SeamlessM4TTextToUnitModel,
+ )
+ from .models.seamless_m4t_v2 import (
+ SeamlessM4Tv2ForSpeechToSpeech,
+ SeamlessM4Tv2ForSpeechToText,
+ SeamlessM4Tv2ForTextToSpeech,
+ SeamlessM4Tv2ForTextToText,
+ SeamlessM4Tv2Model,
+ SeamlessM4Tv2PreTrainedModel,
+ )
+ from .models.segformer import (
+ SegformerDecodeHead,
+ SegformerForImageClassification,
+ SegformerForSemanticSegmentation,
+ SegformerModel,
+ SegformerPreTrainedModel,
+ )
+ from .models.seggpt import (
+ SegGptForImageSegmentation,
+ SegGptModel,
+ SegGptPreTrainedModel,
+ )
+ from .models.sew import (
+ SEWForCTC,
+ SEWForSequenceClassification,
+ SEWModel,
+ SEWPreTrainedModel,
+ )
+ from .models.sew_d import (
+ SEWDForCTC,
+ SEWDForSequenceClassification,
+ SEWDModel,
+ SEWDPreTrainedModel,
+ )
+ from .models.siglip import (
+ SiglipForImageClassification,
+ SiglipModel,
+ SiglipPreTrainedModel,
+ SiglipTextModel,
+ SiglipVisionModel,
+ )
+ from .models.speech_encoder_decoder import SpeechEncoderDecoderModel
+ from .models.speech_to_text import (
+ Speech2TextForConditionalGeneration,
+ Speech2TextModel,
+ Speech2TextPreTrainedModel,
+ )
+ from .models.speecht5 import (
+ SpeechT5ForSpeechToSpeech,
+ SpeechT5ForSpeechToText,
+ SpeechT5ForTextToSpeech,
+ SpeechT5HifiGan,
+ SpeechT5Model,
+ SpeechT5PreTrainedModel,
+ )
+ from .models.splinter import (
+ SplinterForPreTraining,
+ SplinterForQuestionAnswering,
+ SplinterModel,
+ SplinterPreTrainedModel,
+ )
+ from .models.squeezebert import (
+ SqueezeBertForMaskedLM,
+ SqueezeBertForMultipleChoice,
+ SqueezeBertForQuestionAnswering,
+ SqueezeBertForSequenceClassification,
+ SqueezeBertForTokenClassification,
+ SqueezeBertModel,
+ SqueezeBertPreTrainedModel,
+ )
+ from .models.stablelm import (
+ StableLmForCausalLM,
+ StableLmForSequenceClassification,
+ StableLmForTokenClassification,
+ StableLmModel,
+ StableLmPreTrainedModel,
+ )
+ from .models.starcoder2 import (
+ Starcoder2ForCausalLM,
+ Starcoder2ForSequenceClassification,
+ Starcoder2ForTokenClassification,
+ Starcoder2Model,
+ Starcoder2PreTrainedModel,
+ )
+ from .models.superpoint import (
+ SuperPointForKeypointDetection,
+ SuperPointPreTrainedModel,
+ )
+ from .models.swiftformer import (
+ SwiftFormerForImageClassification,
+ SwiftFormerModel,
+ SwiftFormerPreTrainedModel,
+ )
+ from .models.swin import (
+ SwinBackbone,
+ SwinForImageClassification,
+ SwinForMaskedImageModeling,
+ SwinModel,
+ SwinPreTrainedModel,
+ )
+ from .models.swin2sr import (
+ Swin2SRForImageSuperResolution,
+ Swin2SRModel,
+ Swin2SRPreTrainedModel,
+ )
+ from .models.swinv2 import (
+ Swinv2Backbone,
+ Swinv2ForImageClassification,
+ Swinv2ForMaskedImageModeling,
+ Swinv2Model,
+ Swinv2PreTrainedModel,
+ )
+ from .models.switch_transformers import (
+ SwitchTransformersEncoderModel,
+ SwitchTransformersForConditionalGeneration,
+ SwitchTransformersModel,
+ SwitchTransformersPreTrainedModel,
+ SwitchTransformersSparseMLP,
+ SwitchTransformersTop1Router,
+ )
+ from .models.t5 import (
+ T5EncoderModel,
+ T5ForConditionalGeneration,
+ T5ForQuestionAnswering,
+ T5ForSequenceClassification,
+ T5ForTokenClassification,
+ T5Model,
+ T5PreTrainedModel,
+ load_tf_weights_in_t5,
+ )
+ from .models.table_transformer import (
+ TableTransformerForObjectDetection,
+ TableTransformerModel,
+ TableTransformerPreTrainedModel,
+ )
+ from .models.tapas import (
+ TapasForMaskedLM,
+ TapasForQuestionAnswering,
+ TapasForSequenceClassification,
+ TapasModel,
+ TapasPreTrainedModel,
+ load_tf_weights_in_tapas,
+ )
+ from .models.textnet import (
+ TextNetBackbone,
+ TextNetForImageClassification,
+ TextNetModel,
+ TextNetPreTrainedModel,
+ )
+ from .models.time_series_transformer import (
+ TimeSeriesTransformerForPrediction,
+ TimeSeriesTransformerModel,
+ TimeSeriesTransformerPreTrainedModel,
+ )
+ from .models.timesformer import (
+ TimesformerForVideoClassification,
+ TimesformerModel,
+ TimesformerPreTrainedModel,
+ )
+ from .models.timm_backbone import TimmBackbone
+ from .models.timm_wrapper import (
+ TimmWrapperForImageClassification,
+ TimmWrapperModel,
+ TimmWrapperPreTrainedModel,
+ )
+ from .models.trocr import (
+ TrOCRForCausalLM,
+ TrOCRPreTrainedModel,
+ )
+ from .models.tvp import (
+ TvpForVideoGrounding,
+ TvpModel,
+ TvpPreTrainedModel,
+ )
+ from .models.udop import (
+ UdopEncoderModel,
+ UdopForConditionalGeneration,
+ UdopModel,
+ UdopPreTrainedModel,
+ )
+ from .models.umt5 import (
+ UMT5EncoderModel,
+ UMT5ForConditionalGeneration,
+ UMT5ForQuestionAnswering,
+ UMT5ForSequenceClassification,
+ UMT5ForTokenClassification,
+ UMT5Model,
+ UMT5PreTrainedModel,
+ )
+ from .models.unispeech import (
+ UniSpeechForCTC,
+ UniSpeechForPreTraining,
+ UniSpeechForSequenceClassification,
+ UniSpeechModel,
+ UniSpeechPreTrainedModel,
+ )
+ from .models.unispeech_sat import (
+ UniSpeechSatForAudioFrameClassification,
+ UniSpeechSatForCTC,
+ UniSpeechSatForPreTraining,
+ UniSpeechSatForSequenceClassification,
+ UniSpeechSatForXVector,
+ UniSpeechSatModel,
+ UniSpeechSatPreTrainedModel,
+ )
+ from .models.univnet import UnivNetModel
+ from .models.upernet import (
+ UperNetForSemanticSegmentation,
+ UperNetPreTrainedModel,
+ )
+ from .models.video_llava import (
+ VideoLlavaForConditionalGeneration,
+ VideoLlavaPreTrainedModel,
+ VideoLlavaProcessor,
+ )
+ from .models.videomae import (
+ VideoMAEForPreTraining,
+ VideoMAEForVideoClassification,
+ VideoMAEModel,
+ VideoMAEPreTrainedModel,
+ )
+ from .models.vilt import (
+ ViltForImageAndTextRetrieval,
+ ViltForImagesAndTextClassification,
+ ViltForMaskedLM,
+ ViltForQuestionAnswering,
+ ViltForTokenClassification,
+ ViltModel,
+ ViltPreTrainedModel,
+ )
+ from .models.vipllava import (
+ VipLlavaForConditionalGeneration,
+ VipLlavaPreTrainedModel,
+ )
+ from .models.vision_encoder_decoder import VisionEncoderDecoderModel
+ from .models.vision_text_dual_encoder import VisionTextDualEncoderModel
+ from .models.visual_bert import (
+ VisualBertForMultipleChoice,
+ VisualBertForPreTraining,
+ VisualBertForQuestionAnswering,
+ VisualBertForRegionToPhraseAlignment,
+ VisualBertForVisualReasoning,
+ VisualBertModel,
+ VisualBertPreTrainedModel,
+ )
+ from .models.vit import (
+ ViTForImageClassification,
+ ViTForMaskedImageModeling,
+ ViTModel,
+ ViTPreTrainedModel,
+ )
+ from .models.vit_mae import (
+ ViTMAEForPreTraining,
+ ViTMAEModel,
+ ViTMAEPreTrainedModel,
+ )
+ from .models.vit_msn import (
+ ViTMSNForImageClassification,
+ ViTMSNModel,
+ ViTMSNPreTrainedModel,
+ )
+ from .models.vitdet import (
+ VitDetBackbone,
+ VitDetModel,
+ VitDetPreTrainedModel,
+ )
+ from .models.vitmatte import (
+ VitMatteForImageMatting,
+ VitMattePreTrainedModel,
+ )
+ from .models.vitpose import (
+ VitPoseForPoseEstimation,
+ VitPosePreTrainedModel,
+ )
+ from .models.vitpose_backbone import VitPoseBackbone, VitPoseBackbonePreTrainedModel
+ from .models.vits import (
+ VitsModel,
+ VitsPreTrainedModel,
+ )
+ from .models.vivit import (
+ VivitForVideoClassification,
+ VivitModel,
+ VivitPreTrainedModel,
+ )
+ from .models.wav2vec2 import (
+ Wav2Vec2ForAudioFrameClassification,
+ Wav2Vec2ForCTC,
+ Wav2Vec2ForMaskedLM,
+ Wav2Vec2ForPreTraining,
+ Wav2Vec2ForSequenceClassification,
+ Wav2Vec2ForXVector,
+ Wav2Vec2Model,
+ Wav2Vec2PreTrainedModel,
+ )
+ from .models.wav2vec2_bert import (
+ Wav2Vec2BertForAudioFrameClassification,
+ Wav2Vec2BertForCTC,
+ Wav2Vec2BertForSequenceClassification,
+ Wav2Vec2BertForXVector,
+ Wav2Vec2BertModel,
+ Wav2Vec2BertPreTrainedModel,
+ )
+ from .models.wav2vec2_conformer import (
+ Wav2Vec2ConformerForAudioFrameClassification,
+ Wav2Vec2ConformerForCTC,
+ Wav2Vec2ConformerForPreTraining,
+ Wav2Vec2ConformerForSequenceClassification,
+ Wav2Vec2ConformerForXVector,
+ Wav2Vec2ConformerModel,
+ Wav2Vec2ConformerPreTrainedModel,
+ )
+ from .models.wavlm import (
+ WavLMForAudioFrameClassification,
+ WavLMForCTC,
+ WavLMForSequenceClassification,
+ WavLMForXVector,
+ WavLMModel,
+ WavLMPreTrainedModel,
+ )
+ from .models.whisper import (
+ WhisperForAudioClassification,
+ WhisperForCausalLM,
+ WhisperForConditionalGeneration,
+ WhisperModel,
+ WhisperPreTrainedModel,
+ )
+ from .models.x_clip import (
+ XCLIPModel,
+ XCLIPPreTrainedModel,
+ XCLIPTextModel,
+ XCLIPVisionModel,
+ )
+ from .models.xglm import (
+ XGLMForCausalLM,
+ XGLMModel,
+ XGLMPreTrainedModel,
+ )
+ from .models.xlm import (
+ XLMForMultipleChoice,
+ XLMForQuestionAnswering,
+ XLMForQuestionAnsweringSimple,
+ XLMForSequenceClassification,
+ XLMForTokenClassification,
+ XLMModel,
+ XLMPreTrainedModel,
+ XLMWithLMHeadModel,
+ )
+ from .models.xlm_roberta import (
+ XLMRobertaForCausalLM,
+ XLMRobertaForMaskedLM,
+ XLMRobertaForMultipleChoice,
+ XLMRobertaForQuestionAnswering,
+ XLMRobertaForSequenceClassification,
+ XLMRobertaForTokenClassification,
+ XLMRobertaModel,
+ XLMRobertaPreTrainedModel,
+ )
+ from .models.xlm_roberta_xl import (
+ XLMRobertaXLForCausalLM,
+ XLMRobertaXLForMaskedLM,
+ XLMRobertaXLForMultipleChoice,
+ XLMRobertaXLForQuestionAnswering,
+ XLMRobertaXLForSequenceClassification,
+ XLMRobertaXLForTokenClassification,
+ XLMRobertaXLModel,
+ XLMRobertaXLPreTrainedModel,
+ )
+ from .models.xlnet import (
+ XLNetForMultipleChoice,
+ XLNetForQuestionAnswering,
+ XLNetForQuestionAnsweringSimple,
+ XLNetForSequenceClassification,
+ XLNetForTokenClassification,
+ XLNetLMHeadModel,
+ XLNetModel,
+ XLNetPreTrainedModel,
+ load_tf_weights_in_xlnet,
+ )
+ from .models.xmod import (
+ XmodForCausalLM,
+ XmodForMaskedLM,
+ XmodForMultipleChoice,
+ XmodForQuestionAnswering,
+ XmodForSequenceClassification,
+ XmodForTokenClassification,
+ XmodModel,
+ XmodPreTrainedModel,
+ )
+ from .models.yolos import (
+ YolosForObjectDetection,
+ YolosModel,
+ YolosPreTrainedModel,
+ )
+ from .models.yoso import (
+ YosoForMaskedLM,
+ YosoForMultipleChoice,
+ YosoForQuestionAnswering,
+ YosoForSequenceClassification,
+ YosoForTokenClassification,
+ YosoModel,
+ YosoPreTrainedModel,
+ )
+ from .models.zamba import (
+ ZambaForCausalLM,
+ ZambaForSequenceClassification,
+ ZambaModel,
+ ZambaPreTrainedModel,
+ )
+ from .models.zoedepth import (
+ ZoeDepthForDepthEstimation,
+ ZoeDepthPreTrainedModel,
+ )
+
+ # Optimization
+ from .optimization import (
+ Adafactor,
+ AdamW,
+ get_constant_schedule,
+ get_constant_schedule_with_warmup,
+ get_cosine_schedule_with_warmup,
+ get_cosine_with_hard_restarts_schedule_with_warmup,
+ get_inverse_sqrt_schedule,
+ get_linear_schedule_with_warmup,
+ get_polynomial_decay_schedule_with_warmup,
+ get_scheduler,
+ get_wsd_schedule,
+ )
+ from .pytorch_utils import Conv1D, apply_chunking_to_forward, prune_layer
+
+ # Trainer
+ from .trainer import Trainer
+ from .trainer_pt_utils import torch_distributed_zero_first
+ from .trainer_seq2seq import Seq2SeqTrainer
+
+ # TensorFlow
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ # Import the same objects as dummies to get them in the namespace.
+ # They will raise an import error if the user tries to instantiate / use them.
+ from .utils.dummy_tf_objects import *
+ else:
+ from .benchmark.benchmark_args_tf import TensorFlowBenchmarkArguments
+
+ # Benchmarks
+ from .benchmark.benchmark_tf import TensorFlowBenchmark
+ from .generation import (
+ TFForcedBOSTokenLogitsProcessor,
+ TFForcedEOSTokenLogitsProcessor,
+ TFForceTokensLogitsProcessor,
+ TFGenerationMixin,
+ TFLogitsProcessor,
+ TFLogitsProcessorList,
+ TFLogitsWarper,
+ TFMinLengthLogitsProcessor,
+ TFNoBadWordsLogitsProcessor,
+ TFNoRepeatNGramLogitsProcessor,
+ TFRepetitionPenaltyLogitsProcessor,
+ TFSuppressTokensAtBeginLogitsProcessor,
+ TFSuppressTokensLogitsProcessor,
+ TFTemperatureLogitsWarper,
+ TFTopKLogitsWarper,
+ TFTopPLogitsWarper,
+ )
+ from .keras_callbacks import KerasMetricCallback, PushToHubCallback
+ from .modeling_tf_utils import (
+ TFPreTrainedModel,
+ TFSequenceSummary,
+ TFSharedEmbeddings,
+ shape_list,
+ )
+
+ # TensorFlow model imports
+ from .models.albert import (
+ TFAlbertForMaskedLM,
+ TFAlbertForMultipleChoice,
+ TFAlbertForPreTraining,
+ TFAlbertForQuestionAnswering,
+ TFAlbertForSequenceClassification,
+ TFAlbertForTokenClassification,
+ TFAlbertMainLayer,
+ TFAlbertModel,
+ TFAlbertPreTrainedModel,
+ )
+ from .models.auto import (
+ TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
+ TF_MODEL_FOR_CAUSAL_LM_MAPPING,
+ TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING,
+ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
+ TF_MODEL_FOR_MASK_GENERATION_MAPPING,
+ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
+ TF_MODEL_FOR_MASKED_LM_MAPPING,
+ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
+ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
+ TF_MODEL_FOR_PRETRAINING_MAPPING,
+ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
+ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
+ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
+ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
+ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
+ TF_MODEL_FOR_TEXT_ENCODING_MAPPING,
+ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
+ TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
+ TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING,
+ TF_MODEL_MAPPING,
+ TF_MODEL_WITH_LM_HEAD_MAPPING,
+ TFAutoModel,
+ TFAutoModelForAudioClassification,
+ TFAutoModelForCausalLM,
+ TFAutoModelForDocumentQuestionAnswering,
+ TFAutoModelForImageClassification,
+ TFAutoModelForMaskedImageModeling,
+ TFAutoModelForMaskedLM,
+ TFAutoModelForMaskGeneration,
+ TFAutoModelForMultipleChoice,
+ TFAutoModelForNextSentencePrediction,
+ TFAutoModelForPreTraining,
+ TFAutoModelForQuestionAnswering,
+ TFAutoModelForSemanticSegmentation,
+ TFAutoModelForSeq2SeqLM,
+ TFAutoModelForSequenceClassification,
+ TFAutoModelForSpeechSeq2Seq,
+ TFAutoModelForTableQuestionAnswering,
+ TFAutoModelForTextEncoding,
+ TFAutoModelForTokenClassification,
+ TFAutoModelForVision2Seq,
+ TFAutoModelForZeroShotImageClassification,
+ TFAutoModelWithLMHead,
+ )
+ from .models.bart import (
+ TFBartForConditionalGeneration,
+ TFBartForSequenceClassification,
+ TFBartModel,
+ TFBartPretrainedModel,
+ )
+ from .models.bert import (
+ TFBertForMaskedLM,
+ TFBertForMultipleChoice,
+ TFBertForNextSentencePrediction,
+ TFBertForPreTraining,
+ TFBertForQuestionAnswering,
+ TFBertForSequenceClassification,
+ TFBertForTokenClassification,
+ TFBertLMHeadModel,
+ TFBertMainLayer,
+ TFBertModel,
+ TFBertPreTrainedModel,
+ )
+ from .models.blenderbot import (
+ TFBlenderbotForConditionalGeneration,
+ TFBlenderbotModel,
+ TFBlenderbotPreTrainedModel,
+ )
+ from .models.blenderbot_small import (
+ TFBlenderbotSmallForConditionalGeneration,
+ TFBlenderbotSmallModel,
+ TFBlenderbotSmallPreTrainedModel,
+ )
+ from .models.blip import (
+ TFBlipForConditionalGeneration,
+ TFBlipForImageTextRetrieval,
+ TFBlipForQuestionAnswering,
+ TFBlipModel,
+ TFBlipPreTrainedModel,
+ TFBlipTextModel,
+ TFBlipVisionModel,
+ )
+ from .models.camembert import (
+ TFCamembertForCausalLM,
+ TFCamembertForMaskedLM,
+ TFCamembertForMultipleChoice,
+ TFCamembertForQuestionAnswering,
+ TFCamembertForSequenceClassification,
+ TFCamembertForTokenClassification,
+ TFCamembertModel,
+ TFCamembertPreTrainedModel,
+ )
+ from .models.clip import (
+ TFCLIPModel,
+ TFCLIPPreTrainedModel,
+ TFCLIPTextModel,
+ TFCLIPVisionModel,
+ )
+ from .models.convbert import (
+ TFConvBertForMaskedLM,
+ TFConvBertForMultipleChoice,
+ TFConvBertForQuestionAnswering,
+ TFConvBertForSequenceClassification,
+ TFConvBertForTokenClassification,
+ TFConvBertModel,
+ TFConvBertPreTrainedModel,
+ )
+ from .models.convnext import (
+ TFConvNextForImageClassification,
+ TFConvNextModel,
+ TFConvNextPreTrainedModel,
+ )
+ from .models.convnextv2 import (
+ TFConvNextV2ForImageClassification,
+ TFConvNextV2Model,
+ TFConvNextV2PreTrainedModel,
+ )
+ from .models.ctrl import (
+ TFCTRLForSequenceClassification,
+ TFCTRLLMHeadModel,
+ TFCTRLModel,
+ TFCTRLPreTrainedModel,
+ )
+ from .models.cvt import (
+ TFCvtForImageClassification,
+ TFCvtModel,
+ TFCvtPreTrainedModel,
+ )
+ from .models.data2vec import (
+ TFData2VecVisionForImageClassification,
+ TFData2VecVisionForSemanticSegmentation,
+ TFData2VecVisionModel,
+ TFData2VecVisionPreTrainedModel,
+ )
+ from .models.deberta import (
+ TFDebertaForMaskedLM,
+ TFDebertaForQuestionAnswering,
+ TFDebertaForSequenceClassification,
+ TFDebertaForTokenClassification,
+ TFDebertaModel,
+ TFDebertaPreTrainedModel,
+ )
+ from .models.deberta_v2 import (
+ TFDebertaV2ForMaskedLM,
+ TFDebertaV2ForMultipleChoice,
+ TFDebertaV2ForQuestionAnswering,
+ TFDebertaV2ForSequenceClassification,
+ TFDebertaV2ForTokenClassification,
+ TFDebertaV2Model,
+ TFDebertaV2PreTrainedModel,
+ )
+ from .models.deit import (
+ TFDeiTForImageClassification,
+ TFDeiTForImageClassificationWithTeacher,
+ TFDeiTForMaskedImageModeling,
+ TFDeiTModel,
+ TFDeiTPreTrainedModel,
+ )
+ from .models.deprecated.efficientformer import (
+ TFEfficientFormerForImageClassification,
+ TFEfficientFormerForImageClassificationWithTeacher,
+ TFEfficientFormerModel,
+ TFEfficientFormerPreTrainedModel,
+ )
+ from .models.deprecated.transfo_xl import (
+ TFAdaptiveEmbedding,
+ TFTransfoXLForSequenceClassification,
+ TFTransfoXLLMHeadModel,
+ TFTransfoXLMainLayer,
+ TFTransfoXLModel,
+ TFTransfoXLPreTrainedModel,
+ )
+ from .models.distilbert import (
+ TFDistilBertForMaskedLM,
+ TFDistilBertForMultipleChoice,
+ TFDistilBertForQuestionAnswering,
+ TFDistilBertForSequenceClassification,
+ TFDistilBertForTokenClassification,
+ TFDistilBertMainLayer,
+ TFDistilBertModel,
+ TFDistilBertPreTrainedModel,
+ )
+ from .models.dpr import (
+ TFDPRContextEncoder,
+ TFDPRPretrainedContextEncoder,
+ TFDPRPretrainedQuestionEncoder,
+ TFDPRPretrainedReader,
+ TFDPRQuestionEncoder,
+ TFDPRReader,
+ )
+ from .models.electra import (
+ TFElectraForMaskedLM,
+ TFElectraForMultipleChoice,
+ TFElectraForPreTraining,
+ TFElectraForQuestionAnswering,
+ TFElectraForSequenceClassification,
+ TFElectraForTokenClassification,
+ TFElectraModel,
+ TFElectraPreTrainedModel,
+ )
+ from .models.encoder_decoder import TFEncoderDecoderModel
+ from .models.esm import (
+ TFEsmForMaskedLM,
+ TFEsmForSequenceClassification,
+ TFEsmForTokenClassification,
+ TFEsmModel,
+ TFEsmPreTrainedModel,
+ )
+ from .models.flaubert import (
+ TFFlaubertForMultipleChoice,
+ TFFlaubertForQuestionAnsweringSimple,
+ TFFlaubertForSequenceClassification,
+ TFFlaubertForTokenClassification,
+ TFFlaubertModel,
+ TFFlaubertPreTrainedModel,
+ TFFlaubertWithLMHeadModel,
+ )
+ from .models.funnel import (
+ TFFunnelBaseModel,
+ TFFunnelForMaskedLM,
+ TFFunnelForMultipleChoice,
+ TFFunnelForPreTraining,
+ TFFunnelForQuestionAnswering,
+ TFFunnelForSequenceClassification,
+ TFFunnelForTokenClassification,
+ TFFunnelModel,
+ TFFunnelPreTrainedModel,
+ )
+ from .models.gpt2 import (
+ TFGPT2DoubleHeadsModel,
+ TFGPT2ForSequenceClassification,
+ TFGPT2LMHeadModel,
+ TFGPT2MainLayer,
+ TFGPT2Model,
+ TFGPT2PreTrainedModel,
+ )
+ from .models.gptj import (
+ TFGPTJForCausalLM,
+ TFGPTJForQuestionAnswering,
+ TFGPTJForSequenceClassification,
+ TFGPTJModel,
+ TFGPTJPreTrainedModel,
+ )
+ from .models.groupvit import (
+ TFGroupViTModel,
+ TFGroupViTPreTrainedModel,
+ TFGroupViTTextModel,
+ TFGroupViTVisionModel,
+ )
+ from .models.hubert import (
+ TFHubertForCTC,
+ TFHubertModel,
+ TFHubertPreTrainedModel,
+ )
+ from .models.idefics import (
+ TFIdeficsForVisionText2Text,
+ TFIdeficsModel,
+ TFIdeficsPreTrainedModel,
+ )
+ from .models.layoutlm import (
+ TFLayoutLMForMaskedLM,
+ TFLayoutLMForQuestionAnswering,
+ TFLayoutLMForSequenceClassification,
+ TFLayoutLMForTokenClassification,
+ TFLayoutLMMainLayer,
+ TFLayoutLMModel,
+ TFLayoutLMPreTrainedModel,
+ )
+ from .models.layoutlmv3 import (
+ TFLayoutLMv3ForQuestionAnswering,
+ TFLayoutLMv3ForSequenceClassification,
+ TFLayoutLMv3ForTokenClassification,
+ TFLayoutLMv3Model,
+ TFLayoutLMv3PreTrainedModel,
+ )
+ from .models.led import (
+ TFLEDForConditionalGeneration,
+ TFLEDModel,
+ TFLEDPreTrainedModel,
+ )
+ from .models.longformer import (
+ TFLongformerForMaskedLM,
+ TFLongformerForMultipleChoice,
+ TFLongformerForQuestionAnswering,
+ TFLongformerForSequenceClassification,
+ TFLongformerForTokenClassification,
+ TFLongformerModel,
+ TFLongformerPreTrainedModel,
+ )
+ from .models.lxmert import (
+ TFLxmertForPreTraining,
+ TFLxmertMainLayer,
+ TFLxmertModel,
+ TFLxmertPreTrainedModel,
+ TFLxmertVisualFeatureEncoder,
+ )
+ from .models.marian import (
+ TFMarianModel,
+ TFMarianMTModel,
+ TFMarianPreTrainedModel,
+ )
+ from .models.mbart import (
+ TFMBartForConditionalGeneration,
+ TFMBartModel,
+ TFMBartPreTrainedModel,
+ )
+ from .models.mistral import (
+ TFMistralForCausalLM,
+ TFMistralForSequenceClassification,
+ TFMistralModel,
+ TFMistralPreTrainedModel,
+ )
+ from .models.mobilebert import (
+ TFMobileBertForMaskedLM,
+ TFMobileBertForMultipleChoice,
+ TFMobileBertForNextSentencePrediction,
+ TFMobileBertForPreTraining,
+ TFMobileBertForQuestionAnswering,
+ TFMobileBertForSequenceClassification,
+ TFMobileBertForTokenClassification,
+ TFMobileBertMainLayer,
+ TFMobileBertModel,
+ TFMobileBertPreTrainedModel,
+ )
+ from .models.mobilevit import (
+ TFMobileViTForImageClassification,
+ TFMobileViTForSemanticSegmentation,
+ TFMobileViTModel,
+ TFMobileViTPreTrainedModel,
+ )
+ from .models.mpnet import (
+ TFMPNetForMaskedLM,
+ TFMPNetForMultipleChoice,
+ TFMPNetForQuestionAnswering,
+ TFMPNetForSequenceClassification,
+ TFMPNetForTokenClassification,
+ TFMPNetMainLayer,
+ TFMPNetModel,
+ TFMPNetPreTrainedModel,
+ )
+ from .models.mt5 import (
+ TFMT5EncoderModel,
+ TFMT5ForConditionalGeneration,
+ TFMT5Model,
+ )
+ from .models.openai import (
+ TFOpenAIGPTDoubleHeadsModel,
+ TFOpenAIGPTForSequenceClassification,
+ TFOpenAIGPTLMHeadModel,
+ TFOpenAIGPTMainLayer,
+ TFOpenAIGPTModel,
+ TFOpenAIGPTPreTrainedModel,
+ )
+ from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
+ from .models.pegasus import (
+ TFPegasusForConditionalGeneration,
+ TFPegasusModel,
+ TFPegasusPreTrainedModel,
+ )
+ from .models.rag import (
+ TFRagModel,
+ TFRagPreTrainedModel,
+ TFRagSequenceForGeneration,
+ TFRagTokenForGeneration,
+ )
+ from .models.regnet import (
+ TFRegNetForImageClassification,
+ TFRegNetModel,
+ TFRegNetPreTrainedModel,
+ )
+ from .models.rembert import (
+ TFRemBertForCausalLM,
+ TFRemBertForMaskedLM,
+ TFRemBertForMultipleChoice,
+ TFRemBertForQuestionAnswering,
+ TFRemBertForSequenceClassification,
+ TFRemBertForTokenClassification,
+ TFRemBertModel,
+ TFRemBertPreTrainedModel,
+ )
+ from .models.resnet import (
+ TFResNetForImageClassification,
+ TFResNetModel,
+ TFResNetPreTrainedModel,
+ )
+ from .models.roberta import (
+ TFRobertaForCausalLM,
+ TFRobertaForMaskedLM,
+ TFRobertaForMultipleChoice,
+ TFRobertaForQuestionAnswering,
+ TFRobertaForSequenceClassification,
+ TFRobertaForTokenClassification,
+ TFRobertaMainLayer,
+ TFRobertaModel,
+ TFRobertaPreTrainedModel,
+ )
+ from .models.roberta_prelayernorm import (
+ TFRobertaPreLayerNormForCausalLM,
+ TFRobertaPreLayerNormForMaskedLM,
+ TFRobertaPreLayerNormForMultipleChoice,
+ TFRobertaPreLayerNormForQuestionAnswering,
+ TFRobertaPreLayerNormForSequenceClassification,
+ TFRobertaPreLayerNormForTokenClassification,
+ TFRobertaPreLayerNormMainLayer,
+ TFRobertaPreLayerNormModel,
+ TFRobertaPreLayerNormPreTrainedModel,
+ )
+ from .models.roformer import (
+ TFRoFormerForCausalLM,
+ TFRoFormerForMaskedLM,
+ TFRoFormerForMultipleChoice,
+ TFRoFormerForQuestionAnswering,
+ TFRoFormerForSequenceClassification,
+ TFRoFormerForTokenClassification,
+ TFRoFormerModel,
+ TFRoFormerPreTrainedModel,
+ )
+ from .models.sam import (
+ TFSamModel,
+ TFSamPreTrainedModel,
+ )
+ from .models.segformer import (
+ TFSegformerDecodeHead,
+ TFSegformerForImageClassification,
+ TFSegformerForSemanticSegmentation,
+ TFSegformerModel,
+ TFSegformerPreTrainedModel,
+ )
+ from .models.speech_to_text import (
+ TFSpeech2TextForConditionalGeneration,
+ TFSpeech2TextModel,
+ TFSpeech2TextPreTrainedModel,
+ )
+ from .models.swiftformer import (
+ TFSwiftFormerForImageClassification,
+ TFSwiftFormerModel,
+ TFSwiftFormerPreTrainedModel,
+ )
+ from .models.swin import (
+ TFSwinForImageClassification,
+ TFSwinForMaskedImageModeling,
+ TFSwinModel,
+ TFSwinPreTrainedModel,
+ )
+ from .models.t5 import (
+ TFT5EncoderModel,
+ TFT5ForConditionalGeneration,
+ TFT5Model,
+ TFT5PreTrainedModel,
+ )
+ from .models.tapas import (
+ TFTapasForMaskedLM,
+ TFTapasForQuestionAnswering,
+ TFTapasForSequenceClassification,
+ TFTapasModel,
+ TFTapasPreTrainedModel,
+ )
+ from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel
+ from .models.vision_text_dual_encoder import TFVisionTextDualEncoderModel
+ from .models.vit import (
+ TFViTForImageClassification,
+ TFViTModel,
+ TFViTPreTrainedModel,
+ )
+ from .models.vit_mae import (
+ TFViTMAEForPreTraining,
+ TFViTMAEModel,
+ TFViTMAEPreTrainedModel,
+ )
+ from .models.wav2vec2 import (
+ TFWav2Vec2ForCTC,
+ TFWav2Vec2ForSequenceClassification,
+ TFWav2Vec2Model,
+ TFWav2Vec2PreTrainedModel,
+ )
+ from .models.whisper import (
+ TFWhisperForConditionalGeneration,
+ TFWhisperModel,
+ TFWhisperPreTrainedModel,
+ )
+ from .models.xglm import (
+ TFXGLMForCausalLM,
+ TFXGLMModel,
+ TFXGLMPreTrainedModel,
+ )
+ from .models.xlm import (
+ TFXLMForMultipleChoice,
+ TFXLMForQuestionAnsweringSimple,
+ TFXLMForSequenceClassification,
+ TFXLMForTokenClassification,
+ TFXLMMainLayer,
+ TFXLMModel,
+ TFXLMPreTrainedModel,
+ TFXLMWithLMHeadModel,
+ )
+ from .models.xlm_roberta import (
+ TFXLMRobertaForCausalLM,
+ TFXLMRobertaForMaskedLM,
+ TFXLMRobertaForMultipleChoice,
+ TFXLMRobertaForQuestionAnswering,
+ TFXLMRobertaForSequenceClassification,
+ TFXLMRobertaForTokenClassification,
+ TFXLMRobertaModel,
+ TFXLMRobertaPreTrainedModel,
+ )
+ from .models.xlnet import (
+ TFXLNetForMultipleChoice,
+ TFXLNetForQuestionAnsweringSimple,
+ TFXLNetForSequenceClassification,
+ TFXLNetForTokenClassification,
+ TFXLNetLMHeadModel,
+ TFXLNetMainLayer,
+ TFXLNetModel,
+ TFXLNetPreTrainedModel,
+ )
+
+ # Optimization
+ from .optimization_tf import (
+ AdamWeightDecay,
+ GradientAccumulator,
+ WarmUp,
+ create_optimizer,
+ )
+
+ try:
+ if not (
+ is_librosa_available()
+ and is_essentia_available()
+ and is_scipy_available()
+ and is_torch_available()
+ and is_pretty_midi_available()
+ ):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects import *
+ else:
+ from .models.pop2piano import (
+ Pop2PianoFeatureExtractor,
+ Pop2PianoProcessor,
+ Pop2PianoTokenizer,
+ )
+
+ try:
+ if not is_torchaudio_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_torchaudio_objects import *
+ else:
+ from .models.musicgen_melody import MusicgenMelodyFeatureExtractor, MusicgenMelodyProcessor
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ # Import the same objects as dummies to get them in the namespace.
+ # They will raise an import error if the user tries to instantiate / use them.
+ from .utils.dummy_flax_objects import *
+ else:
+ from .generation import (
+ FlaxForcedBOSTokenLogitsProcessor,
+ FlaxForcedEOSTokenLogitsProcessor,
+ FlaxForceTokensLogitsProcessor,
+ FlaxGenerationMixin,
+ FlaxLogitsProcessor,
+ FlaxLogitsProcessorList,
+ FlaxLogitsWarper,
+ FlaxMinLengthLogitsProcessor,
+ FlaxSuppressTokensAtBeginLogitsProcessor,
+ FlaxSuppressTokensLogitsProcessor,
+ FlaxTemperatureLogitsWarper,
+ FlaxTopKLogitsWarper,
+ FlaxTopPLogitsWarper,
+ FlaxWhisperTimeStampLogitsProcessor,
+ )
+ from .modeling_flax_utils import FlaxPreTrainedModel
+
+ # Flax model imports
+ from .models.albert import (
+ FlaxAlbertForMaskedLM,
+ FlaxAlbertForMultipleChoice,
+ FlaxAlbertForPreTraining,
+ FlaxAlbertForQuestionAnswering,
+ FlaxAlbertForSequenceClassification,
+ FlaxAlbertForTokenClassification,
+ FlaxAlbertModel,
+ FlaxAlbertPreTrainedModel,
+ )
+ from .models.auto import (
+ FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
+ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
+ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
+ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
+ FLAX_MODEL_FOR_PRETRAINING_MAPPING,
+ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
+ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
+ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
+ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
+ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
+ FLAX_MODEL_MAPPING,
+ FlaxAutoModel,
+ FlaxAutoModelForCausalLM,
+ FlaxAutoModelForImageClassification,
+ FlaxAutoModelForMaskedLM,
+ FlaxAutoModelForMultipleChoice,
+ FlaxAutoModelForNextSentencePrediction,
+ FlaxAutoModelForPreTraining,
+ FlaxAutoModelForQuestionAnswering,
+ FlaxAutoModelForSeq2SeqLM,
+ FlaxAutoModelForSequenceClassification,
+ FlaxAutoModelForSpeechSeq2Seq,
+ FlaxAutoModelForTokenClassification,
+ FlaxAutoModelForVision2Seq,
+ )
+ from .models.bart import (
+ FlaxBartDecoderPreTrainedModel,
+ FlaxBartForCausalLM,
+ FlaxBartForConditionalGeneration,
+ FlaxBartForQuestionAnswering,
+ FlaxBartForSequenceClassification,
+ FlaxBartModel,
+ FlaxBartPreTrainedModel,
+ )
+ from .models.beit import (
+ FlaxBeitForImageClassification,
+ FlaxBeitForMaskedImageModeling,
+ FlaxBeitModel,
+ FlaxBeitPreTrainedModel,
+ )
+ from .models.bert import (
+ FlaxBertForCausalLM,
+ FlaxBertForMaskedLM,
+ FlaxBertForMultipleChoice,
+ FlaxBertForNextSentencePrediction,
+ FlaxBertForPreTraining,
+ FlaxBertForQuestionAnswering,
+ FlaxBertForSequenceClassification,
+ FlaxBertForTokenClassification,
+ FlaxBertModel,
+ FlaxBertPreTrainedModel,
+ )
+ from .models.big_bird import (
+ FlaxBigBirdForCausalLM,
+ FlaxBigBirdForMaskedLM,
+ FlaxBigBirdForMultipleChoice,
+ FlaxBigBirdForPreTraining,
+ FlaxBigBirdForQuestionAnswering,
+ FlaxBigBirdForSequenceClassification,
+ FlaxBigBirdForTokenClassification,
+ FlaxBigBirdModel,
+ FlaxBigBirdPreTrainedModel,
+ )
+ from .models.blenderbot import (
+ FlaxBlenderbotForConditionalGeneration,
+ FlaxBlenderbotModel,
+ FlaxBlenderbotPreTrainedModel,
+ )
+ from .models.blenderbot_small import (
+ FlaxBlenderbotSmallForConditionalGeneration,
+ FlaxBlenderbotSmallModel,
+ FlaxBlenderbotSmallPreTrainedModel,
+ )
+ from .models.bloom import (
+ FlaxBloomForCausalLM,
+ FlaxBloomModel,
+ FlaxBloomPreTrainedModel,
+ )
+ from .models.clip import (
+ FlaxCLIPModel,
+ FlaxCLIPPreTrainedModel,
+ FlaxCLIPTextModel,
+ FlaxCLIPTextModelWithProjection,
+ FlaxCLIPTextPreTrainedModel,
+ FlaxCLIPVisionModel,
+ FlaxCLIPVisionPreTrainedModel,
+ )
+ from .models.dinov2 import (
+ FlaxDinov2ForImageClassification,
+ FlaxDinov2Model,
+ FlaxDinov2PreTrainedModel,
+ )
+ from .models.distilbert import (
+ FlaxDistilBertForMaskedLM,
+ FlaxDistilBertForMultipleChoice,
+ FlaxDistilBertForQuestionAnswering,
+ FlaxDistilBertForSequenceClassification,
+ FlaxDistilBertForTokenClassification,
+ FlaxDistilBertModel,
+ FlaxDistilBertPreTrainedModel,
+ )
+ from .models.electra import (
+ FlaxElectraForCausalLM,
+ FlaxElectraForMaskedLM,
+ FlaxElectraForMultipleChoice,
+ FlaxElectraForPreTraining,
+ FlaxElectraForQuestionAnswering,
+ FlaxElectraForSequenceClassification,
+ FlaxElectraForTokenClassification,
+ FlaxElectraModel,
+ FlaxElectraPreTrainedModel,
+ )
+ from .models.encoder_decoder import FlaxEncoderDecoderModel
+ from .models.gemma import (
+ FlaxGemmaForCausalLM,
+ FlaxGemmaModel,
+ FlaxGemmaPreTrainedModel,
+ )
+ from .models.gpt2 import (
+ FlaxGPT2LMHeadModel,
+ FlaxGPT2Model,
+ FlaxGPT2PreTrainedModel,
+ )
+ from .models.gpt_neo import (
+ FlaxGPTNeoForCausalLM,
+ FlaxGPTNeoModel,
+ FlaxGPTNeoPreTrainedModel,
+ )
+ from .models.gptj import (
+ FlaxGPTJForCausalLM,
+ FlaxGPTJModel,
+ FlaxGPTJPreTrainedModel,
+ )
+ from .models.llama import (
+ FlaxLlamaForCausalLM,
+ FlaxLlamaModel,
+ FlaxLlamaPreTrainedModel,
+ )
+ from .models.longt5 import (
+ FlaxLongT5ForConditionalGeneration,
+ FlaxLongT5Model,
+ FlaxLongT5PreTrainedModel,
+ )
+ from .models.marian import (
+ FlaxMarianModel,
+ FlaxMarianMTModel,
+ FlaxMarianPreTrainedModel,
+ )
+ from .models.mbart import (
+ FlaxMBartForConditionalGeneration,
+ FlaxMBartForQuestionAnswering,
+ FlaxMBartForSequenceClassification,
+ FlaxMBartModel,
+ FlaxMBartPreTrainedModel,
+ )
+ from .models.mistral import (
+ FlaxMistralForCausalLM,
+ FlaxMistralModel,
+ FlaxMistralPreTrainedModel,
+ )
+ from .models.mt5 import (
+ FlaxMT5EncoderModel,
+ FlaxMT5ForConditionalGeneration,
+ FlaxMT5Model,
+ )
+ from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
+ from .models.pegasus import (
+ FlaxPegasusForConditionalGeneration,
+ FlaxPegasusModel,
+ FlaxPegasusPreTrainedModel,
+ )
+ from .models.regnet import (
+ FlaxRegNetForImageClassification,
+ FlaxRegNetModel,
+ FlaxRegNetPreTrainedModel,
+ )
+ from .models.resnet import (
+ FlaxResNetForImageClassification,
+ FlaxResNetModel,
+ FlaxResNetPreTrainedModel,
+ )
+ from .models.roberta import (
+ FlaxRobertaForCausalLM,
+ FlaxRobertaForMaskedLM,
+ FlaxRobertaForMultipleChoice,
+ FlaxRobertaForQuestionAnswering,
+ FlaxRobertaForSequenceClassification,
+ FlaxRobertaForTokenClassification,
+ FlaxRobertaModel,
+ FlaxRobertaPreTrainedModel,
+ )
+ from .models.roberta_prelayernorm import (
+ FlaxRobertaPreLayerNormForCausalLM,
+ FlaxRobertaPreLayerNormForMaskedLM,
+ FlaxRobertaPreLayerNormForMultipleChoice,
+ FlaxRobertaPreLayerNormForQuestionAnswering,
+ FlaxRobertaPreLayerNormForSequenceClassification,
+ FlaxRobertaPreLayerNormForTokenClassification,
+ FlaxRobertaPreLayerNormModel,
+ FlaxRobertaPreLayerNormPreTrainedModel,
+ )
+ from .models.roformer import (
+ FlaxRoFormerForMaskedLM,
+ FlaxRoFormerForMultipleChoice,
+ FlaxRoFormerForQuestionAnswering,
+ FlaxRoFormerForSequenceClassification,
+ FlaxRoFormerForTokenClassification,
+ FlaxRoFormerModel,
+ FlaxRoFormerPreTrainedModel,
+ )
+ from .models.speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
+ from .models.t5 import (
+ FlaxT5EncoderModel,
+ FlaxT5ForConditionalGeneration,
+ FlaxT5Model,
+ FlaxT5PreTrainedModel,
+ )
+ from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
+ from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
+ from .models.vit import (
+ FlaxViTForImageClassification,
+ FlaxViTModel,
+ FlaxViTPreTrainedModel,
+ )
+ from .models.wav2vec2 import (
+ FlaxWav2Vec2ForCTC,
+ FlaxWav2Vec2ForPreTraining,
+ FlaxWav2Vec2Model,
+ FlaxWav2Vec2PreTrainedModel,
+ )
+ from .models.whisper import (
+ FlaxWhisperForAudioClassification,
+ FlaxWhisperForConditionalGeneration,
+ FlaxWhisperModel,
+ FlaxWhisperPreTrainedModel,
+ )
+ from .models.xglm import (
+ FlaxXGLMForCausalLM,
+ FlaxXGLMModel,
+ FlaxXGLMPreTrainedModel,
+ )
+ from .models.xlm_roberta import (
+ FlaxXLMRobertaForCausalLM,
+ FlaxXLMRobertaForMaskedLM,
+ FlaxXLMRobertaForMultipleChoice,
+ FlaxXLMRobertaForQuestionAnswering,
+ FlaxXLMRobertaForSequenceClassification,
+ FlaxXLMRobertaForTokenClassification,
+ FlaxXLMRobertaModel,
+ FlaxXLMRobertaPreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ extra_objects={"__version__": __version__},
+ )
+
+
+if not is_tf_available() and not is_torch_available() and not is_flax_available():
+ logger.warning_advice(
+ "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. "
+ "Models won't be available and only tokenizers, configuration "
+ "and file/data utilities can be used."
+ )
diff --git a/.venv/lib/python3.11/site-packages/transformers/activations.py b/.venv/lib/python3.11/site-packages/transformers/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..2355fb5fed678d0de6e2c53f52644a35a691a34e
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/activations.py
@@ -0,0 +1,239 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections import OrderedDict
+
+import torch
+from packaging import version
+from torch import Tensor, nn
+
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class PytorchGELUTanh(nn.Module):
+ """
+ A fast C implementation of the tanh approximation of the GeLU activation function. See
+ https://arxiv.org/abs/1606.08415.
+
+ This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
+ match due to rounding errors.
+ """
+
+ def __init__(self):
+ super().__init__()
+ if version.parse(torch.__version__) < version.parse("1.12.0"):
+ raise ImportError(
+ f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
+ "PytorchGELUTanh. Please upgrade torch."
+ )
+
+ def forward(self, input: Tensor) -> Tensor:
+ return nn.functional.gelu(input, approximate="tanh")
+
+
+class NewGELUActivation(nn.Module):
+ """
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
+
+
+class GELUActivation(nn.Module):
+ """
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, use_gelu_python: bool = False):
+ super().__init__()
+ if use_gelu_python:
+ self.act = self._gelu_python
+ else:
+ self.act = nn.functional.gelu
+
+ def _gelu_python(self, input: Tensor) -> Tensor:
+ return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self.act(input)
+
+
+class FastGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
+
+
+class QuickGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return input * torch.sigmoid(1.702 * input)
+
+
+class ClippedGELUActivation(nn.Module):
+ """
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
+ https://arxiv.org/abs/2004.09602.
+
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+ initially created.
+
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, min: float, max: float):
+ if min > max:
+ raise ValueError(f"min should be < max (got min: {min}, max: {max})")
+
+ super().__init__()
+ self.min = min
+ self.max = max
+
+ def forward(self, x: Tensor) -> Tensor:
+ return torch.clip(gelu(x), self.min, self.max)
+
+
+class AccurateGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
+ https://github.com/hendrycks/GELUs
+
+ Implemented along with MEGA (Moving Average Equipped Gated Attention)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.precomputed_constant = math.sqrt(2 / math.pi)
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
+
+
+class MishActivation(nn.Module):
+ """
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
+ """
+
+ def __init__(self):
+ super().__init__()
+ if version.parse(torch.__version__) < version.parse("1.9.0"):
+ self.act = self._mish_python
+ else:
+ self.act = nn.functional.mish
+
+ def _mish_python(self, input: Tensor) -> Tensor:
+ return input * torch.tanh(nn.functional.softplus(input))
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self.act(input)
+
+
+class LinearActivation(nn.Module):
+ """
+ Applies the linear activation function, i.e. forwarding input directly to output.
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return input
+
+
+class LaplaceActivation(nn.Module):
+ """
+ Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
+ https://arxiv.org/abs/2209.10655
+
+ Inspired by squared relu, but with bounded range and gradient for better stability
+ """
+
+ def forward(self, input, mu=0.707107, sigma=0.282095):
+ input = (input - mu).div(sigma * math.sqrt(2.0))
+ return 0.5 * (1.0 + torch.erf(input))
+
+
+class ReLUSquaredActivation(nn.Module):
+ """
+ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
+ """
+
+ def forward(self, input):
+ relu_applied = nn.functional.relu(input)
+ squared = torch.square(relu_applied)
+ return squared
+
+
+class ClassInstantier(OrderedDict):
+ def __getitem__(self, key):
+ content = super().__getitem__(key)
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
+ return cls(**kwargs)
+
+
+ACT2CLS = {
+ "gelu": GELUActivation,
+ "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
+ "gelu_fast": FastGELUActivation,
+ "gelu_new": NewGELUActivation,
+ "gelu_python": (GELUActivation, {"use_gelu_python": True}),
+ "gelu_pytorch_tanh": PytorchGELUTanh,
+ "gelu_accurate": AccurateGELUActivation,
+ "laplace": LaplaceActivation,
+ "leaky_relu": nn.LeakyReLU,
+ "linear": LinearActivation,
+ "mish": MishActivation,
+ "quick_gelu": QuickGELUActivation,
+ "relu": nn.ReLU,
+ "relu2": ReLUSquaredActivation,
+ "relu6": nn.ReLU6,
+ "sigmoid": nn.Sigmoid,
+ "silu": nn.SiLU,
+ "swish": nn.SiLU,
+ "tanh": nn.Tanh,
+}
+ACT2FN = ClassInstantier(ACT2CLS)
+
+
+def get_activation(activation_string):
+ if activation_string in ACT2FN:
+ return ACT2FN[activation_string]
+ else:
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
+
+
+# For backwards compatibility with: from activations import gelu_python
+gelu_python = get_activation("gelu_python")
+gelu_new = get_activation("gelu_new")
+gelu = get_activation("gelu")
+gelu_fast = get_activation("gelu_fast")
+quick_gelu = get_activation("quick_gelu")
+silu = get_activation("silu")
+mish = get_activation("mish")
+linear_act = get_activation("linear")
diff --git a/.venv/lib/python3.11/site-packages/transformers/activations_tf.py b/.venv/lib/python3.11/site-packages/transformers/activations_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..d12b73ea45176f3a4bc42cdabe8b73078a3b90f2
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/activations_tf.py
@@ -0,0 +1,147 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import tensorflow as tf
+from packaging.version import parse
+
+
+try:
+ import tf_keras as keras
+except (ModuleNotFoundError, ImportError):
+ import keras
+
+ if parse(keras.__version__).major > 2:
+ raise ValueError(
+ "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
+ "Transformers. Please install the backwards-compatible tf-keras package with "
+ "`pip install tf-keras`."
+ )
+
+
+def _gelu(x):
+ """
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+ initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
+ https://arxiv.org/abs/1606.08415
+ """
+ x = tf.convert_to_tensor(x)
+ cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
+
+ return x * cdf
+
+
+def _gelu_new(x):
+ """
+ Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
+
+ Args:
+ x: float Tensor to perform activation
+
+ Returns:
+ `x` with the GELU activation applied.
+ """
+ x = tf.convert_to_tensor(x)
+ pi = tf.cast(math.pi, x.dtype)
+ coeff = tf.cast(0.044715, x.dtype)
+ cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
+
+ return x * cdf
+
+
+def mish(x):
+ x = tf.convert_to_tensor(x)
+
+ return x * tf.tanh(tf.math.softplus(x))
+
+
+def gelu_fast(x):
+ x = tf.convert_to_tensor(x)
+ coeff1 = tf.cast(0.044715, x.dtype)
+ coeff2 = tf.cast(0.7978845608, x.dtype)
+
+ return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
+
+
+def quick_gelu(x):
+ x = tf.convert_to_tensor(x)
+ coeff = tf.cast(1.702, x.dtype)
+ return x * tf.math.sigmoid(coeff * x)
+
+
+def gelu_10(x):
+ """
+ Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as
+ it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to
+ https://arxiv.org/abs/2004.09602
+
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+ initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
+ https://arxiv.org/abs/1606.08415 :param x: :return:
+ """
+ return tf.clip_by_value(_gelu(x), -10, 10)
+
+
+def glu(x, axis=-1):
+ """
+ Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
+ the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
+
+ Args:
+ `x`: float Tensor to perform activation
+ `axis`: dimension across which `x` be split in half
+
+ Returns:
+ `x` with the GLU activation applied (with its size halved across the dimension `axis`).
+ """
+ a, b = tf.split(x, 2, axis=axis)
+ return a * tf.math.sigmoid(b)
+
+
+if parse(tf.version.VERSION) >= parse("2.4"):
+
+ def approximate_gelu_wrap(x):
+ return keras.activations.gelu(x, approximate=True)
+
+ gelu = keras.activations.gelu
+ gelu_new = approximate_gelu_wrap
+else:
+ gelu = _gelu
+ gelu_new = _gelu_new
+
+
+ACT2FN = {
+ "gelu": gelu,
+ "gelu_10": gelu_10,
+ "gelu_fast": gelu_fast,
+ "gelu_new": gelu_new,
+ "glu": glu,
+ "mish": mish,
+ "quick_gelu": quick_gelu,
+ "relu": keras.activations.relu,
+ "sigmoid": keras.activations.sigmoid,
+ "silu": keras.activations.swish,
+ "swish": keras.activations.swish,
+ "tanh": keras.activations.tanh,
+}
+
+
+def get_tf_activation(activation_string):
+ if activation_string in ACT2FN:
+ return ACT2FN[activation_string]
+ else:
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
diff --git a/.venv/lib/python3.11/site-packages/transformers/audio_utils.py b/.venv/lib/python3.11/site-packages/transformers/audio_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f11287f309cf1437b84928b6721052b7ba4531
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/audio_utils.py
@@ -0,0 +1,1123 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
+and remove unnecessary dependencies.
+"""
+
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+
+
+def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
+ """
+ Convert frequency from hertz to mels.
+
+ Args:
+ freq (`float` or `np.ndarray`):
+ The frequency, or multiple frequencies, in hertz (Hz).
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
+
+ Returns:
+ `float` or `np.ndarray`: The frequencies on the mel scale.
+ """
+
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
+
+ if mel_scale == "htk":
+ return 2595.0 * np.log10(1.0 + (freq / 700.0))
+ elif mel_scale == "kaldi":
+ return 1127.0 * np.log(1.0 + (freq / 700.0))
+
+ min_log_hertz = 1000.0
+ min_log_mel = 15.0
+ logstep = 27.0 / np.log(6.4)
+ mels = 3.0 * freq / 200.0
+
+ if isinstance(freq, np.ndarray):
+ log_region = freq >= min_log_hertz
+ mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
+ elif freq >= min_log_hertz:
+ mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
+
+ return mels
+
+
+def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
+ """
+ Convert frequency from mels to hertz.
+
+ Args:
+ mels (`float` or `np.ndarray`):
+ The frequency, or multiple frequencies, in mels.
+ mel_scale (`str`, *optional*, `"htk"`):
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
+
+ Returns:
+ `float` or `np.ndarray`: The frequencies in hertz.
+ """
+
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
+
+ if mel_scale == "htk":
+ return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
+ elif mel_scale == "kaldi":
+ return 700.0 * (np.exp(mels / 1127.0) - 1.0)
+
+ min_log_hertz = 1000.0
+ min_log_mel = 15.0
+ logstep = np.log(6.4) / 27.0
+ freq = 200.0 * mels / 3.0
+
+ if isinstance(mels, np.ndarray):
+ log_region = mels >= min_log_mel
+ freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
+ elif mels >= min_log_mel:
+ freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
+
+ return freq
+
+
+def hertz_to_octave(
+ freq: Union[float, np.ndarray], tuning: Optional[float] = 0.0, bins_per_octave: Optional[int] = 12
+):
+ """
+ Convert frequency from hertz to fractional octave numbers.
+ Adapted from *librosa*.
+
+ Args:
+ freq (`float` or `np.ndarray`):
+ The frequency, or multiple frequencies, in hertz (Hz).
+ tuning (`float`, defaults to `0.`):
+ Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave.
+ bins_per_octave (`int`, defaults to `12`):
+ Number of bins per octave.
+
+ Returns:
+ `float` or `np.ndarray`: The frequencies on the octave scale.
+ """
+ stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave)
+ octave = np.log2(freq / (float(stuttgart_pitch) / 16))
+ return octave
+
+
+def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
+ """
+ Creates a triangular filter bank.
+
+ Adapted from *torchaudio* and *librosa*.
+
+ Args:
+ fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
+ Discrete frequencies of the FFT bins in Hz.
+ filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
+ Center frequencies of the triangular filters to create, in Hz.
+
+ Returns:
+ `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
+ """
+ filter_diff = np.diff(filter_freqs)
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
+ return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
+
+
+def chroma_filter_bank(
+ num_frequency_bins: int,
+ num_chroma: int,
+ sampling_rate: int,
+ tuning: float = 0.0,
+ power: Optional[float] = 2.0,
+ weighting_parameters: Optional[Tuple[float]] = (5.0, 2),
+ start_at_c_chroma: Optional[bool] = True,
+):
+ """
+ Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins.
+
+ Adapted from *librosa*.
+
+ Args:
+ num_frequency_bins (`int`):
+ Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
+ num_chroma (`int`):
+ Number of chroma bins (i.e pitch classes).
+ sampling_rate (`float`):
+ Sample rate of the audio waveform.
+ tuning (`float`):
+ Tuning deviation from A440 in fractions of a chroma bin.
+ power (`float`, *optional*, defaults to 2.0):
+ If 12.0, normalizes each column with their L2 norm. If 1.0, normalizes each column with their L1 norm.
+ weighting_parameters (`Tuple[float]`, *optional*, defaults to `(5., 2.)`):
+ If specified, apply a Gaussian weighting parameterized by the first element of the tuple being the center and
+ the second element being the Gaussian half-width.
+ start_at_c_chroma (`float`, *optional*, defaults to `True`):
+ If True, the filter bank will start at the 'C' pitch class. Otherwise, it will start at 'A'.
+ Returns:
+ `np.ndarray` of shape `(num_frequency_bins, num_chroma)`
+ """
+ # Get the FFT bins, not counting the DC component
+ frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:]
+
+ freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma)
+
+ # make up a value for the 0 Hz bin = 1.5 octaves below bin 1
+ # (so chroma is 50% rotated from bin 1, and bin width is broad)
+ freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins))
+
+ bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1]))
+
+ chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T
+
+ num_chroma2 = np.round(float(num_chroma) / 2)
+
+ # Project into range -num_chroma/2 .. num_chroma/2
+ # add on fixed offset of 10*num_chroma to ensure all values passed to
+ # rem are positive
+ chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2
+
+ # Gaussian bumps - 2*D to make them narrower
+ chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2)
+
+ # normalize each column
+ if power is not None:
+ chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power)
+
+ # Maybe apply scaling for fft bins
+ if weighting_parameters is not None:
+ center, half_width = weighting_parameters
+ chroma_filters *= np.tile(
+ np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)),
+ (num_chroma, 1),
+ )
+
+ if start_at_c_chroma:
+ chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0)
+
+ # remove aliasing columns, copy to ensure row-contiguity
+ return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)])
+
+
+def mel_filter_bank(
+ num_frequency_bins: int,
+ num_mel_filters: int,
+ min_frequency: float,
+ max_frequency: float,
+ sampling_rate: int,
+ norm: Optional[str] = None,
+ mel_scale: str = "htk",
+ triangularize_in_mel_space: bool = False,
+) -> np.ndarray:
+ """
+ Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
+ various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
+ are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
+ features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
+
+ Different banks of mel filters were introduced in the literature. The following variations are supported:
+
+ - MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
+ bandwidth of `[0, 4600]` Hz.
+ - MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
+ bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
+ - MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
+ speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
+ - HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
+ 12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
+
+ This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
+ `melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
+
+ Args:
+ num_frequency_bins (`int`):
+ Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
+ num_mel_filters (`int`):
+ Number of mel filters to generate.
+ min_frequency (`float`):
+ Lowest frequency of interest in Hz.
+ max_frequency (`float`):
+ Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
+ sampling_rate (`int`):
+ Sample rate of the audio waveform.
+ norm (`str`, *optional*):
+ If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
+ triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
+ If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
+ should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
+
+ Returns:
+ `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
+ projection matrix to go from a spectrogram to a mel spectrogram.
+ """
+ if norm is not None and norm != "slaney":
+ raise ValueError('norm must be one of None or "slaney"')
+
+ # center points of the triangular mel filters
+ mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
+ mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
+ mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
+ filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
+
+ if triangularize_in_mel_space:
+ # frequencies of FFT bins in Hz, but filters triangularized in mel space
+ fft_bin_width = sampling_rate / (num_frequency_bins * 2)
+ fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
+ filter_freqs = mel_freqs
+ else:
+ # frequencies of FFT bins in Hz
+ fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
+
+ mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
+
+ if norm is not None and norm == "slaney":
+ # Slaney-style mel is scaled to be approx constant energy per channel
+ enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
+ mel_filters *= np.expand_dims(enorm, 0)
+
+ if (mel_filters.max(axis=0) == 0.0).any():
+ warnings.warn(
+ "At least one mel filter has all zero values. "
+ f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
+ f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
+ )
+
+ return mel_filters
+
+
+def optimal_fft_length(window_length: int) -> int:
+ """
+ Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
+ already a power of two, rounds it up to the next power or two.
+
+ The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
+ of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
+ is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
+ it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
+ """
+ return 2 ** int(np.ceil(np.log2(window_length)))
+
+
+def window_function(
+ window_length: int,
+ name: str = "hann",
+ periodic: bool = True,
+ frame_length: Optional[int] = None,
+ center: bool = True,
+) -> np.ndarray:
+ """
+ Returns an array containing the specified window. This window is intended to be used with `stft`.
+
+ The following window types are supported:
+
+ - `"boxcar"`: a rectangular window
+ - `"hamming"`: the Hamming window
+ - `"hann"`: the Hann window
+ - `"povey"`: the Povey window
+
+ Args:
+ window_length (`int`):
+ The length of the window in samples.
+ name (`str`, *optional*, defaults to `"hann"`):
+ The name of the window function.
+ periodic (`bool`, *optional*, defaults to `True`):
+ Whether the window is periodic or symmetric.
+ frame_length (`int`, *optional*):
+ The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
+ than the frame length, so that it will be zero-padded.
+ center (`bool`, *optional*, defaults to `True`):
+ Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
+
+ Returns:
+ `np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
+ """
+ length = window_length + 1 if periodic else window_length
+
+ if name == "boxcar":
+ window = np.ones(length)
+ elif name in ["hamming", "hamming_window"]:
+ window = np.hamming(length)
+ elif name in ["hann", "hann_window"]:
+ window = np.hanning(length)
+ elif name in ["povey"]:
+ window = np.power(np.hanning(length), 0.85)
+ else:
+ raise ValueError(f"Unknown window function '{name}'")
+
+ if periodic:
+ window = window[:-1]
+
+ if frame_length is None:
+ return window
+
+ if window_length > frame_length:
+ raise ValueError(
+ f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
+ )
+
+ padded_window = np.zeros(frame_length)
+ offset = (frame_length - window_length) // 2 if center else 0
+ padded_window[offset : offset + window_length] = window
+ return padded_window
+
+
+# TODO This method does not support batching yet as we are mainly focused on inference.
+def spectrogram(
+ waveform: np.ndarray,
+ window: np.ndarray,
+ frame_length: int,
+ hop_length: int,
+ fft_length: Optional[int] = None,
+ power: Optional[float] = 1.0,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ onesided: bool = True,
+ preemphasis: Optional[float] = None,
+ mel_filters: Optional[np.ndarray] = None,
+ mel_floor: float = 1e-10,
+ log_mel: Optional[str] = None,
+ reference: float = 1.0,
+ min_value: float = 1e-10,
+ db_range: Optional[float] = None,
+ remove_dc_offset: Optional[bool] = None,
+ dtype: np.dtype = np.float32,
+) -> np.ndarray:
+ """
+ Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
+
+ This function can create the following kinds of spectrograms:
+
+ - amplitude spectrogram (`power = 1.0`)
+ - power spectrogram (`power = 2.0`)
+ - complex-valued spectrogram (`power = None`)
+ - log spectrogram (use `log_mel` argument)
+ - mel spectrogram (provide `mel_filters`)
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
+
+ How this works:
+
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
+ - hop_length` samples.
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
+ 3. The DFT is taken of each windowed frame.
+ 4. The results are stacked into a spectrogram.
+
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
+
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
+
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
+ typically the next power of two.
+
+ Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
+ `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
+ can be constructed.
+
+ Args:
+ waveform (`np.ndarray` of shape `(length,)`):
+ The input waveform. This must be a single real-valued, mono waveform.
+ window (`np.ndarray` of shape `(frame_length,)`):
+ The windowing function to apply, including zero-padding if necessary. The actual window length may be
+ shorter than `frame_length`, but we're assuming the array has already been zero-padded.
+ frame_length (`int`):
+ The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
+ allow smaller sizes.
+ hop_length (`int`):
+ The stride between successive analysis frames in samples.
+ fft_length (`int`, *optional*):
+ The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
+ For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
+ power (`float`, *optional*, defaults to 1.0):
+ If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
+ complex numbers.
+ center (`bool`, *optional*, defaults to `True`):
+ Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
+ `t` will start at time `t * hop_length`.
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
+ Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
+ (pad with edge values), `"reflect"` (pads with mirrored values).
+ onesided (`bool`, *optional*, defaults to `True`):
+ If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
+ frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
+ preemphasis (`float`, *optional*)
+ Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
+ mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
+ The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
+ mel_floor (`float`, *optional*, defaults to 1e-10):
+ Minimum value of mel frequency banks.
+ log_mel (`str`, *optional*):
+ How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
+ the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
+ used when `power` is not `None`.
+ reference (`float`, *optional*, defaults to 1.0):
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+ the loudest part to 0 dB. Must be greater than zero.
+ min_value (`float`, *optional*, defaults to `1e-10`):
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+ `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
+ amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
+ db_range (`float`, *optional*):
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+ remove_dc_offset (`bool`, *optional*):
+ Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
+ order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
+ Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
+ `np.complex64`.
+
+ Returns:
+ `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
+ `(num_mel_filters, length)` for a mel spectrogram.
+ """
+ window_length = len(window)
+
+ if fft_length is None:
+ fft_length = frame_length
+
+ if frame_length > fft_length:
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
+
+ if window_length != frame_length:
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
+
+ if hop_length <= 0:
+ raise ValueError("hop_length must be greater than zero")
+
+ if waveform.ndim != 1:
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
+
+ if np.iscomplexobj(waveform):
+ raise ValueError("Complex-valued input waveforms are not currently supported")
+
+ if power is None and mel_filters is not None:
+ raise ValueError(
+ "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
+ "Specify `power` to fix this issue."
+ )
+
+ # center pad the waveform
+ if center:
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
+ waveform = np.pad(waveform, padding, mode=pad_mode)
+
+ # promote to float64, since np.fft uses float64 internally
+ waveform = waveform.astype(np.float64)
+ window = window.astype(np.float64)
+
+ # split waveform into frames of frame_length size
+ num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
+
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
+ spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
+
+ # rfft is faster than fft
+ fft_func = np.fft.rfft if onesided else np.fft.fft
+ buffer = np.zeros(fft_length)
+
+ timestep = 0
+ for frame_idx in range(num_frames):
+ buffer[:frame_length] = waveform[timestep : timestep + frame_length]
+
+ if remove_dc_offset:
+ buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
+
+ if preemphasis is not None:
+ buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
+ buffer[0] *= 1 - preemphasis
+
+ buffer[:frame_length] *= window
+
+ spectrogram[frame_idx] = fft_func(buffer)
+ timestep += hop_length
+
+ # note: ** is much faster than np.power
+ if power is not None:
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
+
+ spectrogram = spectrogram.T
+
+ if mel_filters is not None:
+ spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
+
+ if power is not None and log_mel is not None:
+ if log_mel == "log":
+ spectrogram = np.log(spectrogram)
+ elif log_mel == "log10":
+ spectrogram = np.log10(spectrogram)
+ elif log_mel == "dB":
+ if power == 1.0:
+ spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
+ elif power == 2.0:
+ spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
+ else:
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
+ else:
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
+
+ spectrogram = np.asarray(spectrogram, dtype)
+
+ return spectrogram
+
+
+def spectrogram_batch(
+ waveform_list: List[np.ndarray],
+ window: np.ndarray,
+ frame_length: int,
+ hop_length: int,
+ fft_length: Optional[int] = None,
+ power: Optional[float] = 1.0,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ onesided: bool = True,
+ preemphasis: Optional[float] = None,
+ mel_filters: Optional[np.ndarray] = None,
+ mel_floor: float = 1e-10,
+ log_mel: Optional[str] = None,
+ reference: float = 1.0,
+ min_value: float = 1e-10,
+ db_range: Optional[float] = None,
+ remove_dc_offset: Optional[bool] = None,
+ dtype: np.dtype = np.float32,
+) -> List[np.ndarray]:
+ """
+ Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing.
+ This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting.
+
+ It supports generating various types of spectrograms:
+
+ - amplitude spectrogram (`power = 1.0`)
+ - power spectrogram (`power = 2.0`)
+ - complex-valued spectrogram (`power = None`)
+ - log spectrogram (use `log_mel` argument)
+ - mel spectrogram (provide `mel_filters`)
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
+
+ How this works:
+
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
+ - hop_length` samples.
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
+ 3. The DFT is taken of each windowed frame.
+ 4. The results are stacked into a spectrogram.
+
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
+
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
+
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
+ typically the next power of two.
+
+ Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`.
+
+ Args:
+ waveform_list (`List[np.ndarray]` with arrays of shape `(length,)`):
+ The list of input waveforms, each a single-channel (mono) signal.
+ window (`np.ndarray` of shape `(frame_length,)`):
+ The windowing function to apply, including zero-padding if necessary.
+ frame_length (`int`):
+ The length of each frame for analysis.
+ hop_length (`int`):
+ The step size between successive frames.
+ fft_length (`int`, *optional*):
+ The size of the FFT buffer, defining frequency bin resolution.
+ power (`float`, *optional*, defaults to 1.0):
+ Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex.
+ center (`bool`, *optional*, defaults to `True`):
+ Whether to center-pad the waveform frames.
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
+ The padding strategy when `center` is `True`.
+ onesided (`bool`, *optional*, defaults to `True`):
+ If True, returns a one-sided spectrogram for real input signals.
+ preemphasis (`float`, *optional*):
+ Applies a pre-emphasis filter to each frame.
+ mel_filters (`np.ndarray`, *optional*):
+ Mel filter bank for converting to mel spectrogram.
+ mel_floor (`float`, *optional*, defaults to 1e-10):
+ Floor value for mel spectrogram to avoid log(0).
+ log_mel (`str`, *optional*):
+ Specifies log scaling strategy; options are None, "log", "log10", "dB".
+ reference (`float`, *optional*, defaults to 1.0):
+ Reference value for dB conversion in log_mel.
+ min_value (`float`, *optional*, defaults to 1e-10):
+ Minimum floor value for log scale conversions.
+ db_range (`float`, *optional*):
+ Dynamic range for dB scale spectrograms.
+ remove_dc_offset (`bool`, *optional*):
+ Whether to remove the DC offset from each frame.
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
+ Data type of the output spectrogram.
+
+ Returns:
+ List[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform.
+ """
+ window_length = len(window)
+
+ if fft_length is None:
+ fft_length = frame_length
+
+ if frame_length > fft_length:
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
+
+ if window_length != frame_length:
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
+
+ if hop_length <= 0:
+ raise ValueError("hop_length must be greater than zero")
+
+ # Check the dimensions of the waveform , and if waveform is complex
+ for waveform in waveform_list:
+ if waveform.ndim != 1:
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
+ if np.iscomplexobj(waveform):
+ raise ValueError("Complex-valued input waveforms are not currently supported")
+ # Center pad the waveform
+ if center:
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
+ waveform_list = [
+ np.pad(
+ waveform,
+ padding,
+ mode=pad_mode,
+ )
+ for waveform in waveform_list
+ ]
+ original_waveform_lengths = [
+ len(waveform) for waveform in waveform_list
+ ] # these lengths will be used to remove padding later
+
+ # Batch pad the waveform
+ max_length = max(original_waveform_lengths)
+ padded_waveform_batch = np.array(
+ [
+ np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0)
+ for waveform in waveform_list
+ ],
+ dtype=dtype,
+ )
+
+ # Promote to float64, since np.fft uses float64 internally
+ padded_waveform_batch = padded_waveform_batch.astype(np.float64)
+ window = window.astype(np.float64)
+
+ # Split waveform into frames of frame_length size
+ num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length))
+ # these lengths will be used to remove padding later
+ true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths]
+ num_batches = padded_waveform_batch.shape[0]
+
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
+ spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64)
+
+ # rfft is faster than fft
+ fft_func = np.fft.rfft if onesided else np.fft.fft
+ buffer = np.zeros((num_batches, fft_length))
+
+ for frame_idx in range(num_frames):
+ timestep = frame_idx * hop_length
+ buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
+
+ if remove_dc_offset:
+ buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
+
+ if preemphasis is not None:
+ buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1]
+ buffer[:, 0] *= 1 - preemphasis
+
+ buffer[:, :frame_length] *= window
+
+ spectrogram[:, frame_idx] = fft_func(buffer)
+
+ # Note: ** is much faster than np.power
+ if power is not None:
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
+
+ # Apply mel filters if provided
+ if mel_filters is not None:
+ result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1]))
+ spectrogram = np.maximum(mel_floor, result)
+
+ # Convert to log scale if specified
+ if power is not None and log_mel is not None:
+ if log_mel == "log":
+ spectrogram = np.log(spectrogram)
+ elif log_mel == "log10":
+ spectrogram = np.log10(spectrogram)
+ elif log_mel == "dB":
+ if power == 1.0:
+ spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range)
+ elif power == 2.0:
+ spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range)
+ else:
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
+ else:
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
+
+ spectrogram = np.asarray(spectrogram, dtype)
+
+ spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))]
+
+ return spectrogram_list
+
+
+def power_to_db(
+ spectrogram: np.ndarray,
+ reference: float = 1.0,
+ min_value: float = 1e-10,
+ db_range: Optional[float] = None,
+) -> np.ndarray:
+ """
+ Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
+ logarithm properties for numerical stability.
+
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
+
+ Based on the implementation of `librosa.power_to_db`.
+
+ Args:
+ spectrogram (`np.ndarray`):
+ The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
+ reference (`float`, *optional*, defaults to 1.0):
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+ the loudest part to 0 dB. Must be greater than zero.
+ min_value (`float`, *optional*, defaults to `1e-10`):
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
+ db_range (`float`, *optional*):
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+ Returns:
+ `np.ndarray`: the spectrogram in decibels
+ """
+ if reference <= 0.0:
+ raise ValueError("reference must be greater than zero")
+ if min_value <= 0.0:
+ raise ValueError("min_value must be greater than zero")
+
+ reference = max(min_value, reference)
+
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
+
+ if db_range is not None:
+ if db_range <= 0.0:
+ raise ValueError("db_range must be greater than zero")
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
+
+ return spectrogram
+
+
+def power_to_db_batch(
+ spectrogram: np.ndarray,
+ reference: float = 1.0,
+ min_value: float = 1e-10,
+ db_range: Optional[float] = None,
+) -> np.ndarray:
+ """
+ Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`,
+ using basic logarithm properties for numerical stability.
+
+ This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram.
+
+ Args:
+ spectrogram (`np.ndarray`):
+ The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
+ Note that a power spectrogram has the amplitudes squared!
+ reference (`float`, *optional*, defaults to 1.0):
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+ the loudest part to 0 dB. Must be greater than zero.
+ min_value (`float`, *optional*, defaults to `1e-10`):
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
+ db_range (`float`, *optional*):
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+ Returns:
+ `np.ndarray`: the batch of spectrograms in decibels
+ """
+ if reference <= 0.0:
+ raise ValueError("reference must be greater than zero")
+ if min_value <= 0.0:
+ raise ValueError("min_value must be greater than zero")
+
+ reference = max(min_value, reference)
+
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
+
+ if db_range is not None:
+ if db_range <= 0.0:
+ raise ValueError("db_range must be greater than zero")
+ # Apply db_range clipping per batch item
+ max_values = spectrogram.max(axis=(1, 2), keepdims=True)
+ spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
+
+ return spectrogram
+
+
+def amplitude_to_db(
+ spectrogram: np.ndarray,
+ reference: float = 1.0,
+ min_value: float = 1e-5,
+ db_range: Optional[float] = None,
+) -> np.ndarray:
+ """
+ Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
+ basic logarithm properties for numerical stability.
+
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
+
+ Args:
+ spectrogram (`np.ndarray`):
+ The input amplitude (mel) spectrogram.
+ reference (`float`, *optional*, defaults to 1.0):
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+ the loudest part to 0 dB. Must be greater than zero.
+ min_value (`float`, *optional*, defaults to `1e-5`):
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
+ db_range (`float`, *optional*):
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+ Returns:
+ `np.ndarray`: the spectrogram in decibels
+ """
+ if reference <= 0.0:
+ raise ValueError("reference must be greater than zero")
+ if min_value <= 0.0:
+ raise ValueError("min_value must be greater than zero")
+
+ reference = max(min_value, reference)
+
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
+
+ if db_range is not None:
+ if db_range <= 0.0:
+ raise ValueError("db_range must be greater than zero")
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
+
+ return spectrogram
+
+
+def amplitude_to_db_batch(
+ spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: Optional[float] = None
+) -> np.ndarray:
+ """
+ Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`,
+ using basic logarithm properties for numerical stability.
+
+ The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram.
+
+ Args:
+ spectrogram (`np.ndarray`):
+ The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
+ reference (`float`, *optional*, defaults to 1.0):
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+ the loudest part to 0 dB. Must be greater than zero.
+ min_value (`float`, *optional*, defaults to `1e-5`):
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
+ db_range (`float`, *optional*):
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+ Returns:
+ `np.ndarray`: the batch of spectrograms in decibels
+ """
+ if reference <= 0.0:
+ raise ValueError("reference must be greater than zero")
+ if min_value <= 0.0:
+ raise ValueError("min_value must be greater than zero")
+
+ reference = max(min_value, reference)
+
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
+
+ if db_range is not None:
+ if db_range <= 0.0:
+ raise ValueError("db_range must be greater than zero")
+ # Apply db_range clipping per batch item
+ max_values = spectrogram.max(axis=(1, 2), keepdims=True)
+ spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
+
+ return spectrogram
+
+
+### deprecated functions below this line ###
+
+
+def get_mel_filter_banks(
+ nb_frequency_bins: int,
+ nb_mel_filters: int,
+ frequency_min: float,
+ frequency_max: float,
+ sample_rate: int,
+ norm: Optional[str] = None,
+ mel_scale: str = "htk",
+) -> np.array:
+ warnings.warn(
+ "The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
+ FutureWarning,
+ )
+ return mel_filter_bank(
+ num_frequency_bins=nb_frequency_bins,
+ num_mel_filters=nb_mel_filters,
+ min_frequency=frequency_min,
+ max_frequency=frequency_max,
+ sampling_rate=sample_rate,
+ norm=norm,
+ mel_scale=mel_scale,
+ )
+
+
+def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
+ """
+ In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed
+ segments called `frames`.
+
+ The window length (window_length) defines how much of the signal is contained in each frame, while the hop length
+ defines the step between the beginning of each new frame.
+
+
+ Args:
+ waveform (`np.array` of shape `(sample_length,)`):
+ The raw waveform which will be split into smaller chunks.
+ hop_length (`int`, *optional*, defaults to 160):
+ Step between each window of the waveform.
+ fft_window_size (`int`, *optional*, defaults to 400):
+ Defines the size of the window.
+ center (`bool`, defaults to `True`):
+ Whether or not to center each frame around the middle of the frame. Centering is done by reflecting the
+ waveform on the left and on the right.
+
+ Return:
+ framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
+ The framed waveforms that can be fed to `np.fft`.
+ """
+ warnings.warn(
+ "The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
+ FutureWarning,
+ )
+ frames = []
+ for i in range(0, waveform.shape[0] + 1, hop_length):
+ if center:
+ half_window = (fft_window_size - 1) // 2 + 1
+ start = i - half_window if i > half_window else 0
+ end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
+ frame = waveform[start:end]
+ if start == 0:
+ padd_width = (-i + half_window, 0)
+ frame = np.pad(frame, pad_width=padd_width, mode="reflect")
+
+ elif end == waveform.shape[0]:
+ padd_width = (0, (i - waveform.shape[0] + half_window))
+ frame = np.pad(frame, pad_width=padd_width, mode="reflect")
+
+ else:
+ frame = waveform[i : i + fft_window_size]
+ frame_width = frame.shape[0]
+ if frame_width < waveform.shape[0]:
+ frame = np.lib.pad(
+ frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0
+ )
+ frames.append(frame)
+
+ frames = np.stack(frames, 0)
+ return frames
+
+
+def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
+ """
+ Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
+ as `torch.stft`.
+
+ Args:
+ frames (`np.array` of dimension `(num_frames, fft_window_size)`):
+ A framed audio signal obtained using `audio_utils.fram_wav`.
+ windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`:
+ A array representing the function that will be used to reduces the amplitude of the discontinuities at the
+ boundaries of each frame when computing the STFT. Each frame will be multiplied by the windowing_function.
+ For more information on the discontinuities, called *Spectral leakage*, refer to [this
+ tutorial]https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf
+ fft_window_size (`int`, *optional*):
+ Size of the window om which the Fourier transform is applied. This controls the frequency resolution of the
+ spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. The number of
+ frequency bins (`nb_frequency_bins`) used to divide the window into equal strips is equal to
+ `(1+fft_window_size)//2`. An increase of the fft_window_size slows the calculus time proportionnally.
+
+ Example:
+
+ ```python
+ >>> from transformers.audio_utils import stft, fram_wave
+ >>> import numpy as np
+
+ >>> audio = np.random.rand(50)
+ >>> fft_window_size = 10
+ >>> hop_length = 2
+ >>> framed_audio = fram_wave(audio, hop_length, fft_window_size)
+ >>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1))
+ ```
+
+ Returns:
+ spectrogram (`np.ndarray`):
+ A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm
+ """
+ warnings.warn(
+ "The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
+ FutureWarning,
+ )
+ frame_size = frames.shape[1]
+
+ if fft_window_size is None:
+ fft_window_size = frame_size
+
+ if fft_window_size < frame_size:
+ raise ValueError("FFT size must greater or equal the frame size")
+ # number of FFT bins to store
+ nb_frequency_bins = (fft_window_size >> 1) + 1
+
+ spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64)
+ fft_signal = np.zeros(fft_window_size)
+
+ for f, frame in enumerate(frames):
+ if windowing_function is not None:
+ np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
+ else:
+ fft_signal[:frame_size] = frame
+ spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]
+ return spectrogram.T
diff --git a/.venv/lib/python3.11/site-packages/transformers/cache_utils.py b/.venv/lib/python3.11/site-packages/transformers/cache_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f38fc8f9824d3b88eb00296451994d16347eab2a
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/cache_utils.py
@@ -0,0 +1,2148 @@
+import copy
+import importlib.metadata
+import json
+import os
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from packaging import version
+
+from .configuration_utils import PretrainedConfig
+from .utils import (
+ is_hqq_available,
+ is_optimum_quanto_available,
+ is_torchdynamo_compiling,
+ logging,
+)
+from .utils.deprecation import deprecate_kwarg
+
+
+if is_hqq_available():
+ from hqq.core.quantize import Quantizer as HQQQuantizer
+
+logger = logging.get_logger(__name__)
+
+
+class Cache(torch.nn.Module):
+ """
+ Base, abstract class for all caches. The actual data structure is specific to each subclass.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
+ cache to be created.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ raise NotImplementedError("Make sure to implement `update` in a subclass.")
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
+
+ # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
+ # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
+ # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
+ # we change naming to be more explicit
+ def get_max_length(self) -> Optional[int]:
+ logger.warning_once(
+ "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. "
+ "Calling `get_max_cache()` will raise error from v4.48"
+ )
+ return self.get_max_cache_shape()
+
+ def get_max_cache_shape(self) -> Optional[int]:
+ """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
+ raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")
+
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
+ # Cache without size limit -> all cache is usable
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
+ max_length = self.get_max_cache_shape()
+ previous_seq_length = self.get_seq_length(layer_idx)
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
+ return max_length - new_seq_length
+ return previous_seq_length
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ for layer_idx in range(len(self.key_cache)):
+ if self.key_cache[layer_idx] != []:
+ device = self.key_cache[layer_idx].device
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+ if self.value_cache[layer_idx] != []:
+ device = self.value_cache[layer_idx].device
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+ @property
+ def seen_tokens(self):
+ logger.warning_once(
+ "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
+ "model input instead."
+ )
+ if hasattr(self, "_seen_tokens"):
+ return self._seen_tokens
+ else:
+ return None
+
+
+@dataclass
+class CacheConfig:
+ """
+ Base class for cache configs
+ """
+
+ cache_implementation: None
+
+ @classmethod
+ def from_dict(cls, config_dict, **kwargs):
+ """
+ Constructs a CacheConfig instance from a dictionary of parameters.
+ Args:
+ config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
+ **kwargs: Additional keyword arguments to override dictionary values.
+
+ Returns:
+ CacheConfig: Instance of CacheConfig constructed from the dictionary.
+ """
+ config = cls(**config_dict)
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(config, key):
+ setattr(config, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+ return config
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ use_diff (`bool`, *optional*, defaults to `True`):
+ If set to `True`, only the difference between the config instance and the default
+ `QuantizationConfig()` is serialized to JSON file.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ config_dict = self.to_dict()
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ writer.write(json_string)
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ return copy.deepcopy(self.__dict__)
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
+ def __iter__(self):
+ """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
+ for attr, value in copy.deepcopy(self.__dict__).items():
+ yield attr, value
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ def to_json_string(self):
+ """
+ Serializes this instance to a JSON formatted string.
+ Returns:
+ str: JSON formatted string representing the configuration instance.
+ """
+ return json.dumps(self.__dict__, indent=2) + "\n"
+
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
+ def update(self, **kwargs):
+ """
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
+ returning all the unused kwargs.
+
+ Args:
+ kwargs (`Dict[str, Any]`):
+ Dictionary of attributes to tentatively update this class.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
+ """
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ to_remove.append(key)
+
+ # Remove all the attributes that were updated, without modifying the input dict
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
+ return unused_kwargs
+
+
+@dataclass
+class QuantizedCacheConfig(CacheConfig):
+ """
+ Configuration class for quantized cache settings.
+
+ Attributes:
+ backend (`str`, *optional*, defaults to `"quanto"`):
+ Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
+ nbits (`Optional[int]`, *optional*, defaults to 4):
+ Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
+ axis_key (`int`, *optional*, defaults to 0):
+ Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
+ axis_value (`int`, *optional*, defaults to 0):
+ Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
+ q_group_size (`Optional[int]`, *optional*, defaults to 64):
+ Size of the quantization group, should be a divisor of the model's hidden dimension.
+ Defaults to 64.
+ residual_length (`Optional[int]`, *optional*, defaults to 128):
+ Length of the residual cache which will always be stored in original presicion.
+ Defaults to 128.
+ compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
+ The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
+ device (`str`, *optional*, defaults to `"cpu"`):
+ Device on which to perform computations, should be same as the model's device.
+ """
+
+ def __init__(
+ self,
+ backend: str = "quanto",
+ nbits: Optional[int] = 4,
+ axis_key: Optional[int] = 0,
+ axis_value: Optional[int] = 0,
+ q_group_size: Optional[int] = 64,
+ residual_length: Optional[int] = 128,
+ compute_dtype: Optional[torch.dtype] = torch.float16,
+ device: Optional[str] = "cpu",
+ ):
+ self.backend = backend
+ self.nbits = nbits
+ self.axis_key = axis_key
+ self.axis_value = axis_value
+ self.q_group_size = q_group_size
+ self.residual_length = residual_length
+ self.compute_dtype = compute_dtype
+ self.device = device
+
+ def validate(self):
+ """Validates if the arguments passed are correct"""
+
+ incorrect_arg_msg = (
+ "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
+ "but found {found_value}"
+ )
+ # Check that the values are reasonable in general (nbits, axis)
+ # Later in QuantizedCache init we check if they are supported for that particular backend
+ if self.nbits not in [1, 2, 3, 4, 8]:
+ raise ValueError(
+ incorrect_arg_msg.format(
+ key="nbits",
+ correct_value="2 or 4 or 8",
+ found_value=self.nbits,
+ ),
+ )
+ if self.q_group_size <= 0:
+ raise ValueError(
+ incorrect_arg_msg.format(
+ key="q_group_size",
+ correct_value="a positive integer",
+ found_value=self.q_group_size,
+ ),
+ )
+ if self.residual_length < 0:
+ raise ValueError(
+ incorrect_arg_msg.format(
+ key="residual_length",
+ correct_value="a positive integer",
+ found_value=self.residual_length,
+ ),
+ )
+
+ if self.axis_key not in [0, 1, -1]:
+ raise ValueError(
+ incorrect_arg_msg.format(
+ key="axis_key",
+ correct_value="`1` or `0`, `-1`",
+ found_value=self.axis_key,
+ ),
+ )
+
+ if self.axis_value not in [0, 1, -1]:
+ raise ValueError(
+ incorrect_arg_msg.format(
+ key="axis_value",
+ correct_value="`1` or `0` or `-1`",
+ found_value=self.axis_value,
+ ),
+ )
+
+
+@dataclass
+class StaticCacheConfig(CacheConfig):
+ """
+ Configuration class for static cache settings.
+ """
+
+ cache_implementation = "static"
+
+ def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
+ self.batch_size = batch_size
+ self.max_cache_len = max_cache_len
+ self.device = device
+
+ def validate(self):
+ """Validates if the arguments passed are correct"""
+
+ incorrect_arg_msg = (
+ "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
+ "but found {found_value}"
+ )
+
+ if self.batch_size <= 0:
+ raise ValueError(
+ incorrect_arg_msg.format(
+ key="batch_size",
+ correct_value="> 0",
+ found_value=self.batch_size,
+ ),
+ )
+
+ if self.max_cache_len <= 0:
+ raise ValueError(
+ incorrect_arg_msg.format(
+ key="max_cache_len",
+ correct_value="> 0",
+ found_value=self.max_cache_len,
+ ),
+ )
+
+
+class DynamicCache(Cache):
+ """
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
+
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+ `[batch_size, num_heads, seq_len, head_dim]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> past_key_values = DynamicCache()
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ DynamicCache()
+ ```
+ """
+
+ @deprecate_kwarg("num_hidden_layers", version="4.47.0")
+ def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
+ super().__init__()
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
+ self.key_cache: List[torch.Tensor] = []
+ self.value_cache: List[torch.Tensor] = []
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+ """
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+ sequence length.
+ """
+ if layer_idx < len(self):
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+ def __iter__(self):
+ """
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+ keys and values
+ """
+ for layer_idx in range(len(self)):
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
+
+ def __len__(self):
+ """
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+ to the number of layers in the model.
+ """
+ return len(self.key_cache)
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ # Update the cache
+ if key_states is not None:
+ if len(self.key_cache) <= layer_idx:
+ # There may be skipped layers, fill them with empty lists
+ for _ in range(len(self.key_cache), layer_idx):
+ self.key_cache.append([])
+ self.value_cache.append([])
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ elif (
+ len(self.key_cache[layer_idx]) == 0
+ ): # fills previously skipped layers; checking for tensor causes errors
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ is_empty_layer = (
+ len(self.key_cache) == 0 # no cache in any layer
+ or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
+ or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
+ )
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
+ return layer_seq_length
+
+ def get_max_cache_shape(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
+ return None
+
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+ """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
+ backward compatibility."""
+ legacy_cache = ()
+ for layer_idx in range(len(self)):
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+ return legacy_cache
+
+ @classmethod
+ @deprecate_kwarg("num_hidden_layers", version="4.47.0")
+ def from_legacy_cache(
+ cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
+ ) -> "DynamicCache":
+ """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
+ backward compatibility."""
+ cache = cls()
+ if past_key_values is not None:
+ for layer_idx in range(len(past_key_values)):
+ key_states, value_states = past_key_values[layer_idx]
+ cache.update(key_states, value_states, layer_idx)
+ return cache
+
+ def crop(self, max_length: int):
+ """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
+ negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
+ # In case it is negative
+ if max_length < 0:
+ max_length = self.get_seq_length() - abs(max_length)
+
+ if self.get_seq_length() <= max_length:
+ return
+
+ self._seen_tokens = max_length
+ for idx in range(len(self.key_cache)):
+ if self.key_cache[idx] != []:
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
+
+ @deprecate_kwarg("num_hidden_layers", version="4.47.0")
+ def batch_split(
+ self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
+ ) -> List["DynamicCache"]:
+ """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
+ `_split_model_inputs()` in `generation.utils`"""
+ out = []
+ for i in range(0, full_batch_size, split_size):
+ current_split = DynamicCache()
+ current_split._seen_tokens = self._seen_tokens
+ current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
+ current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
+ out.append(current_split)
+ return out
+
+ @classmethod
+ @deprecate_kwarg("num_hidden_layers", version="4.47.0")
+ def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
+ """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
+ `generation.utils`"""
+ cache = cls()
+ for idx in range(len(splits[0])):
+ key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
+ value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []]
+ if key_cache != []:
+ layer_keys = torch.cat(key_cache, dim=0)
+ layer_values = torch.cat(value_cache, dim=0)
+ cache.update(layer_keys, layer_values, idx)
+ return cache
+
+ def batch_repeat_interleave(self, repeats: int):
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
+ for layer_idx in range(len(self)):
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
+
+ def batch_select_indices(self, indices: torch.Tensor):
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
+ for layer_idx in range(len(self)):
+ self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
+ self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
+
+
+class OffloadedCache(DynamicCache):
+ """
+ A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
+ Useful for generating from models with very long context.
+
+ In addition to the default CUDA stream, where all forward() computations happen,
+ this class uses another stream, the prefetch stream, which it creates itself.
+ Since scheduling of operations on separate streams happens independently, this class uses
+ the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
+ The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
+ ensure the eviction is scheduled after all computations on that cache are finished.
+ """
+
+ def __init__(self) -> None:
+ if not torch.cuda.is_available():
+ raise RuntimeError("OffloadedCache can only be used with a GPU")
+ super().__init__()
+ self.original_device = []
+ self.prefetch_stream = torch.cuda.Stream()
+ self.beam_idx = None # used to delay beam search operations
+
+ def prefetch_layer(self, layer_idx: int):
+ "Starts prefetching the next layer cache"
+ if layer_idx < len(self):
+ with torch.cuda.stream(self.prefetch_stream):
+ # Prefetch next layer tensors to GPU
+ device = self.original_device[layer_idx]
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
+
+ def evict_previous_layer(self, layer_idx: int):
+ "Moves the previous layer cache to the CPU"
+ if len(self) > 2:
+ # We do it on the default stream so it occurs after all earlier computations on these tensors are done
+ prev_layer_idx = (layer_idx - 1) % len(self)
+ self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
+ self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+ "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
+ if layer_idx < len(self):
+ # Evict the previous layer if necessary
+ torch.cuda.current_stream().synchronize()
+ self.evict_previous_layer(layer_idx)
+ # Load current layer cache to its original device if not already there
+ original_device = self.original_device[layer_idx]
+ self.prefetch_stream.synchronize()
+ key_tensor = self.key_cache[layer_idx]
+ value_tensor = self.value_cache[layer_idx]
+ # Now deal with beam search ops which were delayed
+ if self.beam_idx is not None:
+ self.beam_idx = self.beam_idx.to(original_device)
+ key_tensor = key_tensor.index_select(0, self.beam_idx)
+ value_tensor = value_tensor.index_select(0, self.beam_idx)
+ # Prefetch the next layer
+ self.prefetch_layer((layer_idx + 1) % len(self))
+ return (key_tensor, value_tensor)
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Saves the beam indices and reorders the cache when the tensor is back to its device."""
+ # We delay this operation until the tensors are back to their original
+ # device because performing torch.index_select on the CPU is very slow
+ del self.beam_idx
+ self.beam_idx = beam_idx.clone()
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ # Update the cache
+ if len(self.key_cache) < layer_idx:
+ raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
+ elif len(self.key_cache) == layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ self.original_device.append(key_states.device)
+ self.evict_previous_layer(layer_idx)
+ else:
+ key_tensor, value_tensor = self[layer_idx]
+ self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
+ self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
+ # if a method is not supposed to be supported in a subclass we should set it to None
+ from_legacy_cache = None
+
+ to_legacy_cache = None
+
+
+class QuantizedCache(DynamicCache):
+ """
+ A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
+ It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
+
+ The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
+ original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
+ quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
+
+ It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
+ Value in original precision states as a list of tensors, one for each layer. The size of each tensor
+ is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
+ """
+
+ def __init__(self, cache_config: QuantizedCacheConfig) -> None:
+ super().__init__()
+ self._quantized_key_cache: List[torch.Tensor] = []
+ self._quantized_value_cache: List[torch.Tensor] = []
+
+ self.nbits = cache_config.nbits
+ self.residual_length = cache_config.residual_length
+ self.q_group_size = cache_config.q_group_size
+ self.axis_key = cache_config.axis_key
+ self.axis_value = cache_config.axis_value
+ self.compute_dtype = cache_config.compute_dtype
+ self.device = cache_config.device
+
+ super().__init__()
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ if len(self.key_cache) < layer_idx:
+ raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.")
+ elif len(self.key_cache) == layer_idx:
+ self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key))
+ self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value))
+ self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
+ self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
+ keys_to_return, values_to_return = key_states, value_states
+ else:
+ dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
+ dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
+ keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
+ values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
+
+ keys_to_return = torch.cat(keys_to_return, dim=-2)
+ values_to_return = torch.cat(values_to_return, dim=-2)
+ if (
+ self.key_cache[layer_idx].dim() == 4
+ and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
+ ):
+ self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
+ self._quantized_value_cache[layer_idx] = self._quantize(
+ values_to_return.contiguous(), axis=self.axis_value
+ )
+ self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
+ self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
+ else:
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+ return keys_to_return, values_to_return
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ if len(self.key_cache) <= layer_idx:
+ return 0
+ # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
+ # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
+ # this part of code otherwise fails when used to verify attn_weight shape in some models
+ return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
+
+ def _quantize(self, tensor, axis):
+ """Quantizes a key/value using a defined quantization method."""
+ raise NotImplementedError("Make sure to implement `_quantize` in a subclass.")
+
+ def _dequantize(self, q_tensor):
+ """Dequantizes back the tensor that was quantized by `self._quantize()`"""
+ raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.")
+
+
+class QuantoQuantizedCache(QuantizedCache):
+ """
+ Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
+
+ Parameters:
+ cache_config (`QuantizedCacheConfig`):
+ A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
+
+ Example:
+
+ ```python
+ >>> # Run pip install quanto first if you don't have it yet
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
+
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> cache_config = QuantizedCacheConfig(nbits=4)
+ >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ QuantoQuantizedCache()
+ ```
+ """
+
+ def __init__(self, cache_config: CacheConfig) -> None:
+ super().__init__(cache_config)
+
+ if is_optimum_quanto_available():
+ optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto"))
+ if optimum_quanto_version <= version.parse("0.2.5"):
+ raise ImportError(
+ f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
+ )
+ from optimum.quanto import MaxOptimizer, qint2, qint4
+
+ if self.nbits not in [2, 4]:
+ raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
+
+ if self.axis_key not in [0, -1]:
+ raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
+
+ if self.axis_value not in [0, -1]:
+ raise ValueError(
+ f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
+ )
+
+ self.qtype = qint4 if self.nbits == 4 else qint2
+ self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
+
+ def _quantize(self, tensor, axis):
+ # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
+ if is_optimum_quanto_available():
+ from optimum.quanto import quantize_weight
+
+ scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
+ qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
+ return qtensor
+
+ def _dequantize(self, qtensor):
+ return qtensor.dequantize()
+
+
+class HQQQuantizedCache(QuantizedCache):
+ """
+ Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
+
+ Parameters:
+ cache_config (`QuantizedCacheConfig`):
+ A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
+
+ Example:
+
+ ```python
+ >>> # Run pip install hqq first if you don't have it yet
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
+
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1)
+ >>> past_key_values = HQQQuantizedCache(cache_config=cache_config)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ HQQQuantizedCache()
+ ```
+ """
+
+ def __init__(self, cache_config: CacheConfig) -> None:
+ super().__init__(cache_config)
+ if self.nbits not in [1, 2, 3, 4, 8]:
+ raise ValueError(
+ f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
+ )
+
+ if self.axis_key not in [0, 1]:
+ raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
+
+ if self.axis_value not in [0, 1]:
+ raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
+
+ self.quantizer = HQQQuantizer
+
+ def _quantize(self, tensor, axis):
+ qtensor, meta = self.quantizer.quantize(
+ tensor,
+ axis=axis,
+ device=self.device,
+ compute_dtype=self.compute_dtype,
+ nbits=self.nbits,
+ group_size=self.q_group_size,
+ )
+ meta["compute_dtype"] = self.compute_dtype
+ self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype
+ return qtensor, meta
+
+ def _dequantize(self, qtensor):
+ quant_tensor, meta = qtensor
+ tensor = self.quantizer.dequantize(quant_tensor, meta)
+ return tensor
+
+
+class SinkCache(Cache):
+ """
+ A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
+ generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
+ tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
+
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+ `[batch_size, num_heads, seq_len, head_dim]`.
+
+ Parameters:
+ window_length (`int`):
+ The length of the context window.
+ num_sink_tokens (`int`):
+ The number of sink tokens. See the original paper for more information.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ SinkCache()
+ ```
+ """
+
+ is_sliding = True
+
+ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
+ super().__init__()
+ self.key_cache: List[torch.Tensor] = []
+ self.value_cache: List[torch.Tensor] = []
+ self.window_length = window_length
+ self.num_sink_tokens = num_sink_tokens
+ self.cos_sin_rerotation_cache = {}
+ self._cos_cache = None
+ self._sin_cache = None
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
+
+ @staticmethod
+ def _rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def _apply_key_rotary_pos_emb(
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+ ) -> torch.Tensor:
+ rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
+ return rotated_key_states
+
+ def _get_rerotation_cos_sin(
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
+ # Upcast to float32 temporarily for better accuracy
+ cos = cos.to(torch.float32)
+ sin = sin.to(torch.float32)
+
+ # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
+ original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
+ shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
+ original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
+ shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
+ rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
+ rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
+
+ self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
+ rerotation_cos.to(key_states.dtype).unsqueeze(0),
+ rerotation_sin.to(key_states.dtype).unsqueeze(0),
+ )
+ return self.cos_sin_rerotation_cache[key_states.shape[-2]]
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
+ if len(self.key_cache) <= layer_idx:
+ return 0
+ return self.key_cache[layer_idx].shape[-2]
+
+ def get_max_cache_shape(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
+ return self.window_length
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
+ `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
+ rotation as the tokens are shifted.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
+ # with partially rotated position embeddings, like Phi or Persimmon.
+ sin = cache_kwargs.get("sin")
+ cos = cache_kwargs.get("cos")
+ partial_rotation_size = cache_kwargs.get("partial_rotation_size")
+ using_rope = cos is not None and sin is not None
+
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += key_states.shape[-2]
+
+ # Update the sin/cos cache, which holds sin/cos values for all possible positions
+ if using_rope and layer_idx == 0:
+ # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
+ # after all RoPE models have a llama-like cache utilization.
+ if cos.dim() == 2:
+ self._cos_cache = cos
+ self._sin_cache = sin
+ else:
+ if self._cos_cache is None:
+ self._cos_cache = cos[0, ...]
+ self._sin_cache = sin[0, ...]
+ elif self._cos_cache.shape[0] < self.window_length:
+ self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
+ self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
+
+ # [bsz, num_heads, seq_len, head_dim]
+ if len(self.key_cache) <= layer_idx:
+ # Empty cache
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+
+ elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
+ # Growing cache
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+ else:
+ # Shifting cache
+ keys_to_keep = self.key_cache[layer_idx][
+ :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
+ ]
+
+ # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
+ if using_rope:
+ rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
+ key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
+ )
+ if partial_rotation_size is not None:
+ keys_to_keep, keys_pass = (
+ keys_to_keep[..., :partial_rotation_size],
+ keys_to_keep[..., partial_rotation_size:],
+ )
+ keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
+ if partial_rotation_size is not None:
+ keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
+
+ # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
+ sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
+ self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
+
+ sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
+ values_to_keep = self.value_cache[layer_idx][
+ :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
+ ]
+ self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+
+class StaticCache(Cache):
+ """
+ Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
+
+ Parameters:
+ config (`PretrainedConfig`):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ batch_size (`int`):
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
+ smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
+ max_cache_len (`int`):
+ The maximum sequence length with which the model will be used.
+ device (`torch.device` or `str`):
+ The device on which the cache should be initialized. Should be the same as the layer.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ The default `dtype` to use when initializing the layer.
+ layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
+ Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
+ You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+
+ >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ StaticCache()
+ ```
+ """
+
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ batch_size: int = None,
+ max_cache_len: int = None,
+ device: torch.device = None,
+ dtype: torch.dtype = torch.float32,
+ max_batch_size: Optional[int] = None,
+ layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
+ ) -> None:
+ super().__init__()
+ if batch_size is not None:
+ logger.warning_once(
+ f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
+ "v4.49. Use the more precisely named 'max_batch_size' argument instead."
+ )
+
+ self.max_batch_size = batch_size or max_batch_size
+ self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
+
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+ self.head_dim = (
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+ )
+
+ self.dtype = dtype
+ self.num_key_value_heads = (
+ config.num_attention_heads
+ if getattr(config, "num_key_value_heads", None) is None
+ else config.num_key_value_heads
+ )
+
+ self.key_cache: List[torch.Tensor] = []
+ self.value_cache: List[torch.Tensor] = []
+ # Note: There will be significant perf decrease if switching to use 5D tensors instead.
+ cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
+ for idx in range(config.num_hidden_layers):
+ if layer_device_map is not None:
+ layer_device = layer_device_map[idx]
+ else:
+ layer_device = device
+ new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
+ new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
+ # Notes:
+ # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
+ # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
+ # it is not needed anyway)
+ # 2. `torch.export()` requires mutations to be registered as buffers.
+ if not is_torchdynamo_compiling():
+ self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
+ self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
+ new_layer_key_cache = getattr(self, f"key_cache_{idx}")
+ new_layer_value_cache = getattr(self, f"value_cache_{idx}")
+ torch._dynamo.mark_static_address(new_layer_key_cache)
+ torch._dynamo.mark_static_address(new_layer_value_cache)
+ self.key_cache.append(new_layer_key_cache)
+ self.value_cache.append(new_layer_value_cache)
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
+ to know how where to write in the cache.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+
+ cache_position = cache_kwargs.get("cache_position")
+
+ k_out = self.key_cache[layer_idx]
+ v_out = self.value_cache[layer_idx]
+ key_states = key_states.to(k_out.dtype)
+ value_states = value_states.to(v_out.dtype)
+
+ if cache_position is None:
+ k_out.copy_(key_states)
+ v_out.copy_(value_states)
+ else:
+ # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
+ # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
+ # operation, that avoids copies and uses less memory.
+ try:
+ k_out.index_copy_(2, cache_position, key_states)
+ v_out.index_copy_(2, cache_position, value_states)
+ except NotImplementedError:
+ # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+
+ return k_out, v_out
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states that were seen by the model."""
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+ # limit the check to the first batch member and head dimension.
+ # TODO: deprecate this function in favor of `cache_position`
+ return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
+
+ def get_max_cache_shape(self) -> Optional[int]:
+ return self.max_cache_len
+
+ def reset(self):
+ """Resets the cache values while preserving the objects"""
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+ @property
+ def batch_size(self):
+ logger.warning_once(
+ f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
+ "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
+ )
+ return self.max_batch_size
+
+
+class SlidingWindowCache(StaticCache):
+ """
+ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
+ Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
+ if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
+ we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
+
+ The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
+
+ indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
+ tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
+ 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
+ 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+ 55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
+
+ We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)
+
+ Parameters:
+ config (`PretrainedConfig`):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ batch_size (`int`):
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
+ smaller batch size is used.
+ max_cache_len (`int`):
+ The maximum sequence length with which the model will be used.
+ device (`torch.device` or `str`):
+ The device on which the cache should be initialized. Should be the same as the layer.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ The default `dtype` to use when initializing the layer.
+ layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
+ Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
+ You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
+
+ >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ SlidingWindowCache()
+ ```
+ """
+
+ is_sliding = True
+
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ batch_size: int = None,
+ max_cache_len: int = None,
+ device: torch.device = None,
+ dtype: torch.dtype = torch.float32,
+ max_batch_size: Optional[int] = None,
+ layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
+ ) -> None:
+ if not hasattr(config, "sliding_window") or config.sliding_window is None:
+ raise ValueError(
+ "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
+ "sliding window attention, please check if there is a `sliding_window` field in the model "
+ "config and it's not set to None."
+ )
+ max_cache_len = min(config.sliding_window, max_cache_len)
+ super().__init__(
+ config=config,
+ batch_size=batch_size,
+ max_cache_len=max_cache_len,
+ device=device,
+ dtype=dtype,
+ max_batch_size=max_batch_size,
+ layer_device_map=layer_device_map,
+ )
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor]:
+ cache_position = cache_kwargs.get("cache_position")
+ k_out = self.key_cache[layer_idx]
+ v_out = self.value_cache[layer_idx]
+
+ # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
+ if cache_position.shape[0] > self.max_cache_len:
+ k_out = key_states[:, :, -self.max_cache_len :, :]
+ v_out = value_states[:, :, -self.max_cache_len :, :]
+ # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
+ self.key_cache[layer_idx] += k_out
+ self.value_cache[layer_idx] += v_out
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
+ return key_states, value_states
+
+ slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
+ cache_position = cache_position.clamp(0, self.max_cache_len - 1)
+ to_shift = cache_position >= self.max_cache_len - 1
+ indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
+
+ k_out = k_out[:, :, indices]
+ v_out = v_out[:, :, indices]
+
+ try:
+ k_out.index_copy_(2, cache_position, key_states)
+ v_out.index_copy_(2, cache_position, value_states)
+ except NotImplementedError:
+ # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+
+ # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+ self.key_cache[layer_idx] += k_out
+ self.value_cache[layer_idx] += v_out
+
+ return k_out, v_out
+
+ def get_max_cache_shape(self) -> Optional[int]:
+ return self.max_cache_len
+
+ def reset(self):
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+
+class EncoderDecoderCache(Cache):
+ """
+ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
+ cross-attention caches.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
+
+ >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
+
+ >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
+ >>> self_attention_cache = DynamicCache()
+ >>> cross_attention_cache = DynamicCache()
+ >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ EncoderDecoderCache()
+ ```
+
+ """
+
+ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
+ super().__init__()
+ self.self_attention_cache = self_attention_cache
+ self.cross_attention_cache = cross_attention_cache
+
+ self.is_updated = {}
+ for layer_idx in range(len(cross_attention_cache.key_cache)):
+ self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
+
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+ """
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+ sequence length.
+ """
+ if layer_idx < len(self):
+ return (
+ self.self_attention_cache.key_cache[layer_idx],
+ self.self_attention_cache.value_cache[layer_idx],
+ self.cross_attention_cache.key_cache[layer_idx],
+ self.cross_attention_cache.value_cache[layer_idx],
+ )
+ else:
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+ def __len__(self):
+ """
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+ to the number of layers in the model.
+ """
+ return len(self.self_attention_cache)
+
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+ """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
+ legacy_cache = ()
+ if len(self.cross_attention_cache) > 0:
+ for self_attn, cross_attn in zip(
+ self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
+ ):
+ legacy_cache += (self_attn + cross_attn,)
+ else:
+ legacy_cache = self.self_attention_cache.to_legacy_cache()
+ return legacy_cache
+
+ @classmethod
+ def from_legacy_cache(
+ cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ ) -> "EncoderDecoderCache":
+ """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
+ cache = cls(
+ self_attention_cache=DynamicCache(),
+ cross_attention_cache=DynamicCache(),
+ )
+ if past_key_values is not None:
+ for layer_idx in range(len(past_key_values)):
+ key_states, value_states = past_key_values[layer_idx][:2]
+ cache.self_attention_cache.update(key_states, value_states, layer_idx)
+ if len(past_key_values[layer_idx]) > 2:
+ key_states, value_states = past_key_values[layer_idx][2:]
+ cache.cross_attention_cache.update(key_states, value_states, layer_idx)
+ cache.is_updated[layer_idx] = True
+ return cache
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
+ return self.self_attention_cache.get_seq_length(layer_idx)
+
+ def reset(self):
+ if hasattr(self.self_attention_cache, "reset"):
+ self.self_attention_cache.reset()
+ if hasattr(self.cross_attention_cache, "reset"):
+ self.cross_attention_cache.reset()
+ elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
+ raise ValueError(
+ "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
+ "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
+ f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
+ f"{self.cross_attention_cache.__str__()} for the cross attention cache."
+ )
+ for layer_idx in self.is_updated:
+ self.is_updated[layer_idx] = False
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ self.self_attention_cache.reorder_cache(beam_idx)
+ self.cross_attention_cache.reorder_cache(beam_idx)
+
+ def check_dynamic_cache(self, method: str):
+ if not (
+ isinstance(self.self_attention_cache, DynamicCache)
+ and isinstance(self.cross_attention_cache, DynamicCache)
+ ):
+ raise ValueError(
+ f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
+ f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
+ )
+
+ # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
+ def crop(self, maximum_length: int):
+ """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
+ negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
+ self.check_dynamic_cache(self.crop.__name__)
+ self.self_attention_cache.crop(maximum_length)
+
+ @deprecate_kwarg("num_hidden_layers", version="4.47.0")
+ def batch_split(
+ self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
+ ) -> "List[EncoderDecoderCache]":
+ """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
+ `_split_model_inputs()` in `generation.utils`"""
+ self.check_dynamic_cache(self.batch_split.__name__)
+ self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
+ cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
+
+ out = []
+ for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
+ out.append(EncoderDecoderCache(self_attn, cross_attn))
+ return out
+
+ @classmethod
+ @deprecate_kwarg("num_hidden_layers", version="4.47.0")
+ def from_batch_splits(
+ cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int = None
+ ) -> "EncoderDecoderCache":
+ """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
+ `generation.utils`"""
+ self_attention_cache = DynamicCache()
+ cross_attention_cache = DynamicCache()
+ for idx in range(len(splits[0])):
+ layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
+ layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
+ self_attention_cache.update(layer_keys, layer_values, idx)
+
+ layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
+ layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
+ cross_attention_cache.update(layer_keys, layer_values, idx)
+ return cls(self_attention_cache, cross_attention_cache)
+
+ def batch_repeat_interleave(self, repeats: int):
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
+ self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
+ self.self_attention_cache.batch_repeat_interleave(repeats)
+ self.cross_attention_cache.batch_repeat_interleave(repeats)
+
+ def batch_select_indices(self, indices: torch.Tensor):
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
+ self.check_dynamic_cache(self.batch_select_indices.__name__)
+ self.self_attention_cache.batch_select_indices(indices)
+ self.cross_attention_cache.batch_select_indices(indices)
+
+
+class HybridCache(Cache):
+ """
+ Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
+ and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
+ and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
+
+ Parameters:
+ config (`PretrainedConfig):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ batch_size (`int`):
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
+ smaller batch size is used.
+ max_cache_len (`int`):
+ The maximum sequence length with which the model will be used.
+ device (`torch.device` or `str`, *optional*, defaults to `"cpu"`):
+ The device on which the cache should be initialized. Should be the same as the layer.
+ dtype (torch.dtype, *optional*, defaults to `torch.float32`):
+ The default `dtype` to use when initializing the layer.
+ layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
+ Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
+ You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
+
+ >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ HybridCache()
+ ```
+ """
+
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ batch_size: int = None,
+ max_cache_len: int = None,
+ device: Union[torch.device, str] = "cpu",
+ dtype: torch.dtype = torch.float32,
+ max_batch_size: Optional[int] = None,
+ layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
+ ) -> None:
+ super().__init__()
+ if batch_size is not None:
+ logger.warning_once(
+ f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
+ "v4.49. Use the more precisely named 'max_batch_size' argument instead."
+ )
+ if not hasattr(config, "sliding_window") or config.sliding_window is None:
+ raise ValueError(
+ "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
+ "sliding window attention, please check if there is a `sliding_window` field in the model "
+ "config and it's not set to None."
+ )
+ self.max_cache_len = max_cache_len
+ self.max_batch_size = batch_size or max_batch_size
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+ self.head_dim = (
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+ )
+
+ self.dtype = dtype
+ self.num_key_value_heads = (
+ config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
+ )
+ layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
+ self.is_sliding = torch.tensor(
+ [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
+ )
+ self.key_cache: List[torch.Tensor] = []
+ self.value_cache: List[torch.Tensor] = []
+ global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
+ sliding_cache_shape = (
+ self.batch_size,
+ self.num_key_value_heads,
+ min(config.sliding_window, max_cache_len),
+ self.head_dim,
+ )
+ for i in range(config.num_hidden_layers):
+ if layer_device_map is not None:
+ layer_device = layer_device_map[i]
+ else:
+ layer_device = device
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
+ # breaks when updating the cache.
+ cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
+ new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
+ new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
+ torch._dynamo.mark_static_address(new_layer_key_cache)
+ torch._dynamo.mark_static_address(new_layer_value_cache)
+ self.key_cache.append(new_layer_key_cache)
+ self.value_cache.append(new_layer_value_cache)
+
+ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
+ if cache_position.shape[0] > max_cache_len:
+ k_out = key_states[:, :, -max_cache_len:, :]
+ v_out = value_states[:, :, -max_cache_len:, :]
+ # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
+ self.key_cache[layer_idx] += k_out
+ self.value_cache[layer_idx] += v_out
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
+ return key_states, value_states
+
+ slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
+ cache_position = cache_position.clamp(0, max_cache_len - 1)
+ to_shift = cache_position >= max_cache_len - 1
+ indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
+ k_out = k_out[:, :, indices]
+ v_out = v_out[:, :, indices]
+
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+ # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+ self.key_cache[layer_idx] += k_out
+ self.value_cache[layer_idx] += v_out
+ return k_out, v_out
+
+ def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+
+ self.key_cache[layer_idx] = k_out
+ self.value_cache[layer_idx] = v_out
+ return k_out, v_out
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor]:
+ cache_position = cache_kwargs.get("cache_position")
+ sliding_window = cache_kwargs.get("sliding_window")
+ k_out = self.key_cache[layer_idx]
+ v_out = self.value_cache[layer_idx]
+ if sliding_window:
+ update_fn = self._sliding_update
+ else:
+ update_fn = self._static_update
+
+ return update_fn(
+ cache_position,
+ layer_idx,
+ key_states,
+ value_states,
+ k_out,
+ v_out,
+ k_out.shape[2],
+ )
+
+ def get_max_cache_shape(self) -> Optional[int]:
+ return self.max_cache_len
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0):
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+ # limit the check to the first batch member and head dimension.
+ # TODO: deprecate this function in favor of `cache_position`
+ if layer_idx != 0:
+ raise ValueError(
+ "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
+ "Using the `layer_idx` argument is not supported."
+ )
+ return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
+
+ def reset(self):
+ """Resets the cache values while preserving the objects"""
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+ @property
+ def batch_size(self):
+ logger.warning_once(
+ f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
+ "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
+ )
+ return self.max_batch_size
+
+
+class MambaCache:
+ """
+ Cache for mamba model which does not have attention mechanism and key value states.
+
+ Arguments:
+ config (`PretrainedConfig):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ batch_size (`int`):
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
+ smaller batch size is used.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
+ The default `dtype` to use when initializing the layer.
+ device (`torch.device` or `str`, *optional*):
+ The device on which the cache should be initialized. Should be the same as the layer.
+
+ Attributes:
+ dtype: (`torch.dtype`):
+ The default `dtype` used to initializing the cache.
+ intermediate_size: (`int`):
+ Model's intermediate_size taken from config.
+ ssm_state_size: (`int`):
+ Model's state_size taken from config.
+ conv_kernel_size: (`int`):
+ Model's convolution kernel size taken from config
+ conv_states: (`torch.Tensor`):
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
+ ssm_states: (`torch.Tensor`):
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
+
+ >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
+
+ >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values
+ MambaCache()
+ ```
+ """
+
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ batch_size: int = None,
+ dtype: torch.dtype = torch.float16,
+ device: Optional[Union[torch.device, str]] = None,
+ max_batch_size: Optional[int] = None,
+ ):
+ if batch_size is not None:
+ logger.warning_once(
+ f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
+ "v4.49. Use the more precisely named 'max_batch_size' argument instead."
+ )
+ self.dtype = dtype
+ self.max_batch_size = batch_size or max_batch_size
+ self.intermediate_size = config.intermediate_size
+ self.ssm_state_size = config.state_size
+ self.conv_kernel_size = config.conv_kernel
+
+ self.conv_states: torch.Tensor = torch.zeros(
+ config.num_hidden_layers,
+ self.max_batch_size,
+ self.intermediate_size,
+ self.conv_kernel_size,
+ device=device,
+ dtype=dtype,
+ )
+ self.ssm_states: torch.Tensor = torch.zeros(
+ config.num_hidden_layers,
+ self.max_batch_size,
+ self.intermediate_size,
+ self.ssm_state_size,
+ device=device,
+ dtype=dtype,
+ )
+
+ torch._dynamo.mark_static_address(self.conv_states)
+ torch._dynamo.mark_static_address(self.ssm_states)
+
+ def update_conv_state(
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
+ ) -> torch.Tensor:
+ conv_state = self.conv_states[layer_idx]
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
+
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
+ conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
+ self.conv_states[layer_idx].zero_()
+ self.conv_states[layer_idx] += conv_state
+ return self.conv_states[layer_idx]
+
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
+ return self.ssm_states[layer_idx]
+
+ def reset(self):
+ self.conv_states.zero_()
+ self.ssm_states.zero_()
+
+ @property
+ def batch_size(self):
+ logger.warning_once(
+ f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
+ "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
+ )
+ return self.max_batch_size
+
+
+class OffloadedStaticCache(StaticCache):
+ """
+ Static cache class to be used with `torch.compile(model)` that offloads to the CPU or
+ another device.
+
+ Args:
+ config (`PretrainedConfig):
+ The configuration file defining the shape-related attributes required to initialize
+ the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used.
+ max_cache_len (`int`):
+ The maximum sequence length with which the model will be used.
+ device (`Union[str, torch.device]`):
+ The device on which the cache should be initialized. Should be the same as the
+ layer device.
+ dtype (`torch.dtype`, *optional*):
+ The default `dtype` to use when initializing the cache.
+ offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
+ The device to offload to. Defaults to CPU.
+ layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*):
+ Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
+ You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
+
+ Attributes:
+ key_cache (`List[torch.Tensor]`):
+ Off-loaded key cache tensors. First one will be on device, where-as the others are
+ off-loaded.
+ value_cache (`List[torch.Tensor]`):
+ Off-loaded value cache tensors. First one will be on device, where-as the others are
+ off-loaded.
+ max_batch_size (`int`):
+ The maximum batch size with which this cache can be used.
+ max_cache_len (`int`):
+ The maximum sequence length with which this cache can be used.
+ device (`torch.device`):
+ The device on which the cache is used.
+ offload_device (`torch.device`):
+ The device used to offload to.
+ dtype (`torch.dtype`):
+ The `dtype` used to initializing the cache.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
+
+ >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
+ ```
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ max_batch_size: int,
+ max_cache_len: Optional[int],
+ device: Union[str, torch.device],
+ dtype: Optional[torch.dtype] = None,
+ offload_device: Union[str, torch.device] = torch.device("cpu"),
+ layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
+ ) -> None:
+ self.max_batch_size = max_batch_size
+ self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
+ self.device = torch.device(device) if layer_device_map is None else layer_device_map[0]
+ self.offload_device = torch.device(offload_device)
+ self.dtype = dtype if dtype is not None else torch.float32
+
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+ head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+
+ num_key_value_heads = (
+ config.num_attention_heads
+ if getattr(config, "num_key_value_heads", None) is None
+ else config.num_key_value_heads
+ )
+
+ cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
+
+ # Create offloaded CPU tensors.
+ self.key_cache: List[torch.Tensor] = []
+ self.value_cache: List[torch.Tensor] = []
+
+ for i in range(config.num_hidden_layers):
+ # First layer is always on-device.
+ device = self.device if i == 0 else self.offload_device
+
+ key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device)
+
+ self.key_cache.append(key_cache)
+ self.value_cache.append(value_cache)
+
+ # Create device tensors.
+ self._device_key_cache: List[torch.Tensor] = []
+ self._device_value_cache: List[torch.Tensor] = []
+
+ for i in range(2):
+ key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device)
+
+ self._device_key_cache.append(key_cache)
+ self._device_value_cache.append(value_cache)
+
+ # For backwards compatibility.
+ # TODO(gante): Remove this.
+ self._seen_tokens = 0
+
+ # Create new CUDA stream for parallel prefetching.
+ self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, *optional*):
+ Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the
+ `cache_position` input to know how where to write in the cache.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+
+ if layer_idx == 0:
+ # Update seen tokens.
+ # TODO(gante): Remove this.
+ self._seen_tokens += key_states.shape[-2]
+
+ # Always there.
+ k_out = self.key_cache[0]
+ v_out = self.value_cache[0]
+ else:
+ # Wait for prefetch stream.
+ if self._prefetch_stream is not None:
+ torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream)
+
+ k_out = self._device_key_cache[layer_idx & 1]
+ v_out = self._device_value_cache[layer_idx & 1]
+
+ self._prefetch_layer(layer_idx + 1)
+
+ cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
+ if cache_position is None:
+ k_out.copy_(key_states)
+ v_out.copy_(value_states)
+
+ # Copy the values to the offloaded device as well.
+ if layer_idx == 0:
+ self.key_cache[layer_idx].copy_(key_states.to(self.offload_device))
+ self.value_cache[layer_idx].copy_(value_states.to(self.offload_device))
+ else:
+ # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
+ # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does
+ # explicitly an in-place operation, that avoids copies and uses less memory.
+ try:
+ k_out.index_copy_(2, cache_position, key_states)
+ v_out.index_copy_(2, cache_position, value_states)
+ except NotImplementedError:
+ # The operator 'aten::index_copy.out' is not currently implemented for the MPS
+ # device.
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+
+ # Copy the values to the offloaded device as well.
+ if layer_idx != 0:
+ cache_position = cache_position.to(self.offload_device)
+ key_states = key_states.to(self.offload_device)
+ value_states = value_states.to(self.offload_device)
+
+ try:
+ self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
+ self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
+ except NotImplementedError:
+ # The operator 'aten::index_copy.out' is not currently implemented for the MPS
+ # device.
+ self.key_cache[layer_idx][:, :, cache_position] = key_states
+ self.value_cache[layer_idx][:, :, cache_position] = value_states
+
+ return k_out, v_out
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states that were seen by the model."""
+
+ # TODO(gante): Remove this.
+ return self._seen_tokens
+
+ def get_max_cache_shape(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states."""
+
+ return self.max_cache_len
+
+ def reset(self) -> None:
+ """Resets the cache values while preserving the objects."""
+
+ # For backwards compatibility.
+ # TODO(gante): Remove this.
+ self._seen_tokens = 0
+
+ # Zero out cache.
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address.
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+ @property
+ def seen_tokens(self) -> int:
+ # For backwards compatibility.
+ # TODO(gante): Remove this.
+ return self._seen_tokens
+
+ def _create_key_value_cache_tensors(
+ self, shape: Tuple[int, ...], device: torch.device
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static
+ addresses for non-CPU tensors.
+
+ Args:
+ shape (`Tuple[int, ...]`): Shape.
+ device (`torch.device`): Device.
+
+ Returns:
+ Key and value cache tensors as a tuple.
+ """
+
+ is_cpu_device = device == torch.device("cpu")
+
+ key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
+ value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
+
+ # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
+ # preventing compiled graph breaks when updating the cache.
+ torch._dynamo.mark_static_address(key_cache)
+ torch._dynamo.mark_static_address(value_cache)
+
+ return key_cache, value_cache
+
+ def _prefetch_layer(self, layer_idx: int) -> None:
+ """Prefetch a layer to the device. Needs to be called in order of layer indices."""
+
+ # Don't fetch layers that do not exist.
+ if layer_idx >= len(self.key_cache):
+ return
+
+ # Alternate between two on-device caches.
+ if self._prefetch_stream is not None:
+ with torch.cuda.stream(self._prefetch_stream):
+ self._prefetch_layer_in_context(layer_idx)
+ else:
+ self._prefetch_layer_in_context(layer_idx)
+
+ def _prefetch_layer_in_context(self, layer_idx: int) -> None:
+ """Performs the actual copy of the layer to device cache."""
+
+ self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True)
+ self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True)
diff --git a/.venv/lib/python3.11/site-packages/transformers/configuration_utils.py b/.venv/lib/python3.11/site-packages/transformers/configuration_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..648877c8dce962d0d0387924ced7320781d9f056
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/configuration_utils.py
@@ -0,0 +1,1187 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Configuration base class and utilities."""
+
+import copy
+import json
+import os
+import re
+import warnings
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from packaging import version
+
+from . import __version__
+from .dynamic_module_utils import custom_object_save
+from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
+from .utils import (
+ CONFIG_NAME,
+ PushToHubMixin,
+ add_model_info_to_auto_map,
+ add_model_info_to_custom_pipelines,
+ cached_file,
+ copy_func,
+ download_url,
+ extract_commit_hash,
+ is_remote_url,
+ is_torch_available,
+ logging,
+)
+from .utils.generic import is_timm_config_dict
+
+
+logger = logging.get_logger(__name__)
+
+_re_configuration_file = re.compile(r"config\.(.*)\.json")
+
+
+class PretrainedConfig(PushToHubMixin):
+ # no-format
+ r"""
+ Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
+ methods for loading/downloading/saving configurations.
+
+
+
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
+ initialize a model does **not** load the model weights. It only affects the model's configuration.
+
+
+
+ Class attributes (overridden by derived classes):
+
+ - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
+ the correct object in [`~transformers.AutoConfig`].
+ - **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
+ config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
+ [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
+ - **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
+ outputs of the model during inference.
+ - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
+ naming of attributes.
+ - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
+ parallel plan applied to the sub-module when `model.tensor_parallel` is called.
+
+ Common attributes (present in all subclasses):
+
+ - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
+ embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
+ - **hidden_size** (`int`) -- The hidden size of the model.
+ - **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
+ model.
+ - **num_hidden_layers** (`int`) -- The number of blocks in the model.
+
+
+
+ Setting parameters for sequence generation in the model config is deprecated. For backward compatibility, loading
+ some of them will still be possible, but attempting to overwrite them will throw an exception -- you should set
+ them in a [~transformers.GenerationConfig]. Check the documentation of [~transformers.GenerationConfig] for more
+ information about the individual parameters.
+
+
+
+ Arg:
+ name_or_path (`str`, *optional*, defaults to `""`):
+ Store the string that was passed to [`PreTrainedModel.from_pretrained`] or
+ [`TFPreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path` if the configuration was created
+ with such a method.
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
+ Whether or not the model should return all hidden-states.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not the model should returns all attentions.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
+ is_encoder_decoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is used as an encoder/decoder or not.
+ is_decoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is used as decoder or not (in which case it's used as an encoder).
+ cross_attention_hidden_size** (`bool`, *optional*):
+ The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
+ setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
+ add_cross_attention (`bool`, *optional*, defaults to `False`):
+ Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
+ that can be used as decoder models within the [`EncoderDecoderModel`] class, which consists of all models
+ in `AUTO_MODELS_FOR_CAUSAL_LM`.
+ tie_encoder_decoder (`bool`, *optional*, defaults to `False`):
+ Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
+ and decoder model to have the exact same parameter names.
+ prune_heads (`Dict[int, List[int]]`, *optional*, defaults to `{}`):
+ Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of
+ heads to prune in said layer.
+
+ For instance `{1: [0, 2], 2: [2, 3]}` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
+ chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
+ The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
+ the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
+ sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
+ Forward Chunking work?](../glossary.html#feed-forward-chunking).
+
+ > Parameters for fine-tuning tasks
+
+ architectures (`List[str]`, *optional*):
+ Model architectures that can be used with the model pretrained weights.
+ finetuning_task (`str`, *optional*):
+ Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow
+ or PyTorch) checkpoint.
+ id2label (`Dict[int, str]`, *optional*):
+ A map from index (for instance prediction index, or target index) to label.
+ label2id (`Dict[str, int]`, *optional*): A map from label to index for the model.
+ num_labels (`int`, *optional*):
+ Number of labels to use in the last layer added to the model, typically for a classification task.
+ task_specific_params (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments to store for the current task.
+ problem_type (`str`, *optional*):
+ Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
+ `"single_label_classification"` or `"multi_label_classification"`.
+
+ > Parameters linked to the tokenizer
+
+ tokenizer_class (`str`, *optional*):
+ The name of the associated tokenizer class to use (if none is set, will use the tokenizer associated to the
+ model by default).
+ prefix (`str`, *optional*):
+ A specific prompt that should be added at the beginning of each text before calling the model.
+ bos_token_id (`int`, *optional*): The id of the _beginning-of-stream_ token.
+ pad_token_id (`int`, *optional*): The id of the _padding_ token.
+ eos_token_id (`int`, *optional*): The id of the _end-of-stream_ token.
+ decoder_start_token_id (`int`, *optional*):
+ If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token.
+ sep_token_id (`int`, *optional*): The id of the _separation_ token.
+
+ > PyTorch specific parameters
+
+ torchscript (`bool`, *optional*, defaults to `False`):
+ Whether or not the model should be used with Torchscript.
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
+ model has a output word embedding layer.
+ torch_dtype (`str`, *optional*):
+ The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
+ (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
+ model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
+ `float16` weights. Since the config object is stored in plain text, this attribute contains just the
+ floating type string without the `torch.` prefix. For example, for `torch.float16` ``torch_dtype` is the
+ `"float16"` string.
+
+ This attribute is currently not being used during model loading time, but this may change in the future
+ versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
+
+ > TensorFlow specific parameters
+
+ use_bfloat16 (`bool`, *optional*, defaults to `False`):
+ Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
+ tf_legacy_loss (`bool`, *optional*, defaults to `False`):
+ Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
+ not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
+ v5.
+ loss_type (`str`, *optional*):
+ The type of loss that the model should use. It should be in `LOSS_MAPPING`'s keys, otherwise the loss will
+ be automatically infered from the model architecture.
+ """
+
+ model_type: str = ""
+ base_config_key: str = ""
+ sub_configs: Dict[str, "PretrainedConfig"] = {}
+ is_composition: bool = False
+ attribute_map: Dict[str, str] = {}
+ base_model_tp_plan: Optional[Dict[str, Any]] = None
+ _auto_class: Optional[str] = None
+
+ def __setattr__(self, key, value):
+ if key in super().__getattribute__("attribute_map"):
+ key = super().__getattribute__("attribute_map")[key]
+ super().__setattr__(key, value)
+
+ def __getattribute__(self, key):
+ if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
+ key = super().__getattribute__("attribute_map")[key]
+ return super().__getattribute__(key)
+
+ def __init__(self, **kwargs):
+ # Attributes with defaults
+ self.return_dict = kwargs.pop("return_dict", True)
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
+ self.output_attentions = kwargs.pop("output_attentions", False)
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
+ self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
+ self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
+ self.tie_word_embeddings = kwargs.pop(
+ "tie_word_embeddings", True
+ ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
+ self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
+
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
+ self.is_decoder = kwargs.pop("is_decoder", False)
+ self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
+ self.add_cross_attention = kwargs.pop("add_cross_attention", False)
+ self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
+
+ # Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
+ # parameters, saving them will be deprecated. In a distant future, we won't need to load them.
+ for parameter_name, default_value in self._get_global_generation_defaults().items():
+ setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))
+
+ # Fine-tuning task arguments
+ self.architectures = kwargs.pop("architectures", None)
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
+ self.id2label = kwargs.pop("id2label", None)
+ self.label2id = kwargs.pop("label2id", None)
+ if self.label2id is not None and not isinstance(self.label2id, dict):
+ raise ValueError("Argument label2id should be a dictionary.")
+ if self.id2label is not None:
+ if not isinstance(self.id2label, dict):
+ raise ValueError("Argument id2label should be a dictionary.")
+ num_labels = kwargs.pop("num_labels", None)
+ if num_labels is not None and len(self.id2label) != num_labels:
+ logger.warning(
+ f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
+ f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
+ )
+ self.id2label = {int(key): value for key, value in self.id2label.items()}
+ # Keys are always strings in JSON so convert ids to int here.
+ else:
+ self.num_labels = kwargs.pop("num_labels", 2)
+
+ if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
+ # we will start using self.torch_dtype in v5, but to be consistent with
+ # from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
+ if is_torch_available():
+ import torch
+
+ self.torch_dtype = getattr(torch, self.torch_dtype)
+
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
+ self.tokenizer_class = kwargs.pop("tokenizer_class", None)
+ self.prefix = kwargs.pop("prefix", None)
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
+ self.sep_token_id = kwargs.pop("sep_token_id", None)
+
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
+
+ # task specific arguments
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
+
+ # regression / multi-label classification
+ self.problem_type = kwargs.pop("problem_type", None)
+ allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
+ if self.problem_type is not None and self.problem_type not in allowed_problem_types:
+ raise ValueError(
+ f"The config parameter `problem_type` was not understood: received {self.problem_type} "
+ "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
+ )
+
+ # TPU arguments
+ if kwargs.pop("xla_device", None) is not None:
+ logger.warning(
+ "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
+ "safely remove it from your `config.json` file."
+ )
+
+ # Name or path to the pretrained checkpoint
+ self._name_or_path = str(kwargs.pop("name_or_path", ""))
+ # Config hash
+ self._commit_hash = kwargs.pop("_commit_hash", None)
+
+ # Attention implementation to use, if relevant.
+ self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
+ self._attn_implementation_autoset = False
+
+ # Drop the transformers version info
+ self.transformers_version = kwargs.pop("transformers_version", None)
+
+ # Deal with gradient checkpointing
+ if kwargs.get("gradient_checkpointing", False):
+ warnings.warn(
+ "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
+ "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
+ "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
+ )
+
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ @property
+ def name_or_path(self) -> str:
+ return getattr(self, "_name_or_path", None)
+
+ @name_or_path.setter
+ def name_or_path(self, value):
+ self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
+
+ @property
+ def use_return_dict(self) -> bool:
+ """
+ `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples.
+ """
+ # If torchscript is set, force `return_dict=False` to avoid jit errors
+ return self.return_dict and not self.torchscript
+
+ @property
+ def num_labels(self) -> int:
+ """
+ `int`: The number of labels for classification models.
+ """
+ return len(self.id2label)
+
+ @num_labels.setter
+ def num_labels(self, num_labels: int):
+ if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
+ self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
+
+ @property
+ def _attn_implementation(self):
+ # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
+ if hasattr(self, "_attn_implementation_internal"):
+ if self._attn_implementation_internal is None:
+ # `config.attn_implementation` should never be None, for backward compatibility.
+ return "eager"
+ else:
+ return self._attn_implementation_internal
+ else:
+ return "eager"
+
+ @_attn_implementation.setter
+ def _attn_implementation(self, value):
+ self._attn_implementation_internal = value
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~PretrainedConfig.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ self._set_token_in_kwargs(kwargs)
+
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ non_default_generation_parameters = self._get_non_default_generation_parameters()
+ if len(non_default_generation_parameters) > 0:
+ # TODO (joao): this should be an exception if the user has modified the loaded config. See #33886
+ warnings.warn(
+ "Some non-default generation parameters are set in the model config. These should go into either a) "
+ "`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
+ "(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)."
+ "This warning will become an exception in the future."
+ f"\nNon-default generation parameters: {str(non_default_generation_parameters)}",
+ UserWarning,
+ )
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=self)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
+
+ self.to_json_file(output_config_file, use_diff=True)
+ logger.info(f"Configuration saved in {output_config_file}")
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=kwargs.get("token"),
+ )
+
+ @staticmethod
+ def _set_token_in_kwargs(kwargs, token=None):
+ """Temporary method to deal with `token` and `use_auth_token`.
+
+ This method is to avoid apply the same changes in all model config classes that overwrite `from_pretrained`.
+
+ Need to clean up `use_auth_token` in a follow PR.
+ """
+ # Some model config classes like CLIP define their own `from_pretrained` without the new argument `token` yet.
+ if token is None:
+ token = kwargs.pop("token", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ) -> "PretrainedConfig":
+ r"""
+ Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
+ - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if
+ they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+
+
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final configuration object.
+
+ If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
+ part of `kwargs` which has not been used to update `config` and is otherwise ignored.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ kwargs (`Dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
+ by the `return_unused_kwargs` keyword parameter.
+
+ Returns:
+ [`PretrainedConfig`]: The configuration object instantiated from this pretrained model.
+
+ Examples:
+
+ ```python
+ # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
+ # derived class: BertConfig
+ config = BertConfig.from_pretrained(
+ "google-bert/bert-base-uncased"
+ ) # Download configuration from huggingface.co and cache.
+ config = BertConfig.from_pretrained(
+ "./test/saved_model/"
+ ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
+ config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
+ config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
+ assert config.output_attentions == True
+ config, unused_kwargs = BertConfig.from_pretrained(
+ "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
+ )
+ assert config.output_attentions == True
+ assert unused_kwargs == {"foo": False}
+ ```"""
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ cls._set_token_in_kwargs(kwargs, token)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+ if cls.base_config_key and cls.base_config_key in config_dict:
+ config_dict = config_dict[cls.base_config_key]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ # sometimes the config has no `base_config_key` if the config is used in several composite models
+ # e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
+ for k, v in config_dict.items():
+ if isinstance(v, dict) and v.get("model_type") == cls.model_type:
+ config_dict = v
+
+ # raise warning only if we still can't see a match in `model_type`
+ if config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+ @classmethod
+ def get_config_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ [`PretrainedConfig`] using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+
+ Returns:
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
+
+ """
+ cls._set_token_in_kwargs(kwargs)
+
+ original_kwargs = copy.deepcopy(kwargs)
+ # Get config dict associated with the base config file
+ config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
+ if config_dict is None:
+ return {}, kwargs
+ if "_commit_hash" in config_dict:
+ original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
+
+ # That config file may point us toward another config file to use.
+ if "configuration_files" in config_dict:
+ configuration_file = get_configuration_file(config_dict["configuration_files"])
+ config_dict, kwargs = cls._get_config_dict(
+ pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
+ )
+
+ return config_dict, kwargs
+
+ @classmethod
+ def _get_config_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", None)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
+ subfolder = kwargs.pop("subfolder", "")
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+ commit_hash = kwargs.pop("_commit_hash", None)
+
+ gguf_file = kwargs.get("gguf_file", None)
+
+ if trust_remote_code is True:
+ logger.warning(
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
+ " ignored."
+ )
+
+ user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
+ # Special case when pretrained_model_name_or_path is a local file
+ resolved_config_file = pretrained_model_name_or_path
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ configuration_file = pretrained_model_name_or_path if gguf_file is None else gguf_file
+ resolved_config_file = download_url(pretrained_model_name_or_path)
+ else:
+ configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
+
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_config_file = cached_file(
+ pretrained_model_name_or_path,
+ configuration_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _commit_hash=commit_hash,
+ )
+ if resolved_config_file is None:
+ return None, kwargs
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise EnvironmentError(
+ f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
+ f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
+ f" containing a {configuration_file} file"
+ )
+
+ try:
+ if gguf_file:
+ config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
+ else:
+ # Load config dict
+ config_dict = cls._dict_from_json_file(resolved_config_file)
+
+ config_dict["_commit_hash"] = commit_hash
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(
+ f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
+ )
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_config_file}")
+ else:
+ logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
+
+ if "auto_map" in config_dict and not is_local:
+ config_dict["auto_map"] = add_model_info_to_auto_map(
+ config_dict["auto_map"], pretrained_model_name_or_path
+ )
+ if "custom_pipelines" in config_dict and not is_local:
+ config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
+ config_dict["custom_pipelines"], pretrained_model_name_or_path
+ )
+
+ # timm models are not saved with the model_type in the config file
+ if "model_type" not in config_dict and is_timm_config_dict(config_dict):
+ config_dict["model_type"] = "timm_wrapper"
+
+ return config_dict, kwargs
+
+ @classmethod
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
+ """
+ Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
+
+ Args:
+ config_dict (`Dict[str, Any]`):
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
+ retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
+ kwargs (`Dict[str, Any]`):
+ Additional parameters from which to initialize the configuration object.
+
+ Returns:
+ [`PretrainedConfig`]: The configuration object instantiated from those parameters.
+ """
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+ # Those arguments may be passed along for our internal telemetry.
+ # We remove them so they don't appear in `return_unused_kwargs`.
+ kwargs.pop("_from_auto", None)
+ kwargs.pop("_from_pipeline", None)
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
+ kwargs["_commit_hash"] = config_dict["_commit_hash"]
+
+ # We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
+ config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
+
+ config = cls(**config_dict)
+
+ if hasattr(config, "pruned_heads"):
+ config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
+
+ # Update config with kwargs if needed
+ if "num_labels" in kwargs and "id2label" in kwargs:
+ num_labels = kwargs["num_labels"]
+ id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
+ if len(id2label) != num_labels:
+ raise ValueError(
+ f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
+ f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
+ "one of them."
+ )
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(config, key):
+ current_attr = getattr(config, key)
+ # To authorize passing a custom subconfig as kwarg in models that have nested configs.
+ if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
+ value = current_attr.__class__(**value)
+ setattr(config, key, value)
+ if key != "torch_dtype":
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ logger.info(f"Model config {config}")
+ if return_unused_kwargs:
+ return config, kwargs
+ else:
+ return config
+
+ @classmethod
+ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
+ """
+ Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
+
+ Args:
+ json_file (`str` or `os.PathLike`):
+ Path to the JSON file containing the parameters.
+
+ Returns:
+ [`PretrainedConfig`]: The configuration object instantiated from that JSON file.
+
+ """
+ config_dict = cls._dict_from_json_file(json_file)
+ return cls(**config_dict)
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def __eq__(self, other):
+ return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ def __iter__(self):
+ for attr in self.__dict__:
+ yield attr
+
+ def to_diff_dict(self) -> Dict[str, Any]:
+ """
+ Removes all attributes from config which correspond to the default config attributes for better readability and
+ serializes to a Python dictionary.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ config_dict = self.to_dict()
+
+ # get the default config dict
+ default_config_dict = PretrainedConfig().to_dict()
+
+ # get class specific config dict
+ class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
+
+ serializable_config_dict = {}
+
+ # only serialize values that differ from the default config
+ for key, value in config_dict.items():
+ if (
+ isinstance(getattr(self, key, None), PretrainedConfig)
+ and key in class_config_dict
+ and isinstance(class_config_dict[key], dict)
+ ):
+ # For nested configs we need to clean the diff recursively
+ diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
+ if "model_type" in value:
+ # Needs to be set even if it's not in the diff
+ diff["model_type"] = value["model_type"]
+ if len(diff) > 0:
+ serializable_config_dict[key] = diff
+ elif (
+ key not in default_config_dict
+ or key == "transformers_version"
+ or value != default_config_dict[key]
+ or (key in class_config_dict and value != class_config_dict[key])
+ ):
+ serializable_config_dict[key] = value
+
+ if hasattr(self, "quantization_config"):
+ serializable_config_dict["quantization_config"] = (
+ self.quantization_config.to_dict()
+ if not isinstance(self.quantization_config, dict)
+ else self.quantization_config
+ )
+
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
+ _ = serializable_config_dict.pop("_pre_quantization_dtype", None)
+
+ self.dict_torch_dtype_to_str(serializable_config_dict)
+
+ if "_attn_implementation_internal" in serializable_config_dict:
+ del serializable_config_dict["_attn_implementation_internal"]
+ # Do not serialize `base_model_tp_plan` for now
+ if "base_model_tp_plan" in serializable_config_dict:
+ del serializable_config_dict["base_model_tp_plan"]
+
+ return serializable_config_dict
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ if hasattr(self.__class__, "model_type"):
+ output["model_type"] = self.__class__.model_type
+ if "_auto_class" in output:
+ del output["_auto_class"]
+ if "_commit_hash" in output:
+ del output["_commit_hash"]
+ if "_attn_implementation_internal" in output:
+ del output["_attn_implementation_internal"]
+ # Do not serialize `base_model_tp_plan` for now
+ if "base_model_tp_plan" in output:
+ del output["base_model_tp_plan"]
+
+ # Transformers version when serializing the model
+ output["transformers_version"] = __version__
+
+ for key, value in output.items():
+ # Deal with nested configs like CLIP
+ if isinstance(value, PretrainedConfig):
+ value = value.to_dict()
+ del value["transformers_version"]
+
+ output[key] = value
+
+ if hasattr(self, "quantization_config"):
+ output["quantization_config"] = (
+ self.quantization_config.to_dict()
+ if not isinstance(self.quantization_config, dict)
+ else self.quantization_config
+ )
+
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
+ _ = output.pop("_pre_quantization_dtype", None)
+
+ self.dict_torch_dtype_to_str(output)
+
+ return output
+
+ def to_json_string(self, use_diff: bool = True) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Args:
+ use_diff (`bool`, *optional*, defaults to `True`):
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
+ is serialized to JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+ if use_diff is True:
+ config_dict = self.to_diff_dict()
+ else:
+ config_dict = self.to_dict()
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ use_diff (`bool`, *optional*, defaults to `True`):
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
+ is serialized to JSON file.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string(use_diff=use_diff))
+
+ def update(self, config_dict: Dict[str, Any]):
+ """
+ Updates attributes of this class with attributes from `config_dict`.
+
+ Args:
+ config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
+ """
+ for key, value in config_dict.items():
+ setattr(self, key, value)
+
+ def update_from_string(self, update_str: str):
+ """
+ Updates attributes of this class with attributes from `update_str`.
+
+ The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+
+ The keys to change have to already exist in the config object.
+
+ Args:
+ update_str (`str`): String with attributes that should be updated for this class.
+
+ """
+
+ d = dict(x.split("=") for x in update_str.split(","))
+ for k, v in d.items():
+ if not hasattr(self, k):
+ raise ValueError(f"key {k} isn't in the original config dict")
+
+ old_v = getattr(self, k)
+ if isinstance(old_v, bool):
+ if v.lower() in ["true", "1", "y", "yes"]:
+ v = True
+ elif v.lower() in ["false", "0", "n", "no"]:
+ v = False
+ else:
+ raise ValueError(f"can't derive true or false from {v} (key {k})")
+ elif isinstance(old_v, int):
+ v = int(v)
+ elif isinstance(old_v, float):
+ v = float(v)
+ elif not isinstance(old_v, str):
+ raise TypeError(
+ f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
+ )
+
+ setattr(self, k, v)
+
+ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
+ """
+ Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
+ converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
+ string, which can then be stored in the json format.
+ """
+ if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
+ d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
+ for value in d.values():
+ if isinstance(value, dict):
+ self.dict_torch_dtype_to_str(value)
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoConfig"):
+ """
+ Register this class with a given auto class. This should only be used for custom configurations as the ones in
+ the library are already mapped with `AutoConfig`.
+
+
+
+ This API is experimental and may have some slight breaking changes in the next releases.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
+ The auto class to register this new configuration with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+ @staticmethod
+ def _get_global_generation_defaults() -> Dict[str, Any]:
+ return {
+ "max_length": 20,
+ "min_length": 0,
+ "do_sample": False,
+ "early_stopping": False,
+ "num_beams": 1,
+ "num_beam_groups": 1,
+ "diversity_penalty": 0.0,
+ "temperature": 1.0,
+ "top_k": 50,
+ "top_p": 1.0,
+ "typical_p": 1.0,
+ "repetition_penalty": 1.0,
+ "length_penalty": 1.0,
+ "no_repeat_ngram_size": 0,
+ "encoder_no_repeat_ngram_size": 0,
+ "bad_words_ids": None,
+ "num_return_sequences": 1,
+ "output_scores": False,
+ "return_dict_in_generate": False,
+ "forced_bos_token_id": None,
+ "forced_eos_token_id": None,
+ "remove_invalid_values": False,
+ "exponential_decay_length_penalty": None,
+ "suppress_tokens": None,
+ "begin_suppress_tokens": None,
+ }
+
+ def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
+ """
+ Gets the non-default generation parameters on the PretrainedConfig instance
+ """
+ non_default_generation_parameters = {}
+ decoder_attribute_name = None
+
+ # Composite models don't have a default config, use their decoder config as a fallback for default values
+ # If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
+ try:
+ default_config = self.__class__()
+ except ValueError:
+ decoder_config = self.get_text_config(decoder=True)
+ if decoder_config is not self:
+ default_config = decoder_config.__class__()
+ else:
+ default_config = None
+
+ # If it is a composite model, we want to check the subconfig that will be used for generation
+ self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
+
+ for parameter_name, default_global_value in self._get_global_generation_defaults().items():
+ if hasattr(self_decoder_config, parameter_name):
+ is_default_in_config = is_default_generation_value = None
+ parameter_value = getattr(self_decoder_config, parameter_name)
+ # Three cases in which is okay for the model config to hold generation config parameters:
+ # 1. The parameter is set to `None`, effectivelly delegating its value to the generation config
+ if parameter_value is None:
+ continue
+ # 2. If we have a default config, then the instance should hold the same generation defaults
+ if default_config is not None:
+ is_default_in_config = parameter_value == getattr(default_config, parameter_name)
+ # 3. if we don't have a default config, then the instance should hold the global generation defaults
+ else:
+ is_default_generation_value = parameter_value == default_global_value
+
+ is_non_default = (is_default_in_config is False) or (
+ is_default_in_config is None and is_default_generation_value is False
+ )
+ if is_non_default:
+ non_default_generation_parameters[parameter_name] = getattr(self_decoder_config, parameter_name)
+
+ return non_default_generation_parameters
+
+ def get_text_config(self, decoder=False) -> "PretrainedConfig":
+ """
+ Returns the config that is meant to be used with text IO. On most models, it is the original config instance
+ itself. On specific composite models, it is under a set of valid names.
+
+ If `decoder` is set to `True`, then only search for decoder config names.
+ """
+ decoder_possible_text_config_names = ("decoder", "generator", "text_config")
+ encoder_possible_text_config_names = ("text_encoder",)
+ if decoder:
+ possible_text_config_names = decoder_possible_text_config_names
+ else:
+ possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
+
+ valid_text_config_names = []
+ for text_config_name in possible_text_config_names:
+ if hasattr(self, text_config_name):
+ text_config = getattr(self, text_config_name, None)
+ if text_config is not None:
+ valid_text_config_names += [text_config_name]
+
+ if len(valid_text_config_names) > 1:
+ raise ValueError(
+ f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
+ "case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
+ )
+ elif len(valid_text_config_names) == 1:
+ return getattr(self, valid_text_config_names[0])
+ return self
+
+
+def get_configuration_file(configuration_files: List[str]) -> str:
+ """
+ Get the configuration file to use for this version of transformers.
+
+ Args:
+ configuration_files (`List[str]`): The list of available configuration files.
+
+ Returns:
+ `str`: The configuration file to use.
+ """
+ configuration_files_map = {}
+ for file_name in configuration_files:
+ search = _re_configuration_file.search(file_name)
+ if search is not None:
+ v = search.groups()[0]
+ configuration_files_map[v] = file_name
+ available_versions = sorted(configuration_files_map.keys())
+
+ # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
+ configuration_file = CONFIG_NAME
+ transformers_version = version.parse(__version__)
+ for v in available_versions:
+ if version.parse(v) <= transformers_version:
+ configuration_file = configuration_files_map[v]
+ else:
+ # No point going further since the versions are sorted.
+ break
+
+ return configuration_file
+
+
+def recursive_diff_dict(dict_a, dict_b, config_obj=None):
+ """
+ Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
+ values from `dict_a` that are different from values in `dict_b`.
+ """
+ diff = {}
+ default = config_obj.__class__().to_dict() if config_obj is not None else {}
+ for key, value in dict_a.items():
+ obj_value = getattr(config_obj, str(key), None)
+ if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
+ diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
+ if len(diff_value) > 0:
+ diff[key] = diff_value
+ elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
+ diff[key] = value
+ return diff
+
+
+PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
+if PretrainedConfig.push_to_hub.__doc__ is not None:
+ PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
+ object="config", object_class="AutoConfig", object_files="configuration file"
+ )
diff --git a/.venv/lib/python3.11/site-packages/transformers/convert_graph_to_onnx.py b/.venv/lib/python3.11/site-packages/transformers/convert_graph_to_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..051f1d148a84e29d3e706fc4cd42f3ca7d53db26
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/convert_graph_to_onnx.py
@@ -0,0 +1,551 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from argparse import ArgumentParser
+from os import listdir, makedirs
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+from packaging.version import Version, parse
+
+from transformers.pipelines import Pipeline, pipeline
+from transformers.tokenization_utils import BatchEncoding
+from transformers.utils import ModelOutput, is_tf_available, is_torch_available
+
+
+# This is the minimal required version to
+# support some ONNX Runtime features
+ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
+
+
+SUPPORTED_PIPELINES = [
+ "feature-extraction",
+ "ner",
+ "sentiment-analysis",
+ "fill-mask",
+ "question-answering",
+ "text-generation",
+ "translation_en_to_fr",
+ "translation_en_to_de",
+ "translation_en_to_ro",
+]
+
+
+class OnnxConverterArgumentParser(ArgumentParser):
+ """
+ Wraps all the script arguments supported to export transformers models to ONNX IR
+ """
+
+ def __init__(self):
+ super().__init__("ONNX Converter")
+
+ self.add_argument(
+ "--pipeline",
+ type=str,
+ choices=SUPPORTED_PIPELINES,
+ default="feature-extraction",
+ )
+ self.add_argument(
+ "--model",
+ type=str,
+ required=True,
+ help="Model's id or path (ex: google-bert/bert-base-cased)",
+ )
+ self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: google-bert/bert-base-cased)")
+ self.add_argument(
+ "--framework",
+ type=str,
+ choices=["pt", "tf"],
+ help="Framework for loading the model",
+ )
+ self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
+ self.add_argument(
+ "--check-loading",
+ action="store_true",
+ help="Check ONNX is able to load the model",
+ )
+ self.add_argument(
+ "--use-external-format",
+ action="store_true",
+ help="Allow exporting model >= than 2Gb",
+ )
+ self.add_argument(
+ "--quantize",
+ action="store_true",
+ help="Quantize the neural network to be run with int8",
+ )
+ self.add_argument("output")
+
+
+def generate_identified_filename(filename: Path, identifier: str) -> Path:
+ """
+ Append a string-identifier at the end (before the extension, if any) to the provided filepath
+
+ Args:
+ filename: pathlib.Path The actual path object we would like to add an identifier suffix
+ identifier: The suffix to add
+
+ Returns: String with concatenated identifier at the end of the filename
+ """
+ return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)
+
+
+def check_onnxruntime_requirements(minimum_version: Version):
+ """
+ Check onnxruntime is installed and if the installed version match is recent enough
+
+ Raises:
+ ImportError: If onnxruntime is not installed or too old version is found
+ """
+ try:
+ import onnxruntime
+
+ # Parse the version of the installed onnxruntime
+ ort_version = parse(onnxruntime.__version__)
+
+ # We require 1.4.0 minimum
+ if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
+ raise ImportError(
+ f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
+ f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
+ "Please update onnxruntime by running `pip install --upgrade onnxruntime`"
+ )
+
+ except ImportError:
+ raise ImportError(
+ "onnxruntime doesn't seem to be currently installed. "
+ "Please install the onnxruntime by running `pip install onnxruntime`"
+ " and relaunch the conversion."
+ )
+
+
+def ensure_valid_input(model, tokens, input_names):
+ """
+ Ensure inputs are presented in the correct order, without any Non
+
+ Args:
+ model: The model used to forward the input data
+ tokens: BatchEncoding holding the input data
+ input_names: The name of the inputs
+
+ Returns: Tuple
+
+ """
+ print("Ensuring inputs are in correct order")
+
+ model_args_name = model.forward.__code__.co_varnames
+ model_args, ordered_input_names = [], []
+ for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
+ if arg_name in input_names:
+ ordered_input_names.append(arg_name)
+ model_args.append(tokens[arg_name])
+ else:
+ print(f"{arg_name} is not present in the generated input list.")
+ break
+
+ print(f"Generated inputs order: {ordered_input_names}")
+ return ordered_input_names, tuple(model_args)
+
+
+def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
+ """
+ Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model
+
+ Args:
+ nlp: The pipeline object holding the model to be exported
+ framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)
+
+ Returns:
+
+ - List of the inferred input variable names
+ - List of the inferred output variable names
+ - Dictionary with input/output variables names as key and shape tensor as value
+ - a BatchEncoding reference which was used to infer all the above information
+ """
+
+ def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
+ if isinstance(tensor, (tuple, list)):
+ return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
+
+ else:
+ # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
+ axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
+ if is_input:
+ if len(tensor.shape) == 2:
+ axes[1] = "sequence"
+ else:
+ raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
+ else:
+ seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
+ axes.update({dim: "sequence" for dim in seq_axes})
+
+ print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
+ return axes
+
+ tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
+ seq_len = tokens.input_ids.shape[-1]
+ outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
+ if isinstance(outputs, ModelOutput):
+ outputs = outputs.to_tuple()
+ if not isinstance(outputs, (list, tuple)):
+ outputs = (outputs,)
+
+ # Generate input names & axes
+ input_vars = list(tokens.keys())
+ input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}
+
+ # flatten potentially grouped outputs (past for gpt2, attentions)
+ outputs_flat = []
+ for output in outputs:
+ if isinstance(output, (tuple, list)):
+ outputs_flat.extend(output)
+ else:
+ outputs_flat.append(output)
+
+ # Generate output names & axes
+ output_names = [f"output_{i}" for i in range(len(outputs_flat))]
+ output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
+
+ # Create the aggregated axes representation
+ dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
+ return input_vars, output_names, dynamic_axes, tokens
+
+
+def load_graph_from_args(
+ pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs
+) -> Pipeline:
+ """
+ Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model
+
+ Args:
+ pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
+ framework: The actual model to convert the pipeline from ("pt" or "tf")
+ model: The model name which will be loaded by the pipeline
+ tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value
+
+ Returns: Pipeline object
+
+ """
+ # If no tokenizer provided
+ if tokenizer is None:
+ tokenizer = model
+
+ # Check the wanted framework is available
+ if framework == "pt" and not is_torch_available():
+ raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
+ if framework == "tf" and not is_tf_available():
+ raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
+
+ print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")
+
+ # Allocate tokenizer and model
+ return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)
+
+
+def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
+ """
+ Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
+
+ Args:
+ nlp: The pipeline to be exported
+ opset: The actual version of the ONNX operator set to use
+ output: Path where will be stored the generated ONNX model
+ use_external_format: Split the model definition from its parameters to allow model bigger than 2GB
+
+ Returns:
+
+ """
+ if not is_torch_available():
+ raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
+
+ import torch
+ from torch.onnx import export
+
+ print(f"Using framework PyTorch: {torch.__version__}")
+
+ with torch.no_grad():
+ input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
+ ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
+
+ export(
+ nlp.model,
+ model_args,
+ f=output.as_posix(),
+ input_names=ordered_input_names,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ do_constant_folding=True,
+ opset_version=opset,
+ )
+
+
+def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
+ """
+ Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
+
+ Args:
+ nlp: The pipeline to be exported
+ opset: The actual version of the ONNX operator set to use
+ output: Path where will be stored the generated ONNX model
+
+ Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow
+
+ """
+ if not is_tf_available():
+ raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
+
+ print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
+
+ try:
+ import tensorflow as tf
+ import tf2onnx
+ from tf2onnx import __version__ as t2ov
+
+ print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
+
+ # Build
+ input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
+
+ # Forward
+ nlp.model.predict(tokens.data)
+ input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()]
+ model_proto, _ = tf2onnx.convert.from_keras(
+ nlp.model, input_signature, opset=opset, output_path=output.as_posix()
+ )
+
+ except ImportError as e:
+ raise Exception(
+ f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}"
+ )
+
+
+def convert(
+ framework: str,
+ model: str,
+ output: Path,
+ opset: int,
+ tokenizer: Optional[str] = None,
+ use_external_format: bool = False,
+ pipeline_name: str = "feature-extraction",
+ **model_kwargs,
+):
+ """
+ Convert the pipeline object to the ONNX Intermediate Representation (IR) format
+
+ Args:
+ framework: The framework the pipeline is backed by ("pt" or "tf")
+ model: The name of the model to load for the pipeline
+ output: The path where the ONNX graph will be stored
+ opset: The actual version of the ONNX operator set to use
+ tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided
+ use_external_format:
+ Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)
+ pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)
+ model_kwargs: Keyword arguments to be forwarded to the model constructor
+
+ Returns:
+
+ """
+ warnings.warn(
+ "The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of"
+ " Transformers",
+ FutureWarning,
+ )
+ print(f"ONNX opset version set to: {opset}")
+
+ # Load the pipeline
+ nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)
+
+ if not output.parent.exists():
+ print(f"Creating folder {output.parent}")
+ makedirs(output.parent.as_posix())
+ elif len(listdir(output.parent.as_posix())) > 0:
+ raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion")
+
+ # Export the graph
+ if framework == "pt":
+ convert_pytorch(nlp, opset, output, use_external_format)
+ else:
+ convert_tensorflow(nlp, opset, output)
+
+
+def optimize(onnx_model_path: Path) -> Path:
+ """
+ Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the
+ optimizations possible
+
+ Args:
+ onnx_model_path: filepath where the model binary description is stored
+
+ Returns: Path where the optimized model binary description has been saved
+
+ """
+ from onnxruntime import InferenceSession, SessionOptions
+
+ # Generate model name with suffix "optimized"
+ opt_model_path = generate_identified_filename(onnx_model_path, "-optimized")
+ sess_option = SessionOptions()
+ sess_option.optimized_model_filepath = opt_model_path.as_posix()
+ _ = InferenceSession(onnx_model_path.as_posix(), sess_option)
+
+ print(f"Optimized model has been written at {opt_model_path}: \N{HEAVY CHECK MARK}")
+ print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\")
+
+ return opt_model_path
+
+
+def quantize(onnx_model_path: Path) -> Path:
+ """
+ Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
+
+ Args:
+ onnx_model_path: Path to location the exported ONNX model is stored
+
+ Returns: The Path generated for the quantized
+ """
+ import onnx
+ import onnxruntime
+ from onnx.onnx_pb import ModelProto
+ from onnxruntime.quantization import QuantizationMode
+ from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
+ from onnxruntime.quantization.registry import IntegerOpsRegistry
+
+ # Load the ONNX model
+ onnx_model = onnx.load(onnx_model_path.as_posix())
+
+ if parse(onnx.__version__) < parse("1.5.0"):
+ print(
+ "Models larger than 2GB will fail to quantize due to protobuf constraint.\n"
+ "Please upgrade to onnxruntime >= 1.5.0."
+ )
+
+ # Copy it
+ copy_model = ModelProto()
+ copy_model.CopyFrom(onnx_model)
+
+ # Construct quantizer
+ # onnxruntime renamed input_qType to activation_qType in v1.13.1, so we
+ # check the onnxruntime version to ensure backward compatibility.
+ # See also: https://github.com/microsoft/onnxruntime/pull/12873
+ if parse(onnxruntime.__version__) < parse("1.13.1"):
+ quantizer = ONNXQuantizer(
+ model=copy_model,
+ per_channel=False,
+ reduce_range=False,
+ mode=QuantizationMode.IntegerOps,
+ static=False,
+ weight_qType=True,
+ input_qType=False,
+ tensors_range=None,
+ nodes_to_quantize=None,
+ nodes_to_exclude=None,
+ op_types_to_quantize=list(IntegerOpsRegistry),
+ )
+ else:
+ quantizer = ONNXQuantizer(
+ model=copy_model,
+ per_channel=False,
+ reduce_range=False,
+ mode=QuantizationMode.IntegerOps,
+ static=False,
+ weight_qType=True,
+ activation_qType=False,
+ tensors_range=None,
+ nodes_to_quantize=None,
+ nodes_to_exclude=None,
+ op_types_to_quantize=list(IntegerOpsRegistry),
+ )
+
+ # Quantize and export
+ quantizer.quantize_model()
+
+ # Append "-quantized" at the end of the model's name
+ quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")
+
+ # Save model
+ print(f"Quantized model has been written at {quantized_model_path}: \N{HEAVY CHECK MARK}")
+ onnx.save_model(quantizer.model.model, quantized_model_path.as_posix())
+
+ return quantized_model_path
+
+
+def verify(path: Path):
+ from onnxruntime import InferenceSession, SessionOptions
+ from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
+
+ print(f"Checking ONNX model loading from: {path} ...")
+ try:
+ onnx_options = SessionOptions()
+ _ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
+ print(f"Model {path} correctly loaded: \N{HEAVY CHECK MARK}")
+ except RuntimeException as re:
+ print(f"Error while loading the model {re}: \N{HEAVY BALLOT X}")
+
+
+if __name__ == "__main__":
+ parser = OnnxConverterArgumentParser()
+ args = parser.parse_args()
+
+ # Make sure output is absolute path
+ args.output = Path(args.output).absolute()
+
+ try:
+ print("\n====== Converting model to ONNX ======")
+ # Convert
+ convert(
+ args.framework,
+ args.model,
+ args.output,
+ args.opset,
+ args.tokenizer,
+ args.use_external_format,
+ args.pipeline,
+ )
+
+ if args.quantize:
+ # Ensure requirements for quantization on onnxruntime is met
+ check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)
+
+ # onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch
+ if args.framework == "tf":
+ print(
+ "\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n"
+ "\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n"
+ "\t For more information, please refer to the onnxruntime documentation:\n"
+ "\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n"
+ )
+
+ print("\n====== Optimizing ONNX model ======")
+
+ # Quantization works best when using the optimized version of the model
+ args.optimized_output = optimize(args.output)
+
+ # Do the quantization on the right graph
+ args.quantized_output = quantize(args.optimized_output)
+
+ # And verify
+ if args.check_loading:
+ print("\n====== Check exported ONNX model(s) ======")
+ verify(args.output)
+
+ if hasattr(args, "optimized_output"):
+ verify(args.optimized_output)
+
+ if hasattr(args, "quantized_output"):
+ verify(args.quantized_output)
+
+ except Exception as e:
+ print(f"Error while converting the model: {e}")
+ exit(1)
diff --git a/.venv/lib/python3.11/site-packages/transformers/convert_pytorch_checkpoint_to_tf2.py b/.venv/lib/python3.11/site-packages/transformers/convert_pytorch_checkpoint_to_tf2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3431ad5b2e0ac6e0969e24b3a00922edb382116
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/convert_pytorch_checkpoint_to_tf2.py
@@ -0,0 +1,446 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert pytorch checkpoints to TensorFlow"""
+
+import argparse
+import os
+
+from . import (
+ AlbertConfig,
+ BartConfig,
+ BertConfig,
+ CamembertConfig,
+ CTRLConfig,
+ DistilBertConfig,
+ DPRConfig,
+ ElectraConfig,
+ FlaubertConfig,
+ GPT2Config,
+ LayoutLMConfig,
+ LxmertConfig,
+ OpenAIGPTConfig,
+ RobertaConfig,
+ T5Config,
+ TFAlbertForPreTraining,
+ TFBartForConditionalGeneration,
+ TFBartForSequenceClassification,
+ TFBertForPreTraining,
+ TFBertForQuestionAnswering,
+ TFBertForSequenceClassification,
+ TFCamembertForMaskedLM,
+ TFCTRLLMHeadModel,
+ TFDistilBertForMaskedLM,
+ TFDistilBertForQuestionAnswering,
+ TFDPRContextEncoder,
+ TFDPRQuestionEncoder,
+ TFDPRReader,
+ TFElectraForPreTraining,
+ TFFlaubertWithLMHeadModel,
+ TFGPT2LMHeadModel,
+ TFLayoutLMForMaskedLM,
+ TFLxmertForPreTraining,
+ TFLxmertVisualFeatureEncoder,
+ TFOpenAIGPTLMHeadModel,
+ TFRobertaForCausalLM,
+ TFRobertaForMaskedLM,
+ TFRobertaForSequenceClassification,
+ TFT5ForConditionalGeneration,
+ TFTransfoXLLMHeadModel,
+ TFWav2Vec2Model,
+ TFXLMRobertaForMaskedLM,
+ TFXLMWithLMHeadModel,
+ TFXLNetLMHeadModel,
+ TransfoXLConfig,
+ Wav2Vec2Config,
+ Wav2Vec2Model,
+ XLMConfig,
+ XLMRobertaConfig,
+ XLNetConfig,
+ is_torch_available,
+ load_pytorch_checkpoint_in_tf2_model,
+)
+from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
+
+
+if is_torch_available():
+ import numpy as np
+ import torch
+
+ from . import (
+ AlbertForPreTraining,
+ BartForConditionalGeneration,
+ BertForPreTraining,
+ BertForQuestionAnswering,
+ BertForSequenceClassification,
+ CamembertForMaskedLM,
+ CTRLLMHeadModel,
+ DistilBertForMaskedLM,
+ DistilBertForQuestionAnswering,
+ DPRContextEncoder,
+ DPRQuestionEncoder,
+ DPRReader,
+ ElectraForPreTraining,
+ FlaubertWithLMHeadModel,
+ GPT2LMHeadModel,
+ LayoutLMForMaskedLM,
+ LxmertForPreTraining,
+ LxmertVisualFeatureEncoder,
+ OpenAIGPTLMHeadModel,
+ RobertaForMaskedLM,
+ RobertaForSequenceClassification,
+ T5ForConditionalGeneration,
+ TransfoXLLMHeadModel,
+ XLMRobertaForMaskedLM,
+ XLMWithLMHeadModel,
+ XLNetLMHeadModel,
+ )
+
+
+logging.set_verbosity_info()
+
+MODEL_CLASSES = {
+ "bart": (
+ BartConfig,
+ TFBartForConditionalGeneration,
+ TFBartForSequenceClassification,
+ BartForConditionalGeneration,
+ ),
+ "bert": (
+ BertConfig,
+ TFBertForPreTraining,
+ BertForPreTraining,
+ ),
+ "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
+ BertConfig,
+ TFBertForQuestionAnswering,
+ BertForQuestionAnswering,
+ ),
+ "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
+ BertConfig,
+ TFBertForQuestionAnswering,
+ BertForQuestionAnswering,
+ ),
+ "google-bert/bert-base-cased-finetuned-mrpc": (
+ BertConfig,
+ TFBertForSequenceClassification,
+ BertForSequenceClassification,
+ ),
+ "dpr": (
+ DPRConfig,
+ TFDPRQuestionEncoder,
+ TFDPRContextEncoder,
+ TFDPRReader,
+ DPRQuestionEncoder,
+ DPRContextEncoder,
+ DPRReader,
+ ),
+ "openai-community/gpt2": (
+ GPT2Config,
+ TFGPT2LMHeadModel,
+ GPT2LMHeadModel,
+ ),
+ "xlnet": (
+ XLNetConfig,
+ TFXLNetLMHeadModel,
+ XLNetLMHeadModel,
+ ),
+ "xlm": (
+ XLMConfig,
+ TFXLMWithLMHeadModel,
+ XLMWithLMHeadModel,
+ ),
+ "xlm-roberta": (
+ XLMRobertaConfig,
+ TFXLMRobertaForMaskedLM,
+ XLMRobertaForMaskedLM,
+ ),
+ "transfo-xl": (
+ TransfoXLConfig,
+ TFTransfoXLLMHeadModel,
+ TransfoXLLMHeadModel,
+ ),
+ "openai-community/openai-gpt": (
+ OpenAIGPTConfig,
+ TFOpenAIGPTLMHeadModel,
+ OpenAIGPTLMHeadModel,
+ ),
+ "roberta": (
+ RobertaConfig,
+ TFRobertaForCausalLM,
+ TFRobertaForMaskedLM,
+ RobertaForMaskedLM,
+ ),
+ "layoutlm": (
+ LayoutLMConfig,
+ TFLayoutLMForMaskedLM,
+ LayoutLMForMaskedLM,
+ ),
+ "FacebookAI/roberta-large-mnli": (
+ RobertaConfig,
+ TFRobertaForSequenceClassification,
+ RobertaForSequenceClassification,
+ ),
+ "camembert": (
+ CamembertConfig,
+ TFCamembertForMaskedLM,
+ CamembertForMaskedLM,
+ ),
+ "flaubert": (
+ FlaubertConfig,
+ TFFlaubertWithLMHeadModel,
+ FlaubertWithLMHeadModel,
+ ),
+ "distilbert": (
+ DistilBertConfig,
+ TFDistilBertForMaskedLM,
+ DistilBertForMaskedLM,
+ ),
+ "distilbert-base-distilled-squad": (
+ DistilBertConfig,
+ TFDistilBertForQuestionAnswering,
+ DistilBertForQuestionAnswering,
+ ),
+ "lxmert": (
+ LxmertConfig,
+ TFLxmertForPreTraining,
+ LxmertForPreTraining,
+ ),
+ "lxmert-visual-feature-encoder": (
+ LxmertConfig,
+ TFLxmertVisualFeatureEncoder,
+ LxmertVisualFeatureEncoder,
+ ),
+ "Salesforce/ctrl": (
+ CTRLConfig,
+ TFCTRLLMHeadModel,
+ CTRLLMHeadModel,
+ ),
+ "albert": (
+ AlbertConfig,
+ TFAlbertForPreTraining,
+ AlbertForPreTraining,
+ ),
+ "t5": (
+ T5Config,
+ TFT5ForConditionalGeneration,
+ T5ForConditionalGeneration,
+ ),
+ "electra": (
+ ElectraConfig,
+ TFElectraForPreTraining,
+ ElectraForPreTraining,
+ ),
+ "wav2vec2": (
+ Wav2Vec2Config,
+ TFWav2Vec2Model,
+ Wav2Vec2Model,
+ ),
+}
+
+
+def convert_pt_checkpoint_to_tf(
+ model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
+):
+ if model_type not in MODEL_CLASSES:
+ raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.")
+
+ config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
+
+ # Initialise TF model
+ if config_file in aws_config_map:
+ config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
+ config = config_class.from_json_file(config_file)
+ config.output_hidden_states = True
+ config.output_attentions = True
+ print(f"Building TensorFlow model from configuration: {config}")
+ tf_model = model_class(config)
+
+ # Load weights from tf checkpoint
+ if pytorch_checkpoint_path in aws_config_map.keys():
+ pytorch_checkpoint_path = cached_file(
+ pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
+ )
+ # Load PyTorch checkpoint in tf2 model:
+ tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
+
+ if compare_with_pt_model:
+ tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
+
+ weights_only_kwarg = {"weights_only": True}
+ state_dict = torch.load(
+ pytorch_checkpoint_path,
+ map_location="cpu",
+ **weights_only_kwarg,
+ )
+ pt_model = pt_model_class.from_pretrained(
+ pretrained_model_name_or_path=None, config=config, state_dict=state_dict
+ )
+
+ with torch.no_grad():
+ pto = pt_model(**pt_model.dummy_inputs)
+
+ np_pt = pto[0].numpy()
+ np_tf = tfo[0].numpy()
+ diff = np.amax(np.abs(np_pt - np_tf))
+ print(f"Max absolute difference between models outputs {diff}")
+ assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}"
+
+ # Save pytorch-model
+ print(f"Save TensorFlow model to {tf_dump_path}")
+ tf_model.save_weights(tf_dump_path, save_format="h5")
+
+
+def convert_all_pt_checkpoints_to_tf(
+ args_model_type,
+ tf_dump_path,
+ model_shortcut_names_or_path=None,
+ config_shortcut_names_or_path=None,
+ compare_with_pt_model=False,
+ use_cached_models=False,
+ remove_cached_files=False,
+ only_convert_finetuned_models=False,
+):
+ if args_model_type is None:
+ model_types = list(MODEL_CLASSES.keys())
+ else:
+ model_types = [args_model_type]
+
+ for j, model_type in enumerate(model_types, start=1):
+ print("=" * 100)
+ print(f" Converting model type {j}/{len(model_types)}: {model_type}")
+ print("=" * 100)
+ if model_type not in MODEL_CLASSES:
+ raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.")
+
+ config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
+
+ if model_shortcut_names_or_path is None:
+ model_shortcut_names_or_path = list(aws_model_maps.keys())
+ if config_shortcut_names_or_path is None:
+ config_shortcut_names_or_path = model_shortcut_names_or_path
+
+ for i, (model_shortcut_name, config_shortcut_name) in enumerate(
+ zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
+ ):
+ print("-" * 100)
+ if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
+ if not only_convert_finetuned_models:
+ print(f" Skipping finetuned checkpoint {model_shortcut_name}")
+ continue
+ model_type = model_shortcut_name
+ elif only_convert_finetuned_models:
+ print(f" Skipping not finetuned checkpoint {model_shortcut_name}")
+ continue
+ print(
+ f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}"
+ )
+ print("-" * 100)
+
+ if config_shortcut_name in aws_config_map:
+ config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
+ else:
+ config_file = config_shortcut_name
+
+ if model_shortcut_name in aws_model_maps:
+ model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
+ else:
+ model_file = model_shortcut_name
+
+ if os.path.isfile(model_shortcut_name):
+ model_shortcut_name = "converted_model"
+
+ convert_pt_checkpoint_to_tf(
+ model_type=model_type,
+ pytorch_checkpoint_path=model_file,
+ config_file=config_file,
+ tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
+ compare_with_pt_model=compare_with_pt_model,
+ )
+ if remove_cached_files:
+ os.remove(config_file)
+ os.remove(model_file)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
+ )
+ parser.add_argument(
+ "--model_type",
+ default=None,
+ type=str,
+ help=(
+ f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and "
+ "convert all the models from AWS."
+ ),
+ )
+ parser.add_argument(
+ "--pytorch_checkpoint_path",
+ default=None,
+ type=str,
+ help=(
+ "Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
+ "If not given, will download and convert all the checkpoints from AWS."
+ ),
+ )
+ parser.add_argument(
+ "--config_file",
+ default=None,
+ type=str,
+ help=(
+ "The config json file corresponding to the pre-trained model. \n"
+ "This specifies the model architecture. If not given and "
+ "--pytorch_checkpoint_path is not given or is a shortcut name "
+ "use the configuration associated to the shortcut name on the AWS"
+ ),
+ )
+ parser.add_argument(
+ "--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
+ )
+ parser.add_argument(
+ "--use_cached_models",
+ action="store_true",
+ help="Use cached models if possible instead of updating to latest checkpoint versions.",
+ )
+ parser.add_argument(
+ "--remove_cached_files",
+ action="store_true",
+ help="Remove pytorch models after conversion (save memory when converting in batches).",
+ )
+ parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
+ args = parser.parse_args()
+
+ # if args.pytorch_checkpoint_path is not None:
+ # convert_pt_checkpoint_to_tf(args.model_type.lower(),
+ # args.pytorch_checkpoint_path,
+ # args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
+ # args.tf_dump_path,
+ # compare_with_pt_model=args.compare_with_pt_model,
+ # use_cached_models=args.use_cached_models)
+ # else:
+ convert_all_pt_checkpoints_to_tf(
+ args.model_type.lower() if args.model_type is not None else None,
+ args.tf_dump_path,
+ model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
+ if args.pytorch_checkpoint_path is not None
+ else None,
+ config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
+ compare_with_pt_model=args.compare_with_pt_model,
+ use_cached_models=args.use_cached_models,
+ remove_cached_files=args.remove_cached_files,
+ only_convert_finetuned_models=args.only_convert_finetuned_models,
+ )
diff --git a/.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py b/.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..030e3a666436308a5bcbf199e466934a50258767
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py
@@ -0,0 +1,1642 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Utilities to convert slow tokenizers in their fast tokenizers counterparts.
+
+All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
+allow to make our dependency on SentencePiece optional.
+"""
+
+import warnings
+from typing import Dict, List, Tuple
+
+from packaging import version
+from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
+from tokenizers.models import BPE, Unigram, WordPiece
+
+from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
+from .utils.import_utils import PROTOBUF_IMPORT_ERROR
+
+
+logger = logging.get_logger(__name__)
+
+
+def import_protobuf(error_message=""):
+ if is_sentencepiece_available():
+ from sentencepiece import sentencepiece_model_pb2
+
+ return sentencepiece_model_pb2
+ if is_protobuf_available():
+ import google.protobuf
+
+ if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
+ from transformers.utils import sentencepiece_model_pb2
+ else:
+ from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
+ return sentencepiece_model_pb2
+ else:
+ raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
+
+
+def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
+ if add_prefix_space:
+ prepend_scheme = "always"
+ if not getattr(original_tokenizer, "legacy", True):
+ prepend_scheme = "first"
+ else:
+ prepend_scheme = "never"
+ return prepend_scheme
+
+
+def generate_merges(vocab, vocab_scores):
+ reverse = vocab_scores is not None
+ vocab_scores = dict(vocab_scores) if reverse else vocab
+
+ merges = []
+ for merge, piece_score in vocab_scores.items():
+ local = []
+ for index in range(1, len(merge)):
+ piece_l, piece_r = merge[:index], merge[index:]
+ if piece_l in vocab and piece_r in vocab:
+ local.append((piece_l, piece_r, piece_score))
+ local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
+ merges.extend(local)
+
+ merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
+ merges = [(val[0], val[1]) for val in merges]
+ return merges
+
+
+class SentencePieceExtractor:
+ """
+ Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
+ """
+
+ def __init__(self, model: str):
+ requires_backends(self, "sentencepiece")
+ from sentencepiece import SentencePieceProcessor
+
+ self.sp = SentencePieceProcessor()
+ self.sp.Load(model)
+
+ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
+ """
+ By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
+ order the merges with respect to the piece scores instead.
+ """
+ sp = self.sp
+ vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
+
+ merges = generate_merges(vocab, vocab_scores)
+
+ return vocab, merges
+
+
+class GemmaSentencePieceExtractor(SentencePieceExtractor):
+ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
+ """
+ By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
+ order the merges with respect to the piece scores instead.
+ """
+ sp = self.sp
+ vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
+
+ # there is a missing token in the vocab. We have to do this to support merges
+ # "<0x09>" is the bytefallback for `\t`
+ vocab["\t"] = vocab.get("<0x09>")
+
+ merges = generate_merges(vocab, vocab_scores)
+ return vocab, merges
+
+
+def check_number_comma(piece: str) -> bool:
+ return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
+
+
+class Converter:
+ def __init__(self, original_tokenizer):
+ self.original_tokenizer = original_tokenizer
+
+ def converted(self) -> Tokenizer:
+ raise NotImplementedError()
+
+
+class BertConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class SplinterConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ question = str(self.original_tokenizer.question_token)
+ dot = "."
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+ question_token_id = self.original_tokenizer.question_token_id
+ dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
+
+ if self.original_tokenizer.padding_side == "right":
+ pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
+ else:
+ pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=pair,
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ (question, question_token_id),
+ (dot, dot_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class FunnelConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
+ pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class MPNetConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class OpenAIGPTConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+ unk_token = self.original_tokenizer.unk_token
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ unk_token=str(unk_token),
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ if tokenizer.token_to_id(str(unk_token)) is not None:
+ tokenizer.add_special_tokens([str(unk_token)])
+
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+ tokenizer.decoder = decoders.BPEDecoder(suffix="")
+
+ return tokenizer
+
+
+class GPT2Converter(Converter):
+ def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
+ if not vocab:
+ vocab = self.original_tokenizer.encoder
+ if not merges:
+ merges = list(self.original_tokenizer.bpe_ranks)
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+ if getattr(self.original_tokenizer, "add_bos_token", False):
+ bos = self.original_tokenizer.bos_token
+ bos_token_id = self.original_tokenizer.bos_token_id
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{bos}:0 $A:0",
+ pair=f"{bos}:0 $A:0 $B:1",
+ special_tokens=[
+ (bos, bos_token_id),
+ ],
+ )
+ else:
+ # XXX trim_offsets=False actually means this post_processor doesn't
+ # really do anything.
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
+ return tokenizer
+
+
+class HerbertConverter(Converter):
+ def converted(self) -> Tokenizer:
+ tokenizer_info_str = "#version:"
+ token_suffix = ""
+
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+ if tokenizer_info_str in merges[0][0]:
+ merges = merges[1:]
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab,
+ merges,
+ dropout=None,
+ unk_token=self.original_tokenizer.unk_token,
+ end_of_word_suffix=token_suffix,
+ )
+ )
+
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+ tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
+ tokenizer.post_processor = processors.BertProcessing(
+ sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
+ cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
+ )
+
+ return tokenizer
+
+
+class Qwen2Converter(Converter):
+ def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
+ if not vocab:
+ vocab = self.original_tokenizer.encoder
+ if not merges:
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ unk_token=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ byte_fallback=False,
+ )
+ )
+
+ tokenizer.normalizer = normalizers.NFC()
+
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.Split(
+ Regex(
+ r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
+ ),
+ behavior="isolated",
+ invert=False,
+ ),
+ pre_tokenizers.ByteLevel(
+ add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
+ use_regex=False,
+ ),
+ ]
+ )
+
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
+
+ return tokenizer
+
+
+class RobertaConverter(Converter):
+ def converted(self) -> Tokenizer:
+ ot = self.original_tokenizer
+ vocab = ot.encoder
+ merges = list(ot.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.post_processor = processors.RobertaProcessing(
+ sep=(ot.sep_token, ot.sep_token_id),
+ cls=(ot.cls_token, ot.cls_token_id),
+ add_prefix_space=ot.add_prefix_space,
+ trim_offsets=True, # True by default on Roberta (historical)
+ )
+
+ return tokenizer
+
+
+class RoFormerConverter(Converter):
+ def converted(self) -> Tokenizer:
+ from .models.roformer.tokenization_utils import JiebaPreTokenizer
+
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ strip_accents = False
+ do_lower_case = False
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=False,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class DebertaConverter(Converter):
+ def converted(self) -> Tokenizer:
+ ot = self.original_tokenizer
+ vocab = ot.encoder
+ merges = list(ot.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+ return tokenizer
+
+
+class SpmConverter(Converter):
+ handle_byte_fallback = False
+ SpmExtractor = SentencePieceExtractor
+ special_tokens = {}
+
+ def __init__(self, *args):
+ requires_backends(self, "protobuf")
+
+ super().__init__(*args)
+
+ # from .utils import sentencepiece_model_pb2 as model_pb2
+ model_pb2 = import_protobuf()
+
+ m = model_pb2.ModelProto()
+ with open(self.original_tokenizer.vocab_file, "rb") as f:
+ m.ParseFromString(f.read())
+ self.proto = m
+
+ if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
+ warnings.warn(
+ "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
+ " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
+ " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
+ "unknown tokens into a sequence of byte tokens matching the original piece of text."
+ )
+
+ def vocab(self, proto):
+ return [(piece.piece, piece.score) for piece in proto.pieces]
+
+ def unk_id(self, proto):
+ return proto.trainer_spec.unk_id
+
+ def tokenizer(self, proto):
+ model_type = proto.trainer_spec.model_type
+ vocab_scores = self.vocab(proto)
+
+ if model_type == 1:
+ tokenizer = Tokenizer(
+ Unigram(
+ vocab_scores,
+ unk_id=self.unk_id(proto),
+ byte_fallback=self.handle_byte_fallback,
+ )
+ )
+
+ elif model_type == 2:
+ _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
+ bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
+ tokenizer = Tokenizer(
+ BPE(
+ bpe_vocab,
+ merges,
+ unk_token=proto.trainer_spec.unk_piece,
+ fuse_unk=True,
+ byte_fallback=self.handle_byte_fallback,
+ dropout=None,
+ )
+ )
+
+ else:
+ raise Exception(
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
+ )
+
+ # control tokens are special
+ # user defined symbols are not
+ # both user and control tokens are AddedTokens
+ # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
+ spm_added_tokens = [
+ (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
+ for id, p in enumerate(proto.pieces)
+ if p.type in [3, 4]
+ ]
+ tokenizer.add_tokens(
+ [
+ AddedToken(token, normalized=False, special=special)
+ for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
+ ]
+ )
+
+ return tokenizer
+
+ def normalizer(self, proto):
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+ _normalizers = [
+ normalizers.Strip(left=False, right=True), # stripping is important
+ normalizers.Replace(Regex(" {2,}"), "▁"),
+ ]
+ if not precompiled_charsmap:
+ return normalizers.Sequence(_normalizers)
+ else:
+ return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
+
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
+ return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
+
+ def post_processor(self):
+ return None
+
+ def decoder(self, replacement, add_prefix_space):
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
+ return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
+
+ def converted(self) -> Tokenizer:
+ tokenizer = self.tokenizer(self.proto)
+
+ # Tokenizer assemble
+ normalizer = self.normalizer(self.proto)
+ if normalizer is not None:
+ tokenizer.normalizer = normalizer
+
+ replacement = "▁"
+ add_prefix_space = True
+ if hasattr(self.original_tokenizer, "add_prefix_space"):
+ add_prefix_space = self.original_tokenizer.add_prefix_space
+
+ pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
+ if pre_tokenizer is not None:
+ tokenizer.pre_tokenizer = pre_tokenizer
+
+ tokenizer.decoder = self.decoder(replacement, add_prefix_space)
+ post_processor = self.post_processor()
+ if post_processor:
+ tokenizer.post_processor = post_processor
+
+ return tokenizer
+
+
+class AlbertConverter(SpmConverter):
+ def vocab(self, proto):
+ return [
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
+ for piece in proto.pieces
+ ]
+
+ def normalizer(self, proto):
+ list_normalizers = [
+ normalizers.Replace("``", '"'),
+ normalizers.Replace("''", '"'),
+ ]
+ if not self.original_tokenizer.keep_accents:
+ list_normalizers.append(normalizers.NFKD())
+ list_normalizers.append(normalizers.StripAccents())
+ if self.original_tokenizer.do_lower_case:
+ list_normalizers.append(normalizers.Lowercase())
+
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+
+ if precompiled_charsmap:
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
+
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
+ return normalizers.Sequence(list_normalizers)
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+
+class BarthezConverter(SpmConverter):
+ def unk_id(self, proto):
+ unk_id = 3
+ return unk_id
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=" $A ",
+ pair=" $A $B ",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class CamembertConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("NOTUSED", 0.0),
+ ("", 0.0),
+ ("NOTUSED", 0.0),
+ ("", 0.0),
+ ("NOTUSED", -100),
+ ]
+ # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ # See vocab unk position
+ return 3
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=" $A ",
+ pair=" $A $B ",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class DebertaV2Converter(SpmConverter):
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ list_pretokenizers = []
+ if self.original_tokenizer.split_by_punct:
+ list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
+ list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
+ return pre_tokenizers.Sequence(list_pretokenizers)
+
+ def normalizer(self, proto):
+ list_normalizers = []
+ if self.original_tokenizer.do_lower_case:
+ list_normalizers.append(normalizers.Lowercase())
+ list_normalizers.append(normalizers.Strip())
+
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+ if precompiled_charsmap:
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
+
+ return normalizers.Sequence(list_normalizers)
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+
+class MBartConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ vocab += [
+ ("ar_AR", 0.0),
+ ("cs_CZ", 0.0),
+ ("de_DE", 0.0),
+ ("en_XX", 0.0),
+ ("es_XX", 0.0),
+ ("et_EE", 0.0),
+ ("fi_FI", 0.0),
+ ("fr_XX", 0.0),
+ ("gu_IN", 0.0),
+ ("hi_IN", 0.0),
+ ("it_IT", 0.0),
+ ("ja_XX", 0.0),
+ ("kk_KZ", 0.0),
+ ("ko_KR", 0.0),
+ ("lt_LT", 0.0),
+ ("lv_LV", 0.0),
+ ("my_MM", 0.0),
+ ("ne_NP", 0.0),
+ ("nl_XX", 0.0),
+ ("ro_RO", 0.0),
+ ("ru_RU", 0.0),
+ ("si_LK", 0.0),
+ ("tr_TR", 0.0),
+ ("vi_VN", 0.0),
+ ("zh_CN", 0.0),
+ ]
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ return 3
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="$A en_XX",
+ pair="$A $B en_XX",
+ special_tokens=[
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class MBart50Converter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] # fmt: skip
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ return 3
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="en_XX $A ",
+ pair="en_XX $A $B ",
+ special_tokens=[
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class NllbConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ return vocab
+
+ def unk_id(self, proto):
+ return 3
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="eng_Latn $A ",
+ pair="eng_Latn $A $B ",
+ special_tokens=[
+ ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class SeamlessM4TConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ return vocab
+
+ def unk_id(self, proto):
+ return self.original_tokenizer.unk_token_id
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="__eng__ $A ",
+ pair="__eng__ $A $B ",
+ special_tokens=[
+ ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class XLMRobertaConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ vocab += [("", 0.0)]
+ return vocab
+
+ def unk_id(self, proto):
+ unk_id = 3
+ return unk_id
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=" $A ",
+ pair=" $A $B ",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class XLNetConverter(SpmConverter):
+ def vocab(self, proto):
+ return [
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
+ for piece in proto.pieces
+ ]
+
+ def normalizer(self, proto):
+ list_normalizers = [
+ normalizers.Replace("``", '"'),
+ normalizers.Replace("''", '"'),
+ ]
+ if not self.original_tokenizer.keep_accents:
+ list_normalizers.append(normalizers.NFKD())
+ list_normalizers.append(normalizers.StripAccents())
+ if self.original_tokenizer.do_lower_case:
+ list_normalizers.append(normalizers.Lowercase())
+
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+
+ if precompiled_charsmap:
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
+
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
+ return normalizers.Sequence(list_normalizers)
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="$A:0 :0 :2",
+ pair="$A:0 :0 $B:1 :1 :2",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class ReformerConverter(SpmConverter):
+ pass
+
+
+class RemBertConverter(SpmConverter):
+ # Inspired from AlbertConverter
+ def normalizer(self, proto):
+ list_normalizers = [
+ normalizers.Replace("``", '"'),
+ normalizers.Replace("''", '"'),
+ normalizers.Replace(Regex(" {2,}"), " "),
+ ]
+ if not self.original_tokenizer.keep_accents:
+ list_normalizers.append(normalizers.NFKD())
+ list_normalizers.append(normalizers.StripAccents())
+ if self.original_tokenizer.do_lower_case:
+ list_normalizers.append(normalizers.Lowercase())
+
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+
+ if precompiled_charsmap:
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
+
+ return normalizers.Sequence(list_normalizers)
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+
+class BertGenerationConverter(SpmConverter):
+ pass
+
+
+class PegasusConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ (self.original_tokenizer.pad_token, 0.0),
+ (self.original_tokenizer.eos_token, 0.0),
+ ]
+
+ if self.original_tokenizer.mask_token_sent is not None:
+ vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
+
+ if (
+ self.original_tokenizer.mask_token is not None
+ and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
+ ):
+ vocab += [(self.original_tokenizer.mask_token, 0.0)]
+
+ vocab += [(f"", -100.0) for i in range(2, self.original_tokenizer.offset)]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
+ return vocab
+
+ def unk_id(self, proto):
+ return proto.trainer_spec.unk_id + self.original_tokenizer.offset
+
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
+ return pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.WhitespaceSplit(),
+ pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
+ ]
+ )
+
+ def post_processor(self):
+ eos = self.original_tokenizer.eos_token
+ special_tokens = [
+ (eos, self.original_tokenizer.eos_token_id),
+ ]
+ return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
+
+
+class T5Converter(SpmConverter):
+ def vocab(self, proto):
+ num_extra_ids = self.original_tokenizer._extra_ids
+ vocab = [(piece.piece, piece.score) for piece in proto.pieces]
+ vocab += [(f"", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
+ return vocab
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=["$A", ""],
+ pair=["$A", "", "$B", ""],
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class UdopConverter(SpmConverter):
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=["$A", ""],
+ pair=["$A", "", "$B", ""],
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class WhisperConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+
+ prefix_token_ids = self.original_tokenizer.prefix_tokens
+ prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
+ eos = self.original_tokenizer.eos_token
+ eos_token_id = self.original_tokenizer.eos_token_id
+ prefix_template = " ".join([f"{token}:0" for token in prefixes])
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{prefix_template} $A:0 {eos}:0",
+ pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
+ special_tokens=[
+ (eos, eos_token_id),
+ *zip(prefixes, prefix_token_ids),
+ ],
+ )
+
+ return tokenizer
+
+
+class BigBirdConverter(SpmConverter):
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single="[CLS]:0 $A:0 [SEP]:0",
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
+ special_tokens=[
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
+ ],
+ )
+
+
+class CLIPConverter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.encoder
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
+ unk_token = self.original_tokenizer.unk_token
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ unk_token=str(unk_token),
+ )
+ )
+
+ tokenizer.normalizer = normalizers.Sequence(
+ [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.Split(
+ Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
+ behavior="removed",
+ invert=True,
+ ),
+ pre_tokenizers.ByteLevel(add_prefix_space=False),
+ ]
+ )
+ tokenizer.decoder = decoders.ByteLevel()
+
+ # Hack to have a ByteLevel and TemplaceProcessor
+ tokenizer.post_processor = processors.RobertaProcessing(
+ sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
+ cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
+ add_prefix_space=False,
+ trim_offsets=False,
+ )
+ return tokenizer
+
+
+class LayoutLMv2Converter(Converter):
+ def converted(self) -> Tokenizer:
+ vocab = self.original_tokenizer.vocab
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
+
+ tokenize_chinese_chars = False
+ strip_accents = False
+ do_lower_case = True
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
+
+ tokenizer.normalizer = normalizers.BertNormalizer(
+ clean_text=True,
+ handle_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ lowercase=do_lower_case,
+ )
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls}:0 $A:0 {sep}:0",
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
+
+ return tokenizer
+
+
+class BlenderbotConverter(Converter):
+ def converted(self) -> Tokenizer:
+ ot = self.original_tokenizer
+ vocab = ot.encoder
+ merges = list(ot.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"$A:0 {ot.eos_token}:0",
+ special_tokens=[
+ (ot.eos_token, ot.eos_token_id),
+ ],
+ )
+
+ return tokenizer
+
+
+class XGLMConverter(SpmConverter):
+ def vocab(self, proto):
+ vocab = [
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ("", 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ vocab += [("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0), ("", 0.0)] # fmt: skip
+ return vocab
+
+ def unk_id(self, proto):
+ unk_id = 3
+ return unk_id
+
+ def post_processor(self):
+ return processors.TemplateProcessing(
+ single=" $A",
+ pair=" $A $B",
+ special_tokens=[
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ("", self.original_tokenizer.convert_tokens_to_ids("")),
+ ],
+ )
+
+
+class GemmaConverter(SpmConverter):
+ handle_byte_fallback = True
+ SpmExtractor = GemmaSentencePieceExtractor
+ # start and end of turn tokens must be marked as special
+ special_tokens = {"", ""}
+
+ """"
+ split_by_unicode_script: true
+ split_by_number: true
+ split_by_whitespace: true
+ treat_whitespace_as_suffix: false
+ allow_whitespace_only_pieces: true
+ split_digits: true
+ byte_fallback: true
+ """
+
+ def normalizer(self, proto):
+ return normalizers.Replace(" ", "▁")
+
+ def vocab(self, proto):
+ vocab = [
+ (self.original_tokenizer.pad_token, 0.0),
+ (self.original_tokenizer.eos_token, 0.0),
+ (self.original_tokenizer.bos_token, 0.0),
+ ]
+ for piece in proto.pieces[3:]:
+ if piece.piece == "<0x09>":
+ vocab += [("\t", piece.score)]
+ else:
+ vocab += [(piece.piece, piece.score)]
+ # vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ return vocab
+
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ return pre_tokenizers.Split(" ", "merged_with_previous")
+
+ def unk_id(self, proto):
+ unk_id = 3
+ return unk_id
+
+ def decoder(self, replacement, add_prefix_space):
+ return decoders.Sequence(
+ [
+ decoders.Replace("▁", " "),
+ decoders.ByteFallback(),
+ decoders.Fuse(),
+ ]
+ )
+
+
+class LlamaConverter(SpmConverter):
+ handle_byte_fallback = True
+
+ def vocab(self, proto):
+ vocab = [
+ (self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
+ (self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
+ (self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ return vocab
+
+ def unk_id(self, proto):
+ unk_id = 0
+ return unk_id
+
+ def decoder(self, replacement, add_prefix_space):
+ sequence = [
+ decoders.Replace("▁", " "),
+ decoders.ByteFallback(),
+ decoders.Fuse(),
+ ]
+ if add_prefix_space:
+ sequence += [decoders.Strip(content=" ", left=1)]
+ return decoders.Sequence(sequence)
+
+ def normalizer(self, proto):
+ if getattr(self.original_tokenizer, "legacy", True):
+ sequence = []
+ if getattr(self.original_tokenizer, "add_prefix_space", True):
+ sequence += [normalizers.Prepend(prepend="▁")]
+ sequence += [normalizers.Replace(pattern=" ", content="▁")]
+ return normalizers.Sequence(sequence)
+ return None # non-legacy, no normalizer
+
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
+ return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
+ return None
+
+ def post_processor(self):
+ # the processor is defined in the LlamaTokenizerFast class.
+ return None
+
+
+class MarkupLMConverter(Converter):
+ def converted(self) -> Tokenizer:
+ ot = self.original_tokenizer
+ vocab = ot.encoder
+ merges = list(ot.bpe_ranks.keys())
+
+ tokenizer = Tokenizer(
+ BPE(
+ vocab=vocab,
+ merges=merges,
+ dropout=None,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ unk_token=self.original_tokenizer.unk_token,
+ )
+ )
+
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
+ tokenizer.decoder = decoders.ByteLevel()
+
+ cls = str(self.original_tokenizer.cls_token)
+ sep = str(self.original_tokenizer.sep_token)
+ cls_token_id = self.original_tokenizer.cls_token_id
+ sep_token_id = self.original_tokenizer.sep_token_id
+
+ tokenizer.post_processor = processors.TemplateProcessing(
+ single=f"{cls} $A {sep}",
+ pair=f"{cls} $A {sep} $B {sep}",
+ special_tokens=[
+ (cls, cls_token_id),
+ (sep, sep_token_id),
+ ],
+ )
+
+ return tokenizer
+
+
+class MoshiConverter(SpmConverter):
+ handle_byte_fallback = True
+
+ def __init__(self, vocab_file, model_max_length=None, **kwargs):
+ requires_backends(self, "protobuf")
+
+ Converter.__init__(self, vocab_file)
+
+ # from .utils import sentencepiece_model_pb2 as model_pb2
+ model_pb2 = import_protobuf()
+
+ m = model_pb2.ModelProto()
+ with open(vocab_file, "rb") as f:
+ m.ParseFromString(f.read())
+ self.proto = m
+
+ def normalizer(self, proto):
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
+ _normalizers = [
+ normalizers.Replace(" ", "▁"),
+ ]
+ if not precompiled_charsmap:
+ return normalizers.Sequence(_normalizers)
+ else:
+ return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
+
+ def decoder(self, replacement, add_prefix_space):
+ sequence = [
+ decoders.Replace("▁", " "),
+ decoders.ByteFallback(),
+ decoders.Fuse(),
+ ]
+ if add_prefix_space:
+ sequence += [decoders.Strip(content=" ", left=1)]
+ return decoders.Sequence(sequence)
+
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ prepend_scheme = "first"
+ return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
+
+
+# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on.
+
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
+ tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+class TikTokenConverter:
+ """
+ A general tiktoken converter.
+ """
+
+ def __init__(
+ self,
+ vocab_file=None,
+ pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
+ add_prefix_space=False,
+ additional_special_tokens=None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args)
+ self.vocab_file = vocab_file
+ self.pattern = pattern
+ self.add_prefix_space = add_prefix_space
+ self.additional_special_tokens = additional_special_tokens
+
+ def extract_vocab_merges_from_model(self, tiktoken_url: str):
+ try:
+ from tiktoken.load import load_tiktoken_bpe
+ except Exception:
+ raise ValueError(
+ "`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`."
+ )
+
+ bpe_ranks = load_tiktoken_bpe(tiktoken_url)
+ byte_encoder = bytes_to_unicode()
+
+ def token_bytes_to_string(b):
+ return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
+
+ merges = []
+ vocab = {}
+ for token, rank in bpe_ranks.items():
+ vocab[token_bytes_to_string(token)] = rank
+ if len(token) == 1:
+ continue
+ local = []
+ for index in range(1, len(token)):
+ piece_l, piece_r = token[:index], token[index:]
+ if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
+ local.append((piece_l, piece_r, rank))
+ local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
+ merges.extend(local)
+ merges = sorted(merges, key=lambda val: val[2], reverse=False)
+ merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
+ return vocab, merges
+
+ def tokenizer(self):
+ vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
+ tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
+ if hasattr(tokenizer.model, "ignore_merges"):
+ tokenizer.model.ignore_merges = True
+ return tokenizer
+
+ def converted(self) -> Tokenizer:
+ tokenizer = self.tokenizer()
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
+ pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
+ ]
+ )
+ tokenizer.decoder = decoders.ByteLevel()
+ tokenizer.add_special_tokens(self.additional_special_tokens)
+
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
+
+ return tokenizer
+
+
+SLOW_TO_FAST_CONVERTERS = {
+ "AlbertTokenizer": AlbertConverter,
+ "BartTokenizer": RobertaConverter,
+ "BarthezTokenizer": BarthezConverter,
+ "BertTokenizer": BertConverter,
+ "BigBirdTokenizer": BigBirdConverter,
+ "BlenderbotTokenizer": BlenderbotConverter,
+ "CamembertTokenizer": CamembertConverter,
+ "CLIPTokenizer": CLIPConverter,
+ "CodeGenTokenizer": GPT2Converter,
+ "ConvBertTokenizer": BertConverter,
+ "DebertaTokenizer": DebertaConverter,
+ "DebertaV2Tokenizer": DebertaV2Converter,
+ "DistilBertTokenizer": BertConverter,
+ "DPRReaderTokenizer": BertConverter,
+ "DPRQuestionEncoderTokenizer": BertConverter,
+ "DPRContextEncoderTokenizer": BertConverter,
+ "ElectraTokenizer": BertConverter,
+ "FNetTokenizer": AlbertConverter,
+ "FunnelTokenizer": FunnelConverter,
+ "GPT2Tokenizer": GPT2Converter,
+ "HerbertTokenizer": HerbertConverter,
+ "LayoutLMTokenizer": BertConverter,
+ "LayoutLMv2Tokenizer": BertConverter,
+ "LayoutLMv3Tokenizer": RobertaConverter,
+ "LayoutXLMTokenizer": XLMRobertaConverter,
+ "LongformerTokenizer": RobertaConverter,
+ "LEDTokenizer": RobertaConverter,
+ "LxmertTokenizer": BertConverter,
+ "MarkupLMTokenizer": MarkupLMConverter,
+ "MBartTokenizer": MBartConverter,
+ "MBart50Tokenizer": MBart50Converter,
+ "MPNetTokenizer": MPNetConverter,
+ "MobileBertTokenizer": BertConverter,
+ "MvpTokenizer": RobertaConverter,
+ "NllbTokenizer": NllbConverter,
+ "OpenAIGPTTokenizer": OpenAIGPTConverter,
+ "PegasusTokenizer": PegasusConverter,
+ "Qwen2Tokenizer": Qwen2Converter,
+ "RealmTokenizer": BertConverter,
+ "ReformerTokenizer": ReformerConverter,
+ "RemBertTokenizer": RemBertConverter,
+ "RetriBertTokenizer": BertConverter,
+ "RobertaTokenizer": RobertaConverter,
+ "RoFormerTokenizer": RoFormerConverter,
+ "SeamlessM4TTokenizer": SeamlessM4TConverter,
+ "SqueezeBertTokenizer": BertConverter,
+ "T5Tokenizer": T5Converter,
+ "UdopTokenizer": UdopConverter,
+ "WhisperTokenizer": WhisperConverter,
+ "XLMRobertaTokenizer": XLMRobertaConverter,
+ "XLNetTokenizer": XLNetConverter,
+ "SplinterTokenizer": SplinterConverter,
+ "XGLMTokenizer": XGLMConverter,
+ "LlamaTokenizer": LlamaConverter,
+ "CodeLlamaTokenizer": LlamaConverter,
+ "GemmaTokenizer": GemmaConverter,
+ "Phi3Tokenizer": LlamaConverter,
+}
+
+
+def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
+ """
+ Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
+
+ Args:
+ transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
+ Instance of a slow tokenizer to convert in the backend tokenizer for
+ [`~tokenization_utils_base.PreTrainedTokenizerFast`].
+ from_tiktoken (bool, optional): Whether to use the `tiktoken` library to convert the tokenizer instead of sentencepiece.
+ Defaults to False.
+
+ Return:
+ A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
+ [`~tokenization_utils_base.PreTrainedTokenizerFast`]
+ """
+
+ tokenizer_class_name = transformer_tokenizer.__class__.__name__
+ if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
+ converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
+ return converter_class(transformer_tokenizer).converted()
+
+ else:
+ try:
+ logger.info("Converting from Tiktoken")
+ return TikTokenConverter(
+ vocab_file=transformer_tokenizer.vocab_file,
+ additional_special_tokens=transformer_tokenizer.additional_special_tokens,
+ ).converted()
+ except Exception:
+ raise ValueError(
+ f"Converting from Tiktoken failed, if a converter for SentencePiece is available, provide a model path "
+ f"with a SentencePiece tokenizer.model file."
+ f"Currently available slow->fast convertors: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
+ )
diff --git a/.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizers_checkpoints_to_fast.py b/.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizers_checkpoints_to_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b93e4c53ff891e70a5ce33a8868237c430b1b18
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/convert_slow_tokenizers_checkpoints_to_fast.py
@@ -0,0 +1,130 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert slow tokenizers checkpoints in fast (serialization format of the `tokenizers` library)"""
+
+import argparse
+import os
+
+import transformers
+
+from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
+from .utils import logging
+
+
+logging.set_verbosity_info()
+
+logger = logging.get_logger(__name__)
+
+
+TOKENIZER_CLASSES = {
+ # Phi3 uses Llama tokenizer
+ name: getattr(transformers, "LlamaTokenizerFast" if name == "Phi3Tokenizer" else name + "Fast")
+ for name in SLOW_TO_FAST_CONVERTERS
+}
+
+
+def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download):
+ if tokenizer_name is not None and tokenizer_name not in TOKENIZER_CLASSES:
+ raise ValueError(f"Unrecognized tokenizer name, should be one of {list(TOKENIZER_CLASSES.keys())}.")
+
+ if tokenizer_name is None:
+ tokenizer_names = TOKENIZER_CLASSES
+ else:
+ tokenizer_names = {tokenizer_name: getattr(transformers, tokenizer_name + "Fast")}
+
+ logger.info(f"Loading tokenizer classes: {tokenizer_names}")
+
+ for tokenizer_name in tokenizer_names:
+ tokenizer_class = TOKENIZER_CLASSES[tokenizer_name]
+
+ add_prefix = True
+ if checkpoint_name is None:
+ checkpoint_names = list(tokenizer_class.max_model_input_sizes.keys())
+ else:
+ checkpoint_names = [checkpoint_name]
+
+ logger.info(f"For tokenizer {tokenizer_class.__class__.__name__} loading checkpoints: {checkpoint_names}")
+
+ for checkpoint in checkpoint_names:
+ logger.info(f"Loading {tokenizer_class.__class__.__name__} {checkpoint}")
+
+ # Load tokenizer
+ tokenizer = tokenizer_class.from_pretrained(checkpoint, force_download=force_download)
+
+ # Save fast tokenizer
+ logger.info(f"Save fast tokenizer to {dump_path} with prefix {checkpoint} add_prefix {add_prefix}")
+
+ # For organization names we create sub-directories
+ if "/" in checkpoint:
+ checkpoint_directory, checkpoint_prefix_name = checkpoint.split("/")
+ dump_path_full = os.path.join(dump_path, checkpoint_directory)
+ elif add_prefix:
+ checkpoint_prefix_name = checkpoint
+ dump_path_full = dump_path
+ else:
+ checkpoint_prefix_name = None
+ dump_path_full = dump_path
+
+ logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
+
+ if checkpoint in list(tokenizer.pretrained_vocab_files_map.values())[0]:
+ file_path = list(tokenizer.pretrained_vocab_files_map.values())[0][checkpoint]
+ next_char = file_path.split(checkpoint)[-1][0]
+ if next_char == "/":
+ dump_path_full = os.path.join(dump_path_full, checkpoint_prefix_name)
+ checkpoint_prefix_name = None
+
+ logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
+
+ file_names = tokenizer.save_pretrained(
+ dump_path_full, legacy_format=False, filename_prefix=checkpoint_prefix_name
+ )
+ logger.info(f"=> File names {file_names}")
+
+ for file_name in file_names:
+ if not file_name.endswith("tokenizer.json"):
+ os.remove(file_name)
+ logger.info(f"=> removing {file_name}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--dump_path", default=None, type=str, required=True, help="Path to output generated fast tokenizer files."
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ default=None,
+ type=str,
+ help=(
+ f"Optional tokenizer type selected in the list of {list(TOKENIZER_CLASSES.keys())}. If not given, will "
+ "download and convert all the checkpoints from AWS."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoint_name",
+ default=None,
+ type=str,
+ help="Optional checkpoint name. If not given, will download and convert the canonical checkpoints from AWS.",
+ )
+ parser.add_argument(
+ "--force_download",
+ action="store_true",
+ help="Re-download checkpoints.",
+ )
+ args = parser.parse_args()
+
+ convert_slow_checkpoint_to_fast(args.tokenizer_name, args.checkpoint_name, args.dump_path, args.force_download)
diff --git a/.venv/lib/python3.11/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py b/.venv/lib/python3.11/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ccb033b3df1de87a29bfd608090386c16593c5f
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py
@@ -0,0 +1,87 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Seq2Seq TF Hub checkpoint."""
+
+import argparse
+
+from . import (
+ BertConfig,
+ BertGenerationConfig,
+ BertGenerationDecoder,
+ BertGenerationEncoder,
+ load_tf_weights_in_bert_generation,
+ logging,
+)
+
+
+logging.set_verbosity_info()
+
+
+def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder):
+ # Initialise PyTorch model
+ bert_config = BertConfig.from_pretrained(
+ "google-bert/bert-large-cased",
+ vocab_size=vocab_size,
+ max_position_embeddings=512,
+ is_decoder=True,
+ add_cross_attention=True,
+ )
+ bert_config_dict = bert_config.to_dict()
+ del bert_config_dict["type_vocab_size"]
+ config = BertGenerationConfig(**bert_config_dict)
+ if is_encoder:
+ model = BertGenerationEncoder(config)
+ else:
+ model = BertGenerationDecoder(config)
+ print(f"Building PyTorch model from configuration: {config}")
+
+ # Load weights from tf checkpoint
+ load_tf_weights_in_bert_generation(
+ model,
+ tf_hub_path,
+ model_class="bert",
+ is_encoder_named_decoder=is_encoder_named_decoder,
+ is_encoder=is_encoder,
+ )
+
+ # Save pytorch-model
+ print(f"Save PyTorch model and config to {pytorch_dump_path}")
+ model.save_pretrained(pytorch_dump_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
+ )
+ parser.add_argument(
+ "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
+ )
+ parser.add_argument(
+ "--is_encoder_named_decoder",
+ action="store_true",
+ help="If decoder has to be renamed to encoder in PyTorch model.",
+ )
+ parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.")
+ parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model")
+ args = parser.parse_args()
+ convert_tf_checkpoint_to_pytorch(
+ args.tf_hub_path,
+ args.pytorch_dump_path,
+ args.is_encoder_named_decoder,
+ args.vocab_size,
+ is_encoder=args.is_encoder,
+ )
diff --git a/.venv/lib/python3.11/site-packages/transformers/debug_utils.py b/.venv/lib/python3.11/site-packages/transformers/debug_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbceb1d849076999c6821556accaea05e53a9ff9
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/debug_utils.py
@@ -0,0 +1,346 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+
+from .utils import ExplicitEnum, is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+class DebugUnderflowOverflow:
+ """
+ This debug class helps detect and understand where the model starts getting very large or very small, and more
+ importantly `nan` or `inf` weight and activation elements.
+
+ There are 2 working modes:
+
+ 1. Underflow/overflow detection (default)
+ 2. Specific batch absolute min/max tracing without detection
+
+ Mode 1: Underflow/overflow detection
+
+ To activate the underflow/overflow detection, initialize the object with the model :
+
+ ```python
+ debug_overflow = DebugUnderflowOverflow(model)
+ ```
+
+ then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output
+ elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event,
+ each frame reporting
+
+ 1. the fully qualified module name plus the class name whose `forward` was run
+ 2. the absolute min and max value of all elements for each module weights, and the inputs and output
+
+ For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16
+ mixed precision :
+
+ ```
+ Detected inf/nan during batch_number=0
+ Last 21 forward frames:
+ abs min abs max metadata
+ [...]
+ encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
+ 2.17e-07 4.50e+00 weight
+ 1.79e-06 4.65e+00 input[0]
+ 2.68e-06 3.70e+01 output
+ encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
+ 8.08e-07 2.66e+01 weight
+ 1.79e-06 4.65e+00 input[0]
+ 1.27e-04 2.37e+02 output
+ encoder.block.2.layer.1.DenseReluDense.wo Linear
+ 1.01e-06 6.44e+00 weight
+ 0.00e+00 9.74e+03 input[0]
+ 3.18e-04 6.27e+04 output
+ encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
+ 1.79e-06 4.65e+00 input[0]
+ 3.18e-04 6.27e+04 output
+ encoder.block.2.layer.1.dropout Dropout
+ 3.18e-04 6.27e+04 input[0]
+ 0.00e+00 inf output
+ ```
+
+ You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was
+ around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
+ renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
+ 64K, and we get an overlow.
+
+ As you can see it's the previous frames that we need to look into when the numbers start going into very large for
+ fp16 numbers.
+
+ The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
+
+ By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
+
+ ```python
+ debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
+ ```
+
+ To validate that you have set up this debugging feature correctly, and you intend to use it in a training that
+ may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in
+ the next section.
+
+
+ Mode 2. Specific batch absolute min/max tracing without detection
+
+ The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
+
+ Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
+ given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
+
+ ```python
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])
+ ```
+
+ And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
+
+ This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
+ fast-forward right to that area.
+
+
+ Early stopping:
+
+ You can also specify the batch number after which to stop the training, with :
+
+ ```python
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
+ ```
+
+ This feature is mainly useful in the tracing mode, but you can use it for any mode.
+
+
+ **Performance**:
+
+ As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training
+ down. Therefore remember to turn it off once the debugging needs have been met.
+
+ Args:
+ model (`nn.Module`):
+ The model to debug.
+ max_frames_to_save (`int`, *optional*, defaults to 21):
+ How many frames back to record
+ trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
+ Which batch numbers to trace (turns detection off)
+ abort_after_batch_num (`int``, *optional*):
+ Whether to abort after a certain batch number has finished
+ """
+
+ def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
+ self.model = model
+ self.trace_batch_nums = trace_batch_nums
+ self.abort_after_batch_num = abort_after_batch_num
+
+ # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
+ self.frames = collections.deque([], max_frames_to_save)
+ self.frame = []
+ self.batch_number = 0
+ self.total_calls = 0
+ self.detected_overflow = False
+ self.prefix = " "
+
+ self.analyse_model()
+
+ self.register_forward_hook()
+
+ def save_frame(self, frame=None):
+ if frame is not None:
+ self.expand_frame(frame)
+ self.frames.append("\n".join(self.frame))
+ self.frame = [] # start a new frame
+
+ def expand_frame(self, line):
+ self.frame.append(line)
+
+ def trace_frames(self):
+ print("\n".join(self.frames))
+ self.frames = []
+
+ def reset_saved_frames(self):
+ self.frames = []
+
+ def dump_saved_frames(self):
+ print(f"\nDetected inf/nan during batch_number={self.batch_number}")
+ print(f"Last {len(self.frames)} forward frames:")
+ print(f"{'abs min':8} {'abs max':8} metadata")
+ print("\n".join(self.frames))
+ print("\n\n")
+ self.frames = []
+
+ def analyse_model(self):
+ # extract the fully qualified module names, to be able to report at run time. e.g.:
+ # encoder.block.2.layer.0.SelfAttention.o
+ #
+ # for shared weights only the first shared module name will be registered
+ self.module_names = {m: name for name, m in self.model.named_modules()}
+ # self.longest_module_name = max(len(v) for v in self.module_names.values())
+
+ def analyse_variable(self, var, ctx):
+ if torch.is_tensor(var):
+ self.expand_frame(get_abs_min_max(var, ctx))
+ if detect_overflow(var, ctx):
+ self.detected_overflow = True
+ elif var is None:
+ self.expand_frame(f"{'None':>17} {ctx}")
+ else:
+ self.expand_frame(f"{'not a tensor':>17} {ctx}")
+
+ def batch_start_frame(self):
+ self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
+ self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
+
+ def batch_end_frame(self):
+ self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n")
+
+ def create_frame(self, module, input, output):
+ self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
+
+ # params
+ for name, p in module.named_parameters(recurse=False):
+ self.analyse_variable(p, name)
+
+ # inputs
+ if isinstance(input, tuple):
+ for i, x in enumerate(input):
+ self.analyse_variable(x, f"input[{i}]")
+ else:
+ self.analyse_variable(input, "input")
+
+ # outputs
+ if isinstance(output, tuple):
+ for i, x in enumerate(output):
+ # possibly a tuple of tuples
+ if isinstance(x, tuple):
+ for j, y in enumerate(x):
+ self.analyse_variable(y, f"output[{i}][{j}]")
+ else:
+ self.analyse_variable(x, f"output[{i}]")
+ else:
+ self.analyse_variable(output, "output")
+
+ self.save_frame()
+
+ def register_forward_hook(self):
+ self.model.apply(self._register_forward_hook)
+
+ def _register_forward_hook(self, module):
+ module.register_forward_hook(self.forward_hook)
+
+ def forward_hook(self, module, input, output):
+ # - input is a tuple of packed inputs (could be non-Tensors)
+ # - output could be a Tensor or a tuple of Tensors and non-Tensors
+
+ last_frame_of_batch = False
+
+ trace_mode = True if self.batch_number in self.trace_batch_nums else False
+ if trace_mode:
+ self.reset_saved_frames()
+
+ if self.total_calls == 0:
+ self.batch_start_frame()
+ self.total_calls += 1
+
+ # count batch numbers - the very first forward hook of the batch will be called when the
+ # batch completes - i.e. it gets called very last - we know this batch has finished
+ if module == self.model:
+ self.batch_number += 1
+ last_frame_of_batch = True
+
+ self.create_frame(module, input, output)
+
+ # if last_frame_of_batch:
+ # self.batch_end_frame()
+
+ if trace_mode:
+ self.trace_frames()
+
+ if last_frame_of_batch:
+ self.batch_start_frame()
+
+ if self.detected_overflow and not trace_mode:
+ self.dump_saved_frames()
+
+ # now we can abort, as it's pointless to continue running
+ raise ValueError(
+ "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
+ "Please scroll up above this traceback to see the activation values prior to this event."
+ )
+
+ # abort after certain batch if requested to do so
+ if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
+ raise ValueError(
+ f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
+ f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
+ )
+
+
+def get_abs_min_max(var, ctx):
+ abs_var = var.abs()
+ return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
+
+
+def detect_overflow(var, ctx):
+ """
+ Report whether the tensor contains any `nan` or `inf` entries.
+
+ This is useful for detecting overflows/underflows and best to call right after the function that did some math that
+ modified the tensor in question.
+
+ This function contains a few other helper features that you can enable and tweak directly if you want to track
+ various other things.
+
+ Args:
+ var: the tensor variable to check
+ ctx: the message to print as a context
+
+ Return:
+ `True` if `inf` or `nan` was detected, `False` otherwise
+ """
+ detected = False
+ if torch.isnan(var).any().item():
+ detected = True
+ print(f"{ctx} has nans")
+ if torch.isinf(var).any().item():
+ detected = True
+ print(f"{ctx} has infs")
+
+ # if needed to monitor large elements can enable the following
+ if 0: # and detected:
+ n100 = var[torch.ge(var.abs(), 100)]
+ if n100.numel() > 0:
+ print(f"{ctx}: n100={n100.numel()}")
+ n1000 = var[torch.ge(var.abs(), 1000)]
+ if n1000.numel() > 0:
+ print(f"{ctx}: n1000={n1000.numel()}")
+ n10000 = var[torch.ge(var.abs(), 10000)]
+ if n10000.numel() > 0:
+ print(f"{ctx}: n10000={n10000.numel()}")
+
+ if 0:
+ print(f"min={var.min():9.2e} max={var.max():9.2e}")
+
+ if 0:
+ print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
+
+ return detected
+
+
+class DebugOption(ExplicitEnum):
+ UNDERFLOW_OVERFLOW = "underflow_overflow"
+ TPU_METRICS_DEBUG = "tpu_metrics_debug"
diff --git a/.venv/lib/python3.11/site-packages/transformers/dependency_versions_check.py b/.venv/lib/python3.11/site-packages/transformers/dependency_versions_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..82d07850847ec357f36ff51088ddec36aceff093
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/dependency_versions_check.py
@@ -0,0 +1,63 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .dependency_versions_table import deps
+from .utils.versions import require_version, require_version_core
+
+
+# define which module versions we always want to check at run time
+# (usually the ones defined in `install_requires` in setup.py)
+#
+# order specific notes:
+# - tqdm must be checked before tokenizers
+
+pkgs_to_check_at_runtime = [
+ "python",
+ "tqdm",
+ "regex",
+ "requests",
+ "packaging",
+ "filelock",
+ "numpy",
+ "tokenizers",
+ "huggingface-hub",
+ "safetensors",
+ "accelerate",
+ "pyyaml",
+]
+
+for pkg in pkgs_to_check_at_runtime:
+ if pkg in deps:
+ if pkg == "tokenizers":
+ # must be loaded here, or else tqdm check may fail
+ from .utils import is_tokenizers_available
+
+ if not is_tokenizers_available():
+ continue # not required, check version only if installed
+ elif pkg == "accelerate":
+ # must be loaded here, or else tqdm check may fail
+ from .utils import is_accelerate_available
+
+ # Maybe switch to is_torch_available in the future here so that Accelerate is hard dep of
+ # Transformers with PyTorch
+ if not is_accelerate_available():
+ continue # not required, check version only if installed
+
+ require_version_core(deps[pkg])
+ else:
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
+
+
+def dep_version_check(pkg, hint=None):
+ require_version(deps[pkg], hint)
diff --git a/.venv/lib/python3.11/site-packages/transformers/dependency_versions_table.py b/.venv/lib/python3.11/site-packages/transformers/dependency_versions_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..26500c22b167b1894d9038bddff08cb949154405
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/dependency_versions_table.py
@@ -0,0 +1,102 @@
+# THIS FILE HAS BEEN AUTOGENERATED. To update:
+# 1. modify the `_deps` dict in setup.py
+# 2. run `make deps_table_update``
+deps = {
+ "Pillow": "Pillow>=10.0.1,<=15.0",
+ "accelerate": "accelerate>=0.26.0",
+ "av": "av==9.2.0",
+ "beautifulsoup4": "beautifulsoup4",
+ "blobfile": "blobfile",
+ "codecarbon": "codecarbon>=2.8.1",
+ "cookiecutter": "cookiecutter==1.7.3",
+ "dataclasses": "dataclasses",
+ "datasets": "datasets!=2.5.0",
+ "deepspeed": "deepspeed>=0.9.3",
+ "diffusers": "diffusers",
+ "dill": "dill<0.3.5",
+ "evaluate": "evaluate>=0.2.0",
+ "faiss-cpu": "faiss-cpu",
+ "fastapi": "fastapi",
+ "filelock": "filelock",
+ "flax": "flax>=0.4.1,<=0.7.0",
+ "fsspec": "fsspec<2023.10.0",
+ "ftfy": "ftfy",
+ "fugashi": "fugashi>=1.0",
+ "GitPython": "GitPython<3.1.19",
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
+ "huggingface-hub": "huggingface-hub>=0.24.0,<1.0",
+ "importlib_metadata": "importlib_metadata",
+ "ipadic": "ipadic>=1.0.0,<2.0",
+ "isort": "isort>=5.5.4",
+ "jax": "jax>=0.4.1,<=0.4.13",
+ "jaxlib": "jaxlib>=0.4.1,<=0.4.13",
+ "jieba": "jieba",
+ "jinja2": "jinja2>=3.1.0",
+ "kenlm": "kenlm",
+ "keras": "keras>2.9,<2.16",
+ "keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
+ "librosa": "librosa",
+ "nltk": "nltk<=3.8.1",
+ "natten": "natten>=0.14.6,<0.15.0",
+ "numpy": "numpy>=1.17",
+ "onnxconverter-common": "onnxconverter-common",
+ "onnxruntime-tools": "onnxruntime-tools>=1.4.2",
+ "onnxruntime": "onnxruntime>=1.4.0",
+ "opencv-python": "opencv-python",
+ "optimum-benchmark": "optimum-benchmark>=0.3.0",
+ "optuna": "optuna",
+ "optax": "optax>=0.0.8,<=0.1.4",
+ "packaging": "packaging>=20.0",
+ "parameterized": "parameterized",
+ "phonemizer": "phonemizer",
+ "protobuf": "protobuf",
+ "psutil": "psutil",
+ "pyyaml": "pyyaml>=5.1",
+ "pydantic": "pydantic",
+ "pytest": "pytest>=7.2.0,<8.0.0",
+ "pytest-asyncio": "pytest-asyncio",
+ "pytest-timeout": "pytest-timeout",
+ "pytest-xdist": "pytest-xdist",
+ "python": "python>=3.9.0",
+ "ray[tune]": "ray[tune]>=2.7.0",
+ "regex": "regex!=2019.12.17",
+ "requests": "requests",
+ "rhoknp": "rhoknp>=1.1.0,<1.3.1",
+ "rjieba": "rjieba",
+ "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
+ "ruff": "ruff==0.5.1",
+ "sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
+ "sacremoses": "sacremoses",
+ "safetensors": "safetensors>=0.4.1",
+ "sagemaker": "sagemaker>=2.31.0",
+ "schedulefree": "schedulefree>=1.2.6",
+ "scikit-learn": "scikit-learn",
+ "scipy": "scipy<1.13.0",
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
+ "sigopt": "sigopt",
+ "starlette": "starlette",
+ "sudachipy": "sudachipy>=0.6.6",
+ "sudachidict_core": "sudachidict_core>=20220729",
+ "tensorboard": "tensorboard",
+ "tensorflow-cpu": "tensorflow-cpu>2.9,<2.16",
+ "tensorflow": "tensorflow>2.9,<2.16",
+ "tensorflow-text": "tensorflow-text<2.16",
+ "tensorflow-probability": "tensorflow-probability<0.24",
+ "tf2onnx": "tf2onnx",
+ "timeout-decorator": "timeout-decorator",
+ "tiktoken": "tiktoken",
+ "timm": "timm<=1.0.11",
+ "tokenizers": "tokenizers>=0.21,<0.22",
+ "torch": "torch>=2.0",
+ "torchaudio": "torchaudio",
+ "torchvision": "torchvision",
+ "pyctcdecode": "pyctcdecode>=0.4.0",
+ "tqdm": "tqdm>=4.27",
+ "unidic": "unidic>=1.0.2",
+ "unidic_lite": "unidic_lite>=1.0.7",
+ "urllib3": "urllib3<2.0.0",
+ "uvicorn": "uvicorn",
+ "pytest-rich": "pytest-rich",
+ "libcst": "libcst",
+ "rich": "rich",
+}
diff --git a/.venv/lib/python3.11/site-packages/transformers/dynamic_module_utils.py b/.venv/lib/python3.11/site-packages/transformers/dynamic_module_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf44d4b427cf7b7bf76e1c550fa08b3dbc56b673
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/dynamic_module_utils.py
@@ -0,0 +1,685 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities to dynamically load objects from the Hub."""
+
+import filecmp
+import hashlib
+import importlib
+import importlib.util
+import os
+import re
+import shutil
+import signal
+import sys
+import threading
+import typing
+import warnings
+from pathlib import Path
+from types import ModuleType
+from typing import Any, Dict, List, Optional, Union
+
+from huggingface_hub import try_to_load_from_cache
+
+from .utils import (
+ HF_MODULES_CACHE,
+ TRANSFORMERS_DYNAMIC_MODULE_NAME,
+ cached_file,
+ extract_commit_hash,
+ is_offline_mode,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+_HF_REMOTE_CODE_LOCK = threading.Lock()
+
+
+def init_hf_modules():
+ """
+ Creates the cache directory for modules with an init, and adds it to the Python path.
+ """
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
+ if HF_MODULES_CACHE in sys.path:
+ return
+
+ sys.path.append(HF_MODULES_CACHE)
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+ importlib.invalidate_caches()
+
+
+def create_dynamic_module(name: Union[str, os.PathLike]) -> None:
+ """
+ Creates a dynamic module in the cache directory for modules.
+
+ Args:
+ name (`str` or `os.PathLike`):
+ The name of the dynamic module to create.
+ """
+ init_hf_modules()
+ dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
+ # If the parent module does not exist yet, recursively create it.
+ if not dynamic_module_path.parent.exists():
+ create_dynamic_module(dynamic_module_path.parent)
+ os.makedirs(dynamic_module_path, exist_ok=True)
+ init_path = dynamic_module_path / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+ # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
+ # with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
+ importlib.invalidate_caches()
+
+
+def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
+ """
+ Get the list of modules that are relatively imported in a module file.
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+
+ Returns:
+ `List[str]`: The list of relative imports in the module.
+ """
+ with open(module_file, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import .xxx`
+ relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from .xxx import yyy`
+ relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
+ # Unique-ify
+ return list(set(relative_imports))
+
+
+def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
+ """
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
+ imports (if a imports b and b imports c, it will return module files for b and c).
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+
+ Returns:
+ `List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
+ of module files a given module needs.
+ """
+ no_change = False
+ files_to_check = [module_file]
+ all_relative_imports = []
+
+ # Let's recurse through all relative imports
+ while not no_change:
+ new_imports = []
+ for f in files_to_check:
+ new_imports.extend(get_relative_imports(f))
+
+ module_path = Path(module_file).parent
+ new_import_files = [str(module_path / m) for m in new_imports]
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
+ files_to_check = [f"{f}.py" for f in new_import_files]
+
+ no_change = len(new_import_files) == 0
+ all_relative_imports.extend(files_to_check)
+
+ return all_relative_imports
+
+
+def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
+ """
+ Extracts all the libraries (not relative imports this time) that are imported in a file.
+
+ Args:
+ filename (`str` or `os.PathLike`): The module file to inspect.
+
+ Returns:
+ `List[str]`: The list of all packages required to use the input module.
+ """
+ with open(filename, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # filter out try/except block so in custom code we can have try/except imports
+ content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL)
+
+ # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
+ content = re.sub(
+ r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", "", content, flags=re.MULTILINE
+ )
+
+ # Imports of the form `import xxx`
+ imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from xxx import yyy`
+ imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
+ # Only keep the top-level module
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
+ return list(set(imports))
+
+
+def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
+ """
+ Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
+ library is missing.
+
+ Args:
+ filename (`str` or `os.PathLike`): The module file to check.
+
+ Returns:
+ `List[str]`: The list of relative imports in the file.
+ """
+ imports = get_imports(filename)
+ missing_packages = []
+ for imp in imports:
+ try:
+ importlib.import_module(imp)
+ except ImportError as exception:
+ logger.warning(f"Encountered exception while importing {imp}: {exception}")
+ # Some packages can fail with an ImportError because of a dependency issue.
+ # This check avoids hiding such errors.
+ # See https://github.com/huggingface/transformers/issues/33604
+ if "No module named" in str(exception):
+ missing_packages.append(imp)
+ else:
+ raise
+
+ if len(missing_packages) > 0:
+ raise ImportError(
+ "This modeling file requires the following packages that were not found in your environment: "
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
+ )
+
+ return get_relative_imports(filename)
+
+
+def get_class_in_module(
+ class_name: str,
+ module_path: Union[str, os.PathLike],
+ *,
+ force_reload: bool = False,
+) -> typing.Type:
+ """
+ Import a module on the cache directory for modules and extract a class from it.
+
+ Args:
+ class_name (`str`): The name of the class to import.
+ module_path (`str` or `os.PathLike`): The path to the module to import.
+ force_reload (`bool`, *optional*, defaults to `False`):
+ Whether to reload the dynamic module from file if it already exists in `sys.modules`.
+ Otherwise, the module is only reloaded if the file has changed.
+
+ Returns:
+ `typing.Type`: The class looked for.
+ """
+ name = os.path.normpath(module_path)
+ if name.endswith(".py"):
+ name = name[:-3]
+ name = name.replace(os.path.sep, ".")
+ module_file: Path = Path(HF_MODULES_CACHE) / module_path
+ with _HF_REMOTE_CODE_LOCK:
+ if force_reload:
+ sys.modules.pop(name, None)
+ importlib.invalidate_caches()
+ cached_module: Optional[ModuleType] = sys.modules.get(name)
+ module_spec = importlib.util.spec_from_file_location(name, location=module_file)
+
+ # Hash the module file and all its relative imports to check if we need to reload it
+ module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
+ module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
+
+ module: ModuleType
+ if cached_module is None:
+ module = importlib.util.module_from_spec(module_spec)
+ # insert it into sys.modules before any loading begins
+ sys.modules[name] = module
+ else:
+ module = cached_module
+ # reload in both cases, unless the module is already imported and the hash hits
+ if getattr(module, "__transformers_module_hash__", "") != module_hash:
+ module_spec.loader.exec_module(module)
+ module.__transformers_module_hash__ = module_hash
+ return getattr(module, class_name)
+
+
+def get_cached_module_file(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: Optional[bool] = None,
+ proxies: Optional[Dict[str, str]] = None,
+ token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ repo_type: Optional[str] = None,
+ _commit_hash: Optional[str] = None,
+ **deprecated_kwargs,
+) -> str:
+ """
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
+ Transformers module.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+ repo_type (`str`, *optional*):
+ Specify the repo type (useful when downloading from a space for instance).
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `str`: The path to the module inside the cache.
+ """
+ use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
+ token = use_auth_token
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if is_local:
+ submodule = os.path.basename(pretrained_model_name_or_path)
+ else:
+ submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
+ cached_module = try_to_load_from_cache(
+ pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
+ )
+
+ new_files = []
+ try:
+ # Load from URL or cache if already cached
+ resolved_module_file = cached_file(
+ pretrained_model_name_or_path,
+ module_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ repo_type=repo_type,
+ _commit_hash=_commit_hash,
+ )
+ if not is_local and cached_module != resolved_module_file:
+ new_files.append(module_file)
+
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
+
+ # Check we have all the requirements in our environment
+ modules_needed = check_imports(resolved_module_file)
+
+ # Now we move the module inside our cached dynamic modules.
+ full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
+ create_dynamic_module(full_submodule)
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
+ if submodule == os.path.basename(pretrained_model_name_or_path):
+ # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
+ # has changed since last copy.
+ if not (submodule_path / module_file).exists() or not filecmp.cmp(
+ resolved_module_file, str(submodule_path / module_file)
+ ):
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ importlib.invalidate_caches()
+ for module_needed in modules_needed:
+ module_needed = f"{module_needed}.py"
+ module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
+ if not (submodule_path / module_needed).exists() or not filecmp.cmp(
+ module_needed_file, str(submodule_path / module_needed)
+ ):
+ shutil.copy(module_needed_file, submodule_path / module_needed)
+ importlib.invalidate_caches()
+ else:
+ # Get the commit hash
+ commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
+
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
+ # benefit of versioning.
+ submodule_path = submodule_path / commit_hash
+ full_submodule = full_submodule + os.path.sep + commit_hash
+ create_dynamic_module(full_submodule)
+
+ if not (submodule_path / module_file).exists():
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ importlib.invalidate_caches()
+ # Make sure we also have every file with relative
+ for module_needed in modules_needed:
+ if not (submodule_path / f"{module_needed}.py").exists():
+ get_cached_module_file(
+ pretrained_model_name_or_path,
+ f"{module_needed}.py",
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ local_files_only=local_files_only,
+ _commit_hash=commit_hash,
+ )
+ new_files.append(f"{module_needed}.py")
+
+ if len(new_files) > 0 and revision is None:
+ new_files = "\n".join([f"- {f}" for f in new_files])
+ repo_type_str = "" if repo_type is None else f"{repo_type}s/"
+ url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
+ logger.warning(
+ f"A new version of the following files was downloaded from {url}:\n{new_files}"
+ "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
+ "versions of the code file, you can pin a revision."
+ )
+
+ return os.path.join(full_submodule, module_file)
+
+
+def get_class_from_dynamic_module(
+ class_reference: str,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: Optional[bool] = None,
+ proxies: Optional[Dict[str, str]] = None,
+ token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ repo_type: Optional[str] = None,
+ code_revision: Optional[str] = None,
+ **kwargs,
+) -> typing.Type:
+ """
+ Extracts a class from a module file, present in the local folder or repository of a model.
+
+
+
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
+ therefore only be called on trusted repos.
+
+
+
+
+
+ Args:
+ class_reference (`str`):
+ The full name of the class to load, including its module and optionally its repo.
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ This is used when `class_reference` does not specify another repo.
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ class_name (`str`):
+ The name of the class to import in the module.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+ repo_type (`str`, *optional*):
+ Specify the repo type (useful when downloading from a space for instance).
+ code_revision (`str`, *optional*, defaults to `"main"`):
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
+ rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
+ storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
+
+
+
+ Passing `token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `typing.Type`: The class, dynamically imported from the module.
+
+ Examples:
+
+ ```python
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
+ # module.
+ cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
+
+ # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
+ # module.
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
+ ```"""
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
+ token = use_auth_token
+
+ # Catch the name of the repo if it's specified in `class_reference`
+ if "--" in class_reference:
+ repo_id, class_reference = class_reference.split("--")
+ else:
+ repo_id = pretrained_model_name_or_path
+ module_file, class_name = class_reference.split(".")
+
+ if code_revision is None and pretrained_model_name_or_path == repo_id:
+ code_revision = revision
+ # And lastly we get the class inside our newly created module
+ final_module = get_cached_module_file(
+ repo_id,
+ module_file + ".py",
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ revision=code_revision,
+ local_files_only=local_files_only,
+ repo_type=repo_type,
+ )
+ return get_class_in_module(class_name, final_module, force_reload=force_download)
+
+
+def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
+ """
+ Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
+ adds the proper fields in a config.
+
+ Args:
+ obj (`Any`): The object for which to save the module files.
+ folder (`str` or `os.PathLike`): The folder where to save.
+ config (`PretrainedConfig` or dictionary, `optional`):
+ A config in which to register the auto_map corresponding to this custom object.
+
+ Returns:
+ `List[str]`: The list of files saved.
+ """
+ if obj.__module__ == "__main__":
+ logger.warning(
+ f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
+ "this code in a separate module so we can include it in the saved folder and make it easier to share via "
+ "the Hub."
+ )
+ return
+
+ def _set_auto_map_in_config(_config):
+ module_name = obj.__class__.__module__
+ last_module = module_name.split(".")[-1]
+ full_name = f"{last_module}.{obj.__class__.__name__}"
+ # Special handling for tokenizers
+ if "Tokenizer" in full_name:
+ slow_tokenizer_class = None
+ fast_tokenizer_class = None
+ if obj.__class__.__name__.endswith("Fast"):
+ # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
+ fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
+ if getattr(obj, "slow_tokenizer_class", None) is not None:
+ slow_tokenizer = getattr(obj, "slow_tokenizer_class")
+ slow_tok_module_name = slow_tokenizer.__module__
+ last_slow_tok_module = slow_tok_module_name.split(".")[-1]
+ slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
+ else:
+ # Slow tokenizer: no way to have the fast class
+ slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
+
+ full_name = (slow_tokenizer_class, fast_tokenizer_class)
+
+ if isinstance(_config, dict):
+ auto_map = _config.get("auto_map", {})
+ auto_map[obj._auto_class] = full_name
+ _config["auto_map"] = auto_map
+ elif getattr(_config, "auto_map", None) is not None:
+ _config.auto_map[obj._auto_class] = full_name
+ else:
+ _config.auto_map = {obj._auto_class: full_name}
+
+ # Add object class to the config auto_map
+ if isinstance(config, (list, tuple)):
+ for cfg in config:
+ _set_auto_map_in_config(cfg)
+ elif config is not None:
+ _set_auto_map_in_config(config)
+
+ result = []
+ # Copy module file to the output folder.
+ object_file = sys.modules[obj.__module__].__file__
+ dest_file = Path(folder) / (Path(object_file).name)
+ shutil.copy(object_file, dest_file)
+ result.append(dest_file)
+
+ # Gather all relative imports recursively and make sure they are copied as well.
+ for needed_file in get_relative_import_files(object_file):
+ dest_file = Path(folder) / (Path(needed_file).name)
+ shutil.copy(needed_file, dest_file)
+ result.append(dest_file)
+
+ return result
+
+
+def _raise_timeout_error(signum, frame):
+ raise ValueError(
+ "Loading this model requires you to execute custom code contained in the model repository on your local "
+ "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
+ )
+
+
+TIME_OUT_REMOTE_CODE = 15
+
+
+def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
+ if trust_remote_code is None:
+ if has_local_code:
+ trust_remote_code = False
+ elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
+ prev_sig_handler = None
+ try:
+ prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
+ signal.alarm(TIME_OUT_REMOTE_CODE)
+ while trust_remote_code is None:
+ answer = input(
+ f"The repository for {model_name} contains custom code which must be executed to correctly "
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
+ f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
+ f"Do you wish to run the custom code? [y/N] "
+ )
+ if answer.lower() in ["yes", "y", "1"]:
+ trust_remote_code = True
+ elif answer.lower() in ["no", "n", "0", ""]:
+ trust_remote_code = False
+ signal.alarm(0)
+ except Exception:
+ # OS which does not support signal.SIGALRM
+ raise ValueError(
+ f"The repository for {model_name} contains custom code which must be executed to correctly "
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
+ f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
+ )
+ finally:
+ if prev_sig_handler is not None:
+ signal.signal(signal.SIGALRM, prev_sig_handler)
+ signal.alarm(0)
+ elif has_remote_code:
+ # For the CI which puts the timeout at 0
+ _raise_timeout_error(None, None)
+
+ if has_remote_code and not has_local_code and not trust_remote_code:
+ raise ValueError(
+ f"Loading {model_name} requires you to execute the configuration file in that"
+ " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
+ " set the option `trust_remote_code=True` to remove this error."
+ )
+
+ return trust_remote_code
diff --git a/.venv/lib/python3.11/site-packages/transformers/feature_extraction_sequence_utils.py b/.venv/lib/python3.11/site-packages/transformers/feature_extraction_sequence_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f74a3f0c40e28415644b2b2b4b81ad7ed9320a56
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/feature_extraction_sequence_utils.py
@@ -0,0 +1,372 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Sequence feature extraction class for common feature extractors to preprocess sequences.
+"""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy
+
+
+logger = logging.get_logger(__name__)
+
+
+class SequenceFeatureExtractor(FeatureExtractionMixin):
+ """
+ This is a general feature extraction class for speech recognition.
+
+ Args:
+ feature_size (`int`):
+ The feature dimension of the extracted features.
+ sampling_rate (`int`):
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
+ padding_value (`float`):
+ The value that is used to fill the padding values / vectors.
+ """
+
+ def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
+ self.feature_size = feature_size
+ self.sampling_rate = sampling_rate
+ self.padding_value = padding_value
+
+ self.padding_side = kwargs.pop("padding_side", "right")
+ self.return_attention_mask = kwargs.pop("return_attention_mask", True)
+
+ super().__init__(**kwargs)
+
+ def pad(
+ self,
+ processed_features: Union[
+ BatchFeature,
+ List[BatchFeature],
+ Dict[str, BatchFeature],
+ Dict[str, List[BatchFeature]],
+ List[Dict[str, BatchFeature]],
+ ],
+ padding: Union[bool, str, PaddingStrategy] = True,
+ max_length: Optional[int] = None,
+ truncation: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ ) -> BatchFeature:
+ """
+ Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the
+ max sequence length in the batch.
+
+ Padding side (left/right) padding values are defined at the feature extractor level (with `self.padding_side`,
+ `self.padding_value`)
+
+
+
+ If the `processed_features` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
+ result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
+ PyTorch tensors, you will lose the specific device of your tensors however.
+
+
+
+ Args:
+ processed_features ([`BatchFeature`], list of [`BatchFeature`], `Dict[str, List[float]]`, `Dict[str, List[List[float]]` or `List[Dict[str, List[float]]]`):
+ Processed inputs. Can represent one input ([`BatchFeature`] or `Dict[str, List[float]]`) or a batch of
+ input values / vectors (list of [`BatchFeature`], *Dict[str, List[List[float]]]* or *List[Dict[str,
+ List[float]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
+ collate function.
+
+ Instead of `List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
+ see the note above for the return type.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value.
+
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific feature_extractor's default.
+
+ [What are attention masks?](../glossary#attention-mask)
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ """
+ # If we have a list of dicts, let's convert it in a dict of lists
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
+ if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)):
+ processed_features = {
+ key: [example[key] for example in processed_features] for key in processed_features[0].keys()
+ }
+
+ # The model's main input name, usually `input_values`, has be passed for padding
+ if self.model_input_names[0] not in processed_features:
+ raise ValueError(
+ "You should supply an instance of `transformers.BatchFeature` or list of `transformers.BatchFeature`"
+ f" to this method that includes {self.model_input_names[0]}, but you provided"
+ f" {list(processed_features.keys())}"
+ )
+
+ required_input = processed_features[self.model_input_names[0]]
+ return_attention_mask = (
+ return_attention_mask if return_attention_mask is not None else self.return_attention_mask
+ )
+
+ if len(required_input) == 0:
+ if return_attention_mask:
+ processed_features["attention_mask"] = []
+ return processed_features
+
+ # If we have PyTorch/TF tensors or lists as inputs, we cast them as Numpy arrays
+ # and rebuild them afterwards if no return_tensors is specified
+ # Note that we lose the specific device the tensor may be on for PyTorch
+
+ first_element = required_input[0]
+ if isinstance(first_element, (list, tuple)):
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
+ index = 0
+ while len(required_input[index]) == 0:
+ index += 1
+ if index < len(required_input):
+ first_element = required_input[index][0]
+
+ if return_tensors is None:
+ if is_tf_tensor(first_element):
+ return_tensors = "tf"
+ elif is_torch_tensor(first_element):
+ return_tensors = "pt"
+ elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
+ return_tensors = "np"
+ else:
+ raise ValueError(
+ f"type of {first_element} unknown: {type(first_element)}. "
+ "Should be one of a python, numpy, pytorch or tensorflow object."
+ )
+
+ for key, value in processed_features.items():
+ if isinstance(value[0], (int, float)):
+ processed_features[key] = to_numpy(value)
+ else:
+ processed_features[key] = [to_numpy(v) for v in value]
+
+ # Convert padding_strategy in PaddingStrategy
+ padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
+
+ required_input = processed_features[self.model_input_names[0]]
+
+ batch_size = len(required_input)
+ if not all(len(v) == batch_size for v in processed_features.values()):
+ raise ValueError("Some items in the output dictionary have a different batch size than others.")
+
+ truncated_inputs = []
+ for i in range(batch_size):
+ inputs = {k: v[i] for k, v in processed_features.items()}
+ # truncation
+ inputs_slice = self._truncate(
+ inputs,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ truncation=truncation,
+ )
+ truncated_inputs.append(inputs_slice)
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ # make sure that `max_length` cannot be longer than the longest truncated length
+ max_length = max(len(input_slice[self.model_input_names[0]]) for input_slice in truncated_inputs)
+ padding_strategy = PaddingStrategy.MAX_LENGTH
+
+ batch_outputs = {}
+ for i in range(batch_size):
+ # padding
+ outputs = self._pad(
+ truncated_inputs[i],
+ max_length=max_length,
+ padding_strategy=padding_strategy,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ if value.dtype is np.dtype(np.float64):
+ value = value.astype(np.float32)
+ batch_outputs[key].append(value)
+
+ return BatchFeature(batch_outputs, tensor_type=return_tensors)
+
+ def _pad(
+ self,
+ processed_features: Union[Dict[str, np.ndarray], BatchFeature],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ processed_features (`Union[Dict[str, np.ndarray], BatchFeature]`):
+ Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
+ of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see below)
+ padding_strategy (`PaddingStrategy`, *optional*, default to `PaddingStrategy.DO_NOT_PAD`):
+ PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The feature_extractor padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of (`int`, *optional*):
+ Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
+ enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
+ which benefit from having sequence lengths be a multiple of 128.
+ return_attention_mask (`bool`, *optional*):
+ Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ required_input = processed_features[self.model_input_names[0]]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) < max_length
+
+ if return_attention_mask and "attention_mask" not in processed_features:
+ processed_features["attention_mask"] = np.ones(len(required_input), dtype=np.int32)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+ if self.padding_side == "right":
+ if return_attention_mask:
+ processed_features["attention_mask"] = np.pad(
+ processed_features["attention_mask"], (0, difference)
+ )
+ padding_shape = ((0, difference), (0, 0)) if self.feature_size > 1 else (0, difference)
+ processed_features[self.model_input_names[0]] = np.pad(
+ required_input, padding_shape, "constant", constant_values=self.padding_value
+ )
+ elif self.padding_side == "left":
+ if return_attention_mask:
+ processed_features["attention_mask"] = np.pad(
+ processed_features["attention_mask"], (difference, 0)
+ )
+ padding_shape = ((difference, 0), (0, 0)) if self.feature_size > 1 else (difference, 0)
+ processed_features[self.model_input_names[0]] = np.pad(
+ required_input, padding_shape, "constant", constant_values=self.padding_value
+ )
+ else:
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
+
+ return processed_features
+
+ def _truncate(
+ self,
+ processed_features: Union[Dict[str, np.ndarray], BatchFeature],
+ max_length: Optional[int] = None,
+ pad_to_multiple_of: Optional[int] = None,
+ truncation: Optional[bool] = None,
+ ):
+ """
+ Truncate inputs to predefined length or max length in the batch
+
+ Args:
+ processed_features(`Union[Dict[str, np.ndarray], BatchFeature]`):
+ Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
+ of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
+ max_length (`int`, *optional*):
+ maximum length of the returned list and optionally padding length (see below)
+ pad_to_multiple_of (`int`, *optional*) :
+ Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
+ enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
+ which benefit from having sequence lengths be a multiple of 128.
+ truncation (`bool`, *optional*):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ """
+ if not truncation:
+ return processed_features
+ elif truncation and max_length is None:
+ raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.")
+
+ required_input = processed_features[self.model_input_names[0]]
+
+ # find `max_length` that fits `pad_to_multiple_of`
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_truncated = len(required_input) > max_length
+
+ if needs_to_be_truncated:
+ processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length]
+ if "attention_mask" in processed_features:
+ processed_features["attention_mask"] = processed_features["attention_mask"][:max_length]
+
+ return processed_features
+
+ def _get_padding_strategies(self, padding=False, max_length=None):
+ """
+ Find the correct padding strategy
+ """
+
+ # Get padding strategy
+ if padding is not False:
+ if padding is True:
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
+ elif not isinstance(padding, PaddingStrategy):
+ padding_strategy = PaddingStrategy(padding)
+ elif isinstance(padding, PaddingStrategy):
+ padding_strategy = padding
+ else:
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
+
+ # Set max length if needed
+ if max_length is None:
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
+ raise ValueError(
+ f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined"
+ )
+
+ # Test if we have a padding value
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
+ raise ValueError(
+ "Asking to pad but the feature_extractor does not have a padding value. Please select a value to use"
+ " as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
+ )
+
+ return padding_strategy
diff --git a/.venv/lib/python3.11/site-packages/transformers/hf_argparser.py b/.venv/lib/python3.11/site-packages/transformers/hf_argparser.py
new file mode 100644
index 0000000000000000000000000000000000000000..d03ff7004f2b6c0f14af55efd1bd4b8336dde305
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/hf_argparser.py
@@ -0,0 +1,437 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+import json
+import os
+import sys
+import types
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
+from copy import copy
+from enum import Enum
+from inspect import isclass
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints
+
+import yaml
+
+
+DataClass = NewType("DataClass", Any)
+DataClassType = NewType("DataClassType", Any)
+
+
+# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
+def string_to_bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise ArgumentTypeError(
+ f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
+ )
+
+
+def make_choice_type_function(choices: list) -> Callable[[str], Any]:
+ """
+ Creates a mapping function from each choices string representation to the actual value. Used to support multiple
+ value types for a single argument.
+
+ Args:
+ choices (list): List of choices.
+
+ Returns:
+ Callable[[str], Any]: Mapping function from string representation to actual value for each choice.
+ """
+ str_to_choice = {str(choice): choice for choice in choices}
+ return lambda arg: str_to_choice.get(arg, arg)
+
+
+def HfArg(
+ *,
+ aliases: Union[str, List[str]] = None,
+ help: str = None,
+ default: Any = dataclasses.MISSING,
+ default_factory: Callable[[], Any] = dataclasses.MISSING,
+ metadata: dict = None,
+ **kwargs,
+) -> dataclasses.Field:
+ """Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`.
+
+ Example comparing the use of `HfArg` and `dataclasses.field`:
+ ```
+ @dataclass
+ class Args:
+ regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
+ hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
+ ```
+
+ Args:
+ aliases (Union[str, List[str]], optional):
+ Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
+ Defaults to None.
+ help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
+ default (Any, optional):
+ Default value for the argument. If not default or default_factory is specified, the argument is required.
+ Defaults to dataclasses.MISSING.
+ default_factory (Callable[[], Any], optional):
+ The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
+ default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
+ Defaults to dataclasses.MISSING.
+ metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
+
+ Returns:
+ Field: A `dataclasses.Field` with the desired properties.
+ """
+ if metadata is None:
+ # Important, don't use as default param in function signature because dict is mutable and shared across function calls
+ metadata = {}
+ if aliases is not None:
+ metadata["aliases"] = aliases
+ if help is not None:
+ metadata["help"] = help
+
+ return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
+
+
+class HfArgumentParser(ArgumentParser):
+ """
+ This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
+
+ The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
+ arguments to the parser after initialization and you'll get the output back after parsing as an additional
+ namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
+ """
+
+ dataclass_types: Iterable[DataClassType]
+
+ def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
+ """
+ Args:
+ dataclass_types:
+ Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Passed to `argparse.ArgumentParser()` in the regular way.
+ """
+ # To make the default appear when using --help
+ if "formatter_class" not in kwargs:
+ kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
+ super().__init__(**kwargs)
+ if dataclasses.is_dataclass(dataclass_types):
+ dataclass_types = [dataclass_types]
+ self.dataclass_types = list(dataclass_types)
+ for dtype in self.dataclass_types:
+ self._add_dataclass_arguments(dtype)
+
+ @staticmethod
+ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
+ # Long-option strings are conventionlly separated by hyphens rather
+ # than underscores, e.g., "--long-format" rather than "--long_format".
+ # Argparse converts hyphens to underscores so that the destination
+ # string is a valid attribute name. Hf_argparser should do the same.
+ long_options = [f"--{field.name}"]
+ if "_" in field.name:
+ long_options.append(f"--{field.name.replace('_', '-')}")
+
+ kwargs = field.metadata.copy()
+ # field.metadata is not used at all by Data Classes,
+ # it is provided as a third-party extension mechanism.
+ if isinstance(field.type, str):
+ raise RuntimeError(
+ "Unresolved type detected, which should have been done with the help of "
+ "`typing.get_type_hints` method by default"
+ )
+
+ aliases = kwargs.pop("aliases", [])
+ if isinstance(aliases, str):
+ aliases = [aliases]
+
+ origin_type = getattr(field.type, "__origin__", field.type)
+ if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)):
+ if str not in field.type.__args__ and (
+ len(field.type.__args__) != 2 or type(None) not in field.type.__args__
+ ):
+ raise ValueError(
+ "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because"
+ " the argument parser only supports one type per argument."
+ f" Problem encountered in field '{field.name}'."
+ )
+ if type(None) not in field.type.__args__:
+ # filter `str` in Union
+ field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1]
+ origin_type = getattr(field.type, "__origin__", field.type)
+ elif bool not in field.type.__args__:
+ # filter `NoneType` in Union (except for `Union[bool, NoneType]`)
+ field.type = (
+ field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]
+ )
+ origin_type = getattr(field.type, "__origin__", field.type)
+
+ # A variable to store kwargs for a boolean field, if needed
+ # so that we can init a `no_*` complement argument (see below)
+ bool_kwargs = {}
+ if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)):
+ if origin_type is Literal:
+ kwargs["choices"] = field.type.__args__
+ else:
+ kwargs["choices"] = [x.value for x in field.type]
+
+ kwargs["type"] = make_choice_type_function(kwargs["choices"])
+
+ if field.default is not dataclasses.MISSING:
+ kwargs["default"] = field.default
+ else:
+ kwargs["required"] = True
+ elif field.type is bool or field.type == Optional[bool]:
+ # Copy the currect kwargs to use to instantiate a `no_*` complement argument below.
+ # We do not initialize it here because the `no_*` alternative must be instantiated after the real argument
+ bool_kwargs = copy(kwargs)
+
+ # Hack because type=bool in argparse does not behave as we want.
+ kwargs["type"] = string_to_bool
+ if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
+ # Default value is False if we have no default when of type bool.
+ default = False if field.default is dataclasses.MISSING else field.default
+ # This is the value that will get picked if we don't include --{field.name} in any way
+ kwargs["default"] = default
+ # This tells argparse we accept 0 or 1 value after --{field.name}
+ kwargs["nargs"] = "?"
+ # This is the value that will get picked if we do --{field.name} (without value)
+ kwargs["const"] = True
+ elif isclass(origin_type) and issubclass(origin_type, list):
+ kwargs["type"] = field.type.__args__[0]
+ kwargs["nargs"] = "+"
+ if field.default_factory is not dataclasses.MISSING:
+ kwargs["default"] = field.default_factory()
+ elif field.default is dataclasses.MISSING:
+ kwargs["required"] = True
+ else:
+ kwargs["type"] = field.type
+ if field.default is not dataclasses.MISSING:
+ kwargs["default"] = field.default
+ elif field.default_factory is not dataclasses.MISSING:
+ kwargs["default"] = field.default_factory()
+ else:
+ kwargs["required"] = True
+ parser.add_argument(*long_options, *aliases, **kwargs)
+
+ # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
+ # Order is important for arguments with the same destination!
+ # We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down
+ # here and we do not need those changes/additional keys.
+ if field.default is True and (field.type is bool or field.type == Optional[bool]):
+ bool_kwargs["default"] = False
+ parser.add_argument(
+ f"--no_{field.name}",
+ f"--no-{field.name.replace('_', '-')}",
+ action="store_false",
+ dest=field.name,
+ **bool_kwargs,
+ )
+
+ def _add_dataclass_arguments(self, dtype: DataClassType):
+ if hasattr(dtype, "_argument_group_name"):
+ parser = self.add_argument_group(dtype._argument_group_name)
+ else:
+ parser = self
+
+ try:
+ type_hints: Dict[str, type] = get_type_hints(dtype)
+ except NameError:
+ raise RuntimeError(
+ f"Type resolution failed for {dtype}. Try declaring the class in global scope or "
+ "removing line of `from __future__ import annotations` which opts in Postponed "
+ "Evaluation of Annotations (PEP 563)"
+ )
+ except TypeError as ex:
+ # Remove this block when we drop Python 3.9 support
+ if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex):
+ python_version = ".".join(map(str, sys.version_info[:3]))
+ raise RuntimeError(
+ f"Type resolution failed for {dtype} on Python {python_version}. Try removing "
+ "line of `from __future__ import annotations` which opts in union types as "
+ "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To "
+ "support Python versions that lower than 3.10, you need to use "
+ "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of "
+ "`X | None`."
+ ) from ex
+ raise
+
+ for field in dataclasses.fields(dtype):
+ if not field.init:
+ continue
+ field.type = type_hints[field.name]
+ self._parse_dataclass_field(parser, field)
+
+ def parse_args_into_dataclasses(
+ self,
+ args=None,
+ return_remaining_strings=False,
+ look_for_args_file=True,
+ args_filename=None,
+ args_file_flag=None,
+ ) -> Tuple[DataClass, ...]:
+ """
+ Parse command-line args into instances of the specified dataclass types.
+
+ This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at:
+ docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
+
+ Args:
+ args:
+ List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
+ return_remaining_strings:
+ If true, also return a list of remaining argument strings.
+ look_for_args_file:
+ If true, will look for a ".args" file with the same base name as the entry point script for this
+ process, and will append its potential content to the command line args.
+ args_filename:
+ If not None, will uses this file instead of the ".args" file specified in the previous argument.
+ args_file_flag:
+ If not None, will look for a file in the command-line args specified with this flag. The flag can be
+ specified multiple times and precedence is determined by the order (last one wins).
+
+ Returns:
+ Tuple consisting of:
+
+ - the dataclass instances in the same order as they were passed to the initializer.abspath
+ - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser
+ after initialization.
+ - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
+ """
+
+ if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)):
+ args_files = []
+
+ if args_filename:
+ args_files.append(Path(args_filename))
+ elif look_for_args_file and len(sys.argv):
+ args_files.append(Path(sys.argv[0]).with_suffix(".args"))
+
+ # args files specified via command line flag should overwrite default args files so we add them last
+ if args_file_flag:
+ # Create special parser just to extract the args_file_flag values
+ args_file_parser = ArgumentParser()
+ args_file_parser.add_argument(args_file_flag, type=str, action="append")
+
+ # Use only remaining args for further parsing (remove the args_file_flag)
+ cfg, args = args_file_parser.parse_known_args(args=args)
+ cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None)
+
+ if cmd_args_file_paths:
+ args_files.extend([Path(p) for p in cmd_args_file_paths])
+
+ file_args = []
+ for args_file in args_files:
+ if args_file.exists():
+ file_args += args_file.read_text().split()
+
+ # in case of duplicate arguments the last one has precedence
+ # args specified via the command line should overwrite args from files, so we add them last
+ args = file_args + args if args is not None else file_args + sys.argv[1:]
+ namespace, remaining_args = self.parse_known_args(args=args)
+ outputs = []
+ for dtype in self.dataclass_types:
+ keys = {f.name for f in dataclasses.fields(dtype) if f.init}
+ inputs = {k: v for k, v in vars(namespace).items() if k in keys}
+ for k in keys:
+ delattr(namespace, k)
+ obj = dtype(**inputs)
+ outputs.append(obj)
+ if len(namespace.__dict__) > 0:
+ # additional namespace.
+ outputs.append(namespace)
+ if return_remaining_strings:
+ return (*outputs, remaining_args)
+ else:
+ if remaining_args:
+ raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
+
+ return (*outputs,)
+
+ def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
+ """
+ Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
+ types.
+
+ Args:
+ args (`dict`):
+ dict containing config values
+ allow_extra_keys (`bool`, *optional*, defaults to `False`):
+ Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
+
+ Returns:
+ Tuple consisting of:
+
+ - the dataclass instances in the same order as they were passed to the initializer.
+ """
+ unused_keys = set(args.keys())
+ outputs = []
+ for dtype in self.dataclass_types:
+ keys = {f.name for f in dataclasses.fields(dtype) if f.init}
+ inputs = {k: v for k, v in args.items() if k in keys}
+ unused_keys.difference_update(inputs.keys())
+ obj = dtype(**inputs)
+ outputs.append(obj)
+ if not allow_extra_keys and unused_keys:
+ raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
+ return tuple(outputs)
+
+ def parse_json_file(
+ self, json_file: Union[str, os.PathLike], allow_extra_keys: bool = False
+ ) -> Tuple[DataClass, ...]:
+ """
+ Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
+ dataclass types.
+
+ Args:
+ json_file (`str` or `os.PathLike`):
+ File name of the json file to parse
+ allow_extra_keys (`bool`, *optional*, defaults to `False`):
+ Defaults to False. If False, will raise an exception if the json file contains keys that are not
+ parsed.
+
+ Returns:
+ Tuple consisting of:
+
+ - the dataclass instances in the same order as they were passed to the initializer.
+ """
+ with open(Path(json_file), encoding="utf-8") as open_json_file:
+ data = json.loads(open_json_file.read())
+ outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
+ return tuple(outputs)
+
+ def parse_yaml_file(
+ self, yaml_file: Union[str, os.PathLike], allow_extra_keys: bool = False
+ ) -> Tuple[DataClass, ...]:
+ """
+ Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the
+ dataclass types.
+
+ Args:
+ yaml_file (`str` or `os.PathLike`):
+ File name of the yaml file to parse
+ allow_extra_keys (`bool`, *optional*, defaults to `False`):
+ Defaults to False. If False, will raise an exception if the json file contains keys that are not
+ parsed.
+
+ Returns:
+ Tuple consisting of:
+
+ - the dataclass instances in the same order as they were passed to the initializer.
+ """
+ outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
+ return tuple(outputs)
diff --git a/.venv/lib/python3.11/site-packages/transformers/hyperparameter_search.py b/.venv/lib/python3.11/site-packages/transformers/hyperparameter_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..c14165165ca1f92fb28e27b718c8bd81e1bc3a93
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/hyperparameter_search.py
@@ -0,0 +1,141 @@
+# coding=utf-8
+# Copyright 2023-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .integrations import (
+ is_optuna_available,
+ is_ray_tune_available,
+ is_sigopt_available,
+ is_wandb_available,
+ run_hp_search_optuna,
+ run_hp_search_ray,
+ run_hp_search_sigopt,
+ run_hp_search_wandb,
+)
+from .trainer_utils import (
+ HPSearchBackend,
+ default_hp_space_optuna,
+ default_hp_space_ray,
+ default_hp_space_sigopt,
+ default_hp_space_wandb,
+)
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class HyperParamSearchBackendBase:
+ name: str
+ pip_package: str = None
+
+ @staticmethod
+ def is_available():
+ raise NotImplementedError
+
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
+ raise NotImplementedError
+
+ def default_hp_space(self, trial):
+ raise NotImplementedError
+
+ def ensure_available(self):
+ if not self.is_available():
+ raise RuntimeError(
+ f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
+ )
+
+ @classmethod
+ def pip_install(cls):
+ return f"`pip install {cls.pip_package or cls.name}`"
+
+
+class OptunaBackend(HyperParamSearchBackendBase):
+ name = "optuna"
+
+ @staticmethod
+ def is_available():
+ return is_optuna_available()
+
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
+ return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
+
+ def default_hp_space(self, trial):
+ return default_hp_space_optuna(trial)
+
+
+class RayTuneBackend(HyperParamSearchBackendBase):
+ name = "ray"
+ pip_package = "'ray[tune]'"
+
+ @staticmethod
+ def is_available():
+ return is_ray_tune_available()
+
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
+ return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
+
+ def default_hp_space(self, trial):
+ return default_hp_space_ray(trial)
+
+
+class SigOptBackend(HyperParamSearchBackendBase):
+ name = "sigopt"
+
+ @staticmethod
+ def is_available():
+ return is_sigopt_available()
+
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
+ return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
+
+ def default_hp_space(self, trial):
+ return default_hp_space_sigopt(trial)
+
+
+class WandbBackend(HyperParamSearchBackendBase):
+ name = "wandb"
+
+ @staticmethod
+ def is_available():
+ return is_wandb_available()
+
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
+ return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
+
+ def default_hp_space(self, trial):
+ return default_hp_space_wandb(trial)
+
+
+ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
+ HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
+}
+
+
+def default_hp_search_backend() -> str:
+ available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
+ if len(available_backends) > 0:
+ name = available_backends[0].name
+ if len(available_backends) > 1:
+ logger.info(
+ f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
+ )
+ return name
+ raise RuntimeError(
+ "No hyperparameter search backend available.\n"
+ + "\n".join(
+ f" - To install {backend.name} run {backend.pip_install()}"
+ for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
+ )
+ )
diff --git a/.venv/lib/python3.11/site-packages/transformers/image_processing_base.py b/.venv/lib/python3.11/site-packages/transformers/image_processing_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6ce7af3fa8076958d4a4ac87ca3ab13716c7955
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/image_processing_base.py
@@ -0,0 +1,559 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import copy
+import json
+import os
+import warnings
+from io import BytesIO
+from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
+
+import numpy as np
+import requests
+
+from .dynamic_module_utils import custom_object_save
+from .feature_extraction_utils import BatchFeature as BaseBatchFeature
+from .utils import (
+ IMAGE_PROCESSOR_NAME,
+ PushToHubMixin,
+ add_model_info_to_auto_map,
+ add_model_info_to_custom_pipelines,
+ cached_file,
+ copy_func,
+ download_url,
+ is_offline_mode,
+ is_remote_url,
+ is_vision_available,
+ logging,
+)
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin")
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
+# We override the class string here, but logic is the same.
+class BatchFeature(BaseBatchFeature):
+ r"""
+ Holds the output of the image processor specific `__call__` methods.
+
+ This class is derived from a python dictionary and can be used as a dictionary.
+
+ Args:
+ data (`dict`):
+ Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+ initialization.
+ """
+
+
+# TODO: (Amy) - factor out the common parts of this and the feature extractor
+class ImageProcessingMixin(PushToHubMixin):
+ """
+ This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
+ extractors.
+ """
+
+ _auto_class = None
+
+ def __init__(self, **kwargs):
+ """Set elements of `kwargs` as attributes."""
+ # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use
+ # `XXXImageProcessor`, this attribute and its value are misleading.
+ kwargs.pop("feature_extractor_type", None)
+ # Pop "processor_class" as it should be saved as private attribute
+ self._processor_class = kwargs.pop("processor_class", None)
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ def _set_processor_class(self, processor_class: str):
+ """Sets processor class as an attribute."""
+ self._processor_class = processor_class
+
+ @classmethod
+ def from_pretrained(
+ cls: Type[ImageProcessorType],
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ) -> ImageProcessorType:
+ r"""
+ Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a image processor file saved using the
+ [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
+ `./my_model_directory/`.
+ - a path or url to a saved image processor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model image processor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the image processor files and override the cached versions if
+ they exist.
+ resume_download:
+ Deprecated and ignored. All downloads are now resumed by default when possible.
+ Will be removed in v5 of Transformers.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+
+
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final image processor object. If `True`, then this
+ functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
+ `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ kwargs (`Dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are image processor attributes will be used to override the
+ loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
+ controlled by the `return_unused_kwargs` keyword parameter.
+
+ Returns:
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`].
+
+ Examples:
+
+ ```python
+ # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a
+ # derived class: *CLIPImageProcessor*
+ image_processor = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-base-patch32"
+ ) # Download image_processing_config from huggingface.co and cache.
+ image_processor = CLIPImageProcessor.from_pretrained(
+ "./test/saved_model/"
+ ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')*
+ image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json")
+ image_processor = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-base-patch32", do_normalize=False, foo=False
+ )
+ assert image_processor.do_normalize is False
+ image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained(
+ "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True
+ )
+ assert image_processor.do_normalize is False
+ assert unused_kwargs == {"foo": False}
+ ```"""
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ if token is not None:
+ kwargs["token"] = token
+
+ image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
+
+ return cls.from_dict(image_processor_dict, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the image processor JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ use_auth_token = kwargs.pop("use_auth_token", None)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if kwargs.get("token", None) is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ kwargs["token"] = use_auth_token
+
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = self._create_repo(repo_id, **kwargs)
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=self)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
+
+ self.to_json_file(output_image_processor_file)
+ logger.info(f"Image processor saved in {output_image_processor_file}")
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=kwargs.get("token"),
+ )
+
+ return [output_image_processor_file]
+
+ @classmethod
+ def get_image_processor_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
+ The name of the file in the model directory to use for the image processor config.
+
+ Returns:
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", None)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", "")
+ image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
+
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+
+ if use_auth_token is not None:
+ warnings.warn(
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+ FutureWarning,
+ )
+ if token is not None:
+ raise ValueError(
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+ )
+ token = use_auth_token
+
+ user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_image_processor_file = pretrained_model_name_or_path
+ is_local = True
+ elif is_remote_url(pretrained_model_name_or_path):
+ image_processor_file = pretrained_model_name_or_path
+ resolved_image_processor_file = download_url(pretrained_model_name_or_path)
+ else:
+ image_processor_file = image_processor_filename
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_image_processor_file = cached_file(
+ pretrained_model_name_or_path,
+ image_processor_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ )
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise EnvironmentError(
+ f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {image_processor_filename} file"
+ )
+
+ try:
+ # Load image_processor dict
+ with open(resolved_image_processor_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ image_processor_dict = json.loads(text)
+
+ except json.JSONDecodeError:
+ raise EnvironmentError(
+ f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
+ )
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_image_processor_file}")
+ else:
+ logger.info(
+ f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
+ )
+ if "auto_map" in image_processor_dict:
+ image_processor_dict["auto_map"] = add_model_info_to_auto_map(
+ image_processor_dict["auto_map"], pretrained_model_name_or_path
+ )
+ if "custom_pipelines" in image_processor_dict:
+ image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
+ image_processor_dict["custom_pipelines"], pretrained_model_name_or_path
+ )
+
+ return image_processor_dict, kwargs
+
+ @classmethod
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
+ """
+ Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters.
+
+ Args:
+ image_processor_dict (`Dict[str, Any]`):
+ Dictionary that will be used to instantiate the image processor object. Such a dictionary can be
+ retrieved from a pretrained checkpoint by leveraging the
+ [`~image_processing_utils.ImageProcessingMixin.to_dict`] method.
+ kwargs (`Dict[str, Any]`):
+ Additional parameters from which to initialize the image processor object.
+
+ Returns:
+ [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
+ parameters.
+ """
+ image_processor_dict = image_processor_dict.copy()
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
+ # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
+ # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
+ if "size" in kwargs and "size" in image_processor_dict:
+ image_processor_dict["size"] = kwargs.pop("size")
+ if "crop_size" in kwargs and "crop_size" in image_processor_dict:
+ image_processor_dict["crop_size"] = kwargs.pop("crop_size")
+
+ image_processor = cls(**image_processor_dict)
+
+ # Update image_processor with kwargs if needed
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(image_processor, key):
+ setattr(image_processor, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ logger.info(f"Image processor {image_processor}")
+ if return_unused_kwargs:
+ return image_processor, kwargs
+ else:
+ return image_processor
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["image_processor_type"] = self.__class__.__name__
+
+ return output
+
+ @classmethod
+ def from_json_file(cls, json_file: Union[str, os.PathLike]):
+ """
+ Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON
+ file of parameters.
+
+ Args:
+ json_file (`str` or `os.PathLike`):
+ Path to the JSON file containing the parameters.
+
+ Returns:
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
+ instantiated from that JSON file.
+ """
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ image_processor_dict = json.loads(text)
+ return cls(**image_processor_dict)
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+ """
+ dictionary = self.to_dict()
+
+ for key, value in dictionary.items():
+ if isinstance(value, np.ndarray):
+ dictionary[key] = value.tolist()
+
+ # make sure private name "_processor_class" is correctly
+ # saved as "processor_class"
+ _processor_class = dictionary.pop("_processor_class", None)
+ if _processor_class is not None:
+ dictionary["processor_class"] = _processor_class
+
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this image_processor instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
+ """
+ Register this class with a given auto class. This should only be used for custom image processors as the ones
+ in the library are already mapped with `AutoImageProcessor `.
+
+
+
+ This API is experimental and may have some slight breaking changes in the next releases.
+
+
+
+ Args:
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):
+ The auto class to register this new image processor with.
+ """
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
+
+ def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
+ """
+ Convert a single or a list of urls into the corresponding `PIL.Image` objects.
+
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
+ returned.
+ """
+ headers = {
+ "User-Agent": (
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
+ " Safari/537.36"
+ )
+ }
+ if isinstance(image_url_or_urls, list):
+ return [self.fetch_images(x) for x in image_url_or_urls]
+ elif isinstance(image_url_or_urls, str):
+ response = requests.get(image_url_or_urls, stream=True, headers=headers)
+ response.raise_for_status()
+ return Image.open(BytesIO(response.content))
+ else:
+ raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
+
+
+ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub)
+if ImageProcessingMixin.push_to_hub.__doc__ is not None:
+ ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format(
+ object="image processor", object_class="AutoImageProcessor", object_files="image processor file"
+ )
diff --git a/.venv/lib/python3.11/site-packages/transformers/image_processing_utils.py b/.venv/lib/python3.11/site-packages/transformers/image_processing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0279f26a963e35fc0d3f74a3b669b8a5e1ccf422
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/image_processing_utils.py
@@ -0,0 +1,287 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Iterable, Optional, Union
+
+import numpy as np
+
+from .image_processing_base import BatchFeature, ImageProcessingMixin
+from .image_transforms import center_crop, normalize, rescale
+from .image_utils import ChannelDimension
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+INIT_SERVICE_KWARGS = [
+ "processor_class",
+ "image_processor_type",
+]
+
+
+class BaseImageProcessor(ImageProcessingMixin):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def __call__(self, images, **kwargs) -> BatchFeature:
+ """Preprocess an image or a batch of images."""
+ return self.preprocess(images, **kwargs)
+
+ def preprocess(self, images, **kwargs) -> BatchFeature:
+ raise NotImplementedError("Each image processor must implement its own preprocess method")
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: float,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Rescale an image by a scale factor. image = image * scale.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`float`):
+ The scaling factor to rescale pixel values by.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The rescaled image.
+ """
+ return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
+
+ def normalize(
+ self,
+ image: np.ndarray,
+ mean: Union[float, Iterable[float]],
+ std: Union[float, Iterable[float]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Normalize an image. image = (image - image_mean) / image_std.
+
+ Args:
+ image (`np.ndarray`):
+ Image to normalize.
+ mean (`float` or `Iterable[float]`):
+ Image mean to use for normalization.
+ std (`float` or `Iterable[float]`):
+ Image standard deviation to use for normalization.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+ Returns:
+ `np.ndarray`: The normalized image.
+ """
+ return normalize(
+ image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
+ )
+
+ def center_crop(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
+ any edge, the image is padded with 0's and then center cropped.
+
+ Args:
+ image (`np.ndarray`):
+ Image to center crop.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
+ return center_crop(
+ image,
+ size=(size["height"], size["width"]),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def to_dict(self):
+ encoder_dict = super().to_dict()
+ encoder_dict.pop("_valid_processor_keys", None)
+ return encoder_dict
+
+
+VALID_SIZE_DICT_KEYS = (
+ {"height", "width"},
+ {"shortest_edge"},
+ {"shortest_edge", "longest_edge"},
+ {"longest_edge"},
+ {"max_height", "max_width"},
+)
+
+
+def is_valid_size_dict(size_dict):
+ if not isinstance(size_dict, dict):
+ return False
+
+ size_dict_keys = set(size_dict.keys())
+ for allowed_keys in VALID_SIZE_DICT_KEYS:
+ if size_dict_keys == allowed_keys:
+ return True
+ return False
+
+
+def convert_to_size_dict(
+ size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
+):
+ # By default, if size is an int we assume it represents a tuple of (size, size).
+ if isinstance(size, int) and default_to_square:
+ if max_size is not None:
+ raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
+ return {"height": size, "width": size}
+ # In other configs, if size is an int and default_to_square is False, size represents the length of
+ # the shortest edge after resizing.
+ elif isinstance(size, int) and not default_to_square:
+ size_dict = {"shortest_edge": size}
+ if max_size is not None:
+ size_dict["longest_edge"] = max_size
+ return size_dict
+ # Otherwise, if size is a tuple it's either (height, width) or (width, height)
+ elif isinstance(size, (tuple, list)) and height_width_order:
+ return {"height": size[0], "width": size[1]}
+ elif isinstance(size, (tuple, list)) and not height_width_order:
+ return {"height": size[1], "width": size[0]}
+ elif size is None and max_size is not None:
+ if default_to_square:
+ raise ValueError("Cannot specify both default_to_square=True and max_size")
+ return {"longest_edge": max_size}
+
+ raise ValueError(f"Could not convert size input to size dict: {size}")
+
+
+def get_size_dict(
+ size: Union[int, Iterable[int], Dict[str, int]] = None,
+ max_size: Optional[int] = None,
+ height_width_order: bool = True,
+ default_to_square: bool = True,
+ param_name="size",
+) -> dict:
+ """
+ Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
+ compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
+ width) or (width, height) format.
+
+ - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
+ size[0]}` if `height_width_order` is `False`.
+ - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
+ - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
+ is set, it is added to the dict as `{"longest_edge": max_size}`.
+
+ Args:
+ size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
+ The `size` parameter to be cast into a size dictionary.
+ max_size (`Optional[int]`, *optional*):
+ The `max_size` parameter to be cast into a size dictionary.
+ height_width_order (`bool`, *optional*, defaults to `True`):
+ If `size` is a tuple, whether it's in (height, width) or (width, height) order.
+ default_to_square (`bool`, *optional*, defaults to `True`):
+ If `size` is an int, whether to default to a square image or not.
+ """
+ if not isinstance(size, dict):
+ size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
+ logger.info(
+ f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
+ f" Converted to {size_dict}.",
+ )
+ else:
+ size_dict = size
+
+ if not is_valid_size_dict(size_dict):
+ raise ValueError(
+ f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
+ )
+ return size_dict
+
+
+def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
+ """
+ Selects the best resolution from a list of possible resolutions based on the original size.
+
+ This is done by calculating the effective and wasted resolution for each possible resolution.
+
+ The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
+
+ Args:
+ original_size (tuple):
+ The original size of the image in the format (height, width).
+ possible_resolutions (list):
+ A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
+
+ Returns:
+ tuple: The best fit resolution in the format (height, width).
+ """
+ original_height, original_width = original_size
+ best_fit = None
+ max_effective_resolution = 0
+ min_wasted_resolution = float("inf")
+
+ for height, width in possible_resolutions:
+ scale = min(width / original_width, height / original_height)
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
+ wasted_resolution = (width * height) - effective_resolution
+
+ if effective_resolution > max_effective_resolution or (
+ effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
+ ):
+ max_effective_resolution = effective_resolution
+ min_wasted_resolution = wasted_resolution
+ best_fit = (height, width)
+
+ return best_fit
diff --git a/.venv/lib/python3.11/site-packages/transformers/image_processing_utils_fast.py b/.venv/lib/python3.11/site-packages/transformers/image_processing_utils_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c1be325b7eb304060e3ed8aa28981619c677129
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/image_processing_utils_fast.py
@@ -0,0 +1,133 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+from dataclasses import dataclass
+from typing import Any, Iterable, List, Optional, Tuple
+
+from .image_processing_utils import BaseImageProcessor
+from .utils.import_utils import is_torch_available, is_torchvision_available
+
+
+if is_torchvision_available():
+ from torchvision.transforms import Compose
+
+if is_torch_available():
+ import torch
+
+
+@dataclass(frozen=True)
+class SizeDict:
+ """
+ Hashable dictionary to store image size information.
+ """
+
+ height: int = None
+ width: int = None
+ longest_edge: int = None
+ shortest_edge: int = None
+ max_height: int = None
+ max_width: int = None
+
+ def __getitem__(self, key):
+ if hasattr(self, key):
+ return getattr(self, key)
+ raise KeyError(f"Key {key} not found in SizeDict.")
+
+
+class BaseImageProcessorFast(BaseImageProcessor):
+ _transform_params = None
+
+ def _build_transforms(self, **kwargs) -> "Compose":
+ """
+ Given the input settings e.g. do_resize, build the image transforms.
+ """
+ raise NotImplementedError
+
+ def _validate_params(self, **kwargs) -> None:
+ for k, v in kwargs.items():
+ if k not in self._transform_params:
+ raise ValueError(f"Invalid transform parameter {k}={v}.")
+
+ @functools.lru_cache(maxsize=1)
+ def get_transforms(self, **kwargs) -> "Compose":
+ self._validate_params(**kwargs)
+ return self._build_transforms(**kwargs)
+
+ def to_dict(self):
+ encoder_dict = super().to_dict()
+ encoder_dict.pop("_transform_params", None)
+ return encoder_dict
+
+
+def get_image_size_for_max_height_width(
+ image_size: Tuple[int, int],
+ max_height: int,
+ max_width: int,
+) -> Tuple[int, int]:
+ """
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
+ to at least one of the edges be equal to max_height or max_width.
+
+ For example:
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
+
+ Args:
+ image_size (`Tuple[int, int]`):
+ The image to resize.
+ max_height (`int`):
+ The maximum allowed height.
+ max_width (`int`):
+ The maximum allowed width.
+ """
+ height, width = image_size
+ height_scale = max_height / height
+ width_scale = max_width / width
+ min_scale = min(height_scale, width_scale)
+ new_height = int(height * min_scale)
+ new_width = int(width * min_scale)
+ return new_height, new_width
+
+
+def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
+ """
+ Squeezes a tensor, but only if the axis specified has dim 1.
+ """
+ if axis is None:
+ return tensor.squeeze()
+
+ try:
+ return tensor.squeeze(axis=axis)
+ except ValueError:
+ return tensor
+
+
+def max_across_indices(values: Iterable[Any]) -> List[Any]:
+ """
+ Return the maximum value across all indices of an iterable of values.
+ """
+ return [max(values_i) for values_i in zip(*values)]
+
+
+def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
+
+ return (max_height, max_width)
diff --git a/.venv/lib/python3.11/site-packages/transformers/image_transforms.py b/.venv/lib/python3.11/site-packages/transformers/image_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7d3a5abb7a8db634ef1f1f19ea57219f14457b4
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/image_transforms.py
@@ -0,0 +1,860 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from math import ceil
+from typing import Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+
+from .image_utils import (
+ ChannelDimension,
+ ImageInput,
+ get_channel_dimension_axis,
+ get_image_size,
+ infer_channel_dimension_format,
+)
+from .utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
+from .utils.import_utils import (
+ is_flax_available,
+ is_tf_available,
+ is_torch_available,
+ is_torchvision_available,
+ is_torchvision_v2_available,
+ is_vision_available,
+ requires_backends,
+)
+
+
+if is_vision_available():
+ import PIL
+
+ from .image_utils import PILImageResampling
+
+if is_torch_available():
+ import torch
+
+if is_tf_available():
+ import tensorflow as tf
+
+if is_flax_available():
+ import jax.numpy as jnp
+
+if is_torchvision_v2_available():
+ from torchvision.transforms.v2 import functional as F
+elif is_torchvision_available():
+ from torchvision.transforms import functional as F
+
+
+def to_channel_dimension_format(
+ image: np.ndarray,
+ channel_dim: Union[ChannelDimension, str],
+ input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
+) -> np.ndarray:
+ """
+ Converts `image` to the channel dimension format specified by `channel_dim`.
+
+ Args:
+ image (`numpy.ndarray`):
+ The image to have its channel dimension set.
+ channel_dim (`ChannelDimension`):
+ The channel dimension format to use.
+ input_channel_dim (`ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
+
+ Returns:
+ `np.ndarray`: The image with the channel dimension set to `channel_dim`.
+ """
+ if not isinstance(image, np.ndarray):
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
+
+ if input_channel_dim is None:
+ input_channel_dim = infer_channel_dimension_format(image)
+
+ target_channel_dim = ChannelDimension(channel_dim)
+ if input_channel_dim == target_channel_dim:
+ return image
+
+ if target_channel_dim == ChannelDimension.FIRST:
+ image = image.transpose((2, 0, 1))
+ elif target_channel_dim == ChannelDimension.LAST:
+ image = image.transpose((1, 2, 0))
+ else:
+ raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))
+
+ return image
+
+
+def rescale(
+ image: np.ndarray,
+ scale: float,
+ data_format: Optional[ChannelDimension] = None,
+ dtype: np.dtype = np.float32,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ """
+ Rescales `image` by `scale`.
+
+ Args:
+ image (`np.ndarray`):
+ The image to rescale.
+ scale (`float`):
+ The scale to use for rescaling the image.
+ data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
+ The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
+ extractors.
+ input_data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
+
+ Returns:
+ `np.ndarray`: The rescaled image.
+ """
+ if not isinstance(image, np.ndarray):
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
+
+ rescaled_image = image.astype(np.float64) * scale # Numpy type promotion has changed, so always upcast first
+ if data_format is not None:
+ rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
+
+ rescaled_image = rescaled_image.astype(dtype) # Finally downcast to the desired dtype at the end
+
+ return rescaled_image
+
+
+def _rescale_for_pil_conversion(image):
+ """
+ Detects whether or not the image needs to be rescaled before being converted to a PIL image.
+
+ The assumption is that if the image is of type `np.float` and all values are between 0 and 1, it needs to be
+ rescaled.
+ """
+ if image.dtype == np.uint8:
+ do_rescale = False
+ elif np.allclose(image, image.astype(int)):
+ if np.all(0 <= image) and np.all(image <= 255):
+ do_rescale = False
+ else:
+ raise ValueError(
+ "The image to be converted to a PIL image contains values outside the range [0, 255], "
+ f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
+ )
+ elif np.all(0 <= image) and np.all(image <= 1):
+ do_rescale = True
+ else:
+ raise ValueError(
+ "The image to be converted to a PIL image contains values outside the range [0, 1], "
+ f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
+ )
+ return do_rescale
+
+
+def to_pil_image(
+ image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
+ do_rescale: Optional[bool] = None,
+ image_mode: Optional[str] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> "PIL.Image.Image":
+ """
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
+ needed.
+
+ Args:
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`):
+ The image to convert to the `PIL.Image` format.
+ do_rescale (`bool`, *optional*):
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
+ to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
+ and `False` otherwise.
+ image_mode (`str`, *optional*):
+ The mode to use for the PIL image. If unset, will use the default mode for the input image type.
+ input_data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
+
+ Returns:
+ `PIL.Image.Image`: The converted image.
+ """
+ requires_backends(to_pil_image, ["vision"])
+
+ if isinstance(image, PIL.Image.Image):
+ return image
+
+ # Convert all tensors to numpy arrays before converting to PIL image
+ if is_torch_tensor(image) or is_tf_tensor(image):
+ image = image.numpy()
+ elif is_jax_tensor(image):
+ image = np.array(image)
+ elif not isinstance(image, np.ndarray):
+ raise ValueError("Input image type not supported: {}".format(type(image)))
+
+ # If the channel has been moved to first dim, we put it back at the end.
+ image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
+
+ # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
+ image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
+
+ # PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed.
+ do_rescale = _rescale_for_pil_conversion(image) if do_rescale is None else do_rescale
+
+ if do_rescale:
+ image = rescale(image, 255)
+
+ image = image.astype(np.uint8)
+ return PIL.Image.fromarray(image, mode=image_mode)
+
+
+# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
+def get_resize_output_image_size(
+ input_image: np.ndarray,
+ size: Union[int, Tuple[int, int], List[int], Tuple[int]],
+ default_to_square: bool = True,
+ max_size: Optional[int] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> tuple:
+ """
+ Find the target (height, width) dimension of the output image after resizing given the input image and the desired
+ size.
+
+ Args:
+ input_image (`np.ndarray`):
+ The image to resize.
+ size (`int` or `Tuple[int, int]` or List[int] or `Tuple[int]`):
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to
+ this.
+
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this
+ number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
+ default_to_square (`bool`, *optional*, defaults to `True`):
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square
+ (`size`,`size`). If set to `False`, will replicate
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
+ with support for resizing only the smallest edge and providing an optional `max_size`.
+ max_size (`int`, *optional*):
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater
+ than `max_size` after being resized according to `size`, then the image is resized again so that the longer
+ edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
+ than `size`. Only used if `default_to_square` is `False`.
+ input_data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
+
+ Returns:
+ `tuple`: The target (height, width) dimension of the output image after resizing.
+ """
+ if isinstance(size, (tuple, list)):
+ if len(size) == 2:
+ return tuple(size)
+ elif len(size) == 1:
+ # Perform same logic as if size was an int
+ size = size[0]
+ else:
+ raise ValueError("size must have 1 or 2 elements if it is a list or tuple")
+
+ if default_to_square:
+ return (size, size)
+
+ height, width = get_image_size(input_image, input_data_format)
+ short, long = (width, height) if width <= height else (height, width)
+ requested_new_short = size
+
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
+
+ if max_size is not None:
+ if max_size <= requested_new_short:
+ raise ValueError(
+ f"max_size = {max_size} must be strictly greater than the requested "
+ f"size for the smaller edge size = {size}"
+ )
+ if new_long > max_size:
+ new_short, new_long = int(max_size * new_short / new_long), max_size
+
+ return (new_long, new_short) if width <= height else (new_short, new_long)
+
+
+def resize(
+ image: np.ndarray,
+ size: Tuple[int, int],
+ resample: "PILImageResampling" = None,
+ reducing_gap: Optional[int] = None,
+ data_format: Optional[ChannelDimension] = None,
+ return_numpy: bool = True,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ """
+ Resizes `image` to `(height, width)` specified by `size` using the PIL library.
+
+ Args:
+ image (`np.ndarray`):
+ The image to resize.
+ size (`Tuple[int, int]`):
+ The size to use for resizing the image.
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ The filter to user for resampling.
+ reducing_gap (`int`, *optional*):
+ Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to
+ the fair resampling. See corresponding Pillow documentation for more details.
+ data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
+ return_numpy (`bool`, *optional*, defaults to `True`):
+ Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
+ returned.
+ input_data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ requires_backends(resize, ["vision"])
+
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
+
+ if not len(size) == 2:
+ raise ValueError("size must have 2 elements")
+
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
+ # The resized image from PIL will always have channels last, so find the input format first.
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+ data_format = input_data_format if data_format is None else data_format
+
+ # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
+ # the pillow library to resize the image and then convert back to numpy
+ do_rescale = False
+ if not isinstance(image, PIL.Image.Image):
+ do_rescale = _rescale_for_pil_conversion(image)
+ image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
+ height, width = size
+ # PIL images are in the format (width, height)
+ resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
+
+ if return_numpy:
+ resized_image = np.array(resized_image)
+ # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
+ # so we need to add it back if necessary.
+ resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
+ # The image is always in channels last format after converting from a PIL image
+ resized_image = to_channel_dimension_format(
+ resized_image, data_format, input_channel_dim=ChannelDimension.LAST
+ )
+ # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
+ # rescale it back to the original range.
+ resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
+ return resized_image
+
+
+def normalize(
+ image: np.ndarray,
+ mean: Union[float, Iterable[float]],
+ std: Union[float, Iterable[float]],
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ """
+ Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
+
+ image = (image - mean) / std
+
+ Args:
+ image (`np.ndarray`):
+ The image to normalize.
+ mean (`float` or `Iterable[float]`):
+ The mean to use for normalization.
+ std (`float` or `Iterable[float]`):
+ The standard deviation to use for normalization.
+ data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
+ input_data_format (`ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
+ """
+ if not isinstance(image, np.ndarray):
+ raise ValueError("image must be a numpy array")
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
+ num_channels = image.shape[channel_axis]
+
+ # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
+ # We preserve the original dtype if it is a float type to prevent upcasting float16.
+ if not np.issubdtype(image.dtype, np.floating):
+ image = image.astype(np.float32)
+
+ if isinstance(mean, Iterable):
+ if len(mean) != num_channels:
+ raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
+ else:
+ mean = [mean] * num_channels
+ mean = np.array(mean, dtype=image.dtype)
+
+ if isinstance(std, Iterable):
+ if len(std) != num_channels:
+ raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
+ else:
+ std = [std] * num_channels
+ std = np.array(std, dtype=image.dtype)
+
+ if input_data_format == ChannelDimension.LAST:
+ image = (image - mean) / std
+ else:
+ image = ((image.T - mean) / std).T
+
+ image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
+ return image
+
+
+def center_crop(
+ image: np.ndarray,
+ size: Tuple[int, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ return_numpy: Optional[bool] = None,
+) -> np.ndarray:
+ """
+ Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
+ the size given, it will be padded (so the returned result will always be of size `size`).
+
+ Args:
+ image (`np.ndarray`):
+ The image to crop.
+ size (`Tuple[int, int]`):
+ The target size for the cropped image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+ return_numpy (`bool`, *optional*):
+ Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
+ previous ImageFeatureExtractionMixin method.
+ - Unset: will return the same type as the input image.
+ - `True`: will return a numpy array.
+ - `False`: will return a `PIL.Image.Image` object.
+ Returns:
+ `np.ndarray`: The cropped image.
+ """
+ requires_backends(center_crop, ["vision"])
+
+ if return_numpy is not None:
+ warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)
+
+ return_numpy = True if return_numpy is None else return_numpy
+
+ if not isinstance(image, np.ndarray):
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
+
+ if not isinstance(size, Iterable) or len(size) != 2:
+ raise ValueError("size must have 2 elements representing the height and width of the output image")
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+ output_data_format = data_format if data_format is not None else input_data_format
+
+ # We perform the crop in (C, H, W) format and then convert to the output format
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
+
+ orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
+ crop_height, crop_width = size
+ crop_height, crop_width = int(crop_height), int(crop_width)
+
+ # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
+ top = (orig_height - crop_height) // 2
+ bottom = top + crop_height
+ # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
+ left = (orig_width - crop_width) // 2
+ right = left + crop_width
+
+ # Check if cropped area is within image boundaries
+ if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
+ image = image[..., top:bottom, left:right]
+ image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
+ return image
+
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
+ new_height = max(crop_height, orig_height)
+ new_width = max(crop_width, orig_width)
+ new_shape = image.shape[:-2] + (new_height, new_width)
+ new_image = np.zeros_like(image, shape=new_shape)
+
+ # If the image is too small, pad it with zeros
+ top_pad = ceil((new_height - orig_height) / 2)
+ bottom_pad = top_pad + orig_height
+ left_pad = ceil((new_width - orig_width) / 2)
+ right_pad = left_pad + orig_width
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
+
+ top += top_pad
+ bottom += top_pad
+ left += left_pad
+ right += left_pad
+
+ new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
+ new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
+
+ if not return_numpy:
+ new_image = to_pil_image(new_image)
+
+ return new_image
+
+
+def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
+ center_x, center_y, width, height = bboxes_center.unbind(-1)
+ bbox_corners = torch.stack(
+ # top left x, top left y, bottom right x, bottom right y
+ [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
+ dim=-1,
+ )
+ return bbox_corners
+
+
+def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
+ center_x, center_y, width, height = bboxes_center.T
+ bboxes_corners = np.stack(
+ # top left x, top left y, bottom right x, bottom right y
+ [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
+ axis=-1,
+ )
+ return bboxes_corners
+
+
+def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor":
+ center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1)
+ bboxes_corners = tf.stack(
+ # top left x, top left y, bottom right x, bottom right y
+ [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
+ axis=-1,
+ )
+ return bboxes_corners
+
+
+# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
+def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
+ """
+ Converts bounding boxes from center format to corners format.
+
+ center format: contains the coordinate for the center of the box and its width, height dimensions
+ (center_x, center_y, width, height)
+ corners format: contains the coodinates for the top-left and bottom-right corners of the box
+ (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
+ """
+ # Function is used during model forward pass, so we use the input framework if possible, without
+ # converting to numpy
+ if is_torch_tensor(bboxes_center):
+ return _center_to_corners_format_torch(bboxes_center)
+ elif isinstance(bboxes_center, np.ndarray):
+ return _center_to_corners_format_numpy(bboxes_center)
+ elif is_tf_tensor(bboxes_center):
+ return _center_to_corners_format_tf(bboxes_center)
+
+ raise ValueError(f"Unsupported input type {type(bboxes_center)}")
+
+
+def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
+ b = [
+ (top_left_x + bottom_right_x) / 2, # center x
+ (top_left_y + bottom_right_y) / 2, # center y
+ (bottom_right_x - top_left_x), # width
+ (bottom_right_y - top_left_y), # height
+ ]
+ return torch.stack(b, dim=-1)
+
+
+def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
+ bboxes_center = np.stack(
+ [
+ (top_left_x + bottom_right_x) / 2, # center x
+ (top_left_y + bottom_right_y) / 2, # center y
+ (bottom_right_x - top_left_x), # width
+ (bottom_right_y - top_left_y), # height
+ ],
+ axis=-1,
+ )
+ return bboxes_center
+
+
+def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor":
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1)
+ bboxes_center = tf.stack(
+ [
+ (top_left_x + bottom_right_x) / 2, # center x
+ (top_left_y + bottom_right_y) / 2, # center y
+ (bottom_right_x - top_left_x), # width
+ (bottom_right_y - top_left_y), # height
+ ],
+ axis=-1,
+ )
+ return bboxes_center
+
+
+def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
+ """
+ Converts bounding boxes from corners format to center format.
+
+ corners format: contains the coordinates for the top-left and bottom-right corners of the box
+ (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
+ center format: contains the coordinate for the center of the box and its the width, height dimensions
+ (center_x, center_y, width, height)
+ """
+ # Inverse function accepts different input types so implemented here too
+ if is_torch_tensor(bboxes_corners):
+ return _corners_to_center_format_torch(bboxes_corners)
+ elif isinstance(bboxes_corners, np.ndarray):
+ return _corners_to_center_format_numpy(bboxes_corners)
+ elif is_tf_tensor(bboxes_corners):
+ return _corners_to_center_format_tf(bboxes_corners)
+
+ raise ValueError(f"Unsupported input type {type(bboxes_corners)}")
+
+
+# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
+# Copyright (c) 2018, Alexander Kirillov
+# All rights reserved.
+def rgb_to_id(color):
+ """
+ Converts RGB color to unique ID.
+ """
+ if isinstance(color, np.ndarray) and len(color.shape) == 3:
+ if color.dtype == np.uint8:
+ color = color.astype(np.int32)
+ return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
+ return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
+
+
+def id_to_rgb(id_map):
+ """
+ Converts unique ID to RGB color.
+ """
+ if isinstance(id_map, np.ndarray):
+ id_map_copy = id_map.copy()
+ rgb_shape = tuple(list(id_map.shape) + [3])
+ rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
+ for i in range(3):
+ rgb_map[..., i] = id_map_copy % 256
+ id_map_copy //= 256
+ return rgb_map
+ color = []
+ for _ in range(3):
+ color.append(id_map % 256)
+ id_map //= 256
+ return color
+
+
+class PaddingMode(ExplicitEnum):
+ """
+ Enum class for the different padding modes to use when padding images.
+ """
+
+ CONSTANT = "constant"
+ REFLECT = "reflect"
+ REPLICATE = "replicate"
+ SYMMETRIC = "symmetric"
+
+
+def pad(
+ image: np.ndarray,
+ padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
+ mode: PaddingMode = PaddingMode.CONSTANT,
+ constant_values: Union[float, Iterable[float]] = 0.0,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ """
+ Pads the `image` with the specified (height, width) `padding` and `mode`.
+
+ Args:
+ image (`np.ndarray`):
+ The image to pad.
+ padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
+ - `((before, after),)` yields same before and after pad for height and width.
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
+ mode (`PaddingMode`):
+ The padding mode to use. Can be one of:
+ - `"constant"`: pads with a constant value.
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
+ vector along each axis.
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ `np.ndarray`: The padded image.
+
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ def _expand_for_data_format(values):
+ """
+ Convert values to be in the format expected by np.pad based on the data format.
+ """
+ if isinstance(values, (int, float)):
+ values = ((values, values), (values, values))
+ elif isinstance(values, tuple) and len(values) == 1:
+ values = ((values[0], values[0]), (values[0], values[0]))
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
+ values = (values, values)
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
+ values = values
+ else:
+ raise ValueError(f"Unsupported format: {values}")
+
+ # add 0 for channel dimension
+ values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))
+
+ # Add additional padding if there's a batch dimension
+ values = (0, *values) if image.ndim == 4 else values
+ return values
+
+ padding = _expand_for_data_format(padding)
+
+ if mode == PaddingMode.CONSTANT:
+ constant_values = _expand_for_data_format(constant_values)
+ image = np.pad(image, padding, mode="constant", constant_values=constant_values)
+ elif mode == PaddingMode.REFLECT:
+ image = np.pad(image, padding, mode="reflect")
+ elif mode == PaddingMode.REPLICATE:
+ image = np.pad(image, padding, mode="edge")
+ elif mode == PaddingMode.SYMMETRIC:
+ image = np.pad(image, padding, mode="symmetric")
+ else:
+ raise ValueError(f"Invalid padding mode: {mode}")
+
+ image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
+ return image
+
+
+# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
+def convert_to_rgb(image: ImageInput) -> ImageInput:
+ """
+ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
+ as is.
+ Args:
+ image (Image):
+ The image to convert.
+ """
+ requires_backends(convert_to_rgb, ["vision"])
+
+ if not isinstance(image, PIL.Image.Image):
+ return image
+
+ if image.mode == "RGB":
+ return image
+
+ image = image.convert("RGB")
+ return image
+
+
+def flip_channel_order(
+ image: np.ndarray,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+ """
+ Flips the channel order of the image.
+
+ If the image is in RGB format, it will be converted to BGR and vice versa.
+
+ Args:
+ image (`np.ndarray`):
+ The image to flip.
+ data_format (`ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+ """
+ input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
+
+ if input_data_format == ChannelDimension.LAST:
+ image = image[..., ::-1]
+ elif input_data_format == ChannelDimension.FIRST:
+ image = image[::-1, ...]
+ else:
+ raise ValueError(f"Unsupported channel dimension: {input_data_format}")
+
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ return image
+
+
+def _cast_tensor_to_float(x):
+ if x.is_floating_point():
+ return x
+ return x.float()
+
+
+class FusedRescaleNormalize:
+ """
+ Rescale and normalize the input image in one step.
+ """
+
+ def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False):
+ self.mean = torch.tensor(mean) * (1.0 / rescale_factor)
+ self.std = torch.tensor(std) * (1.0 / rescale_factor)
+ self.inplace = inplace
+
+ def __call__(self, image: "torch.Tensor"):
+ image = _cast_tensor_to_float(image)
+ return F.normalize(image, self.mean, self.std, inplace=self.inplace)
+
+
+class Rescale:
+ """
+ Rescale the input image by rescale factor: image *= rescale_factor.
+ """
+
+ def __init__(self, rescale_factor: float = 1.0):
+ self.rescale_factor = rescale_factor
+
+ def __call__(self, image: "torch.Tensor"):
+ image = image * self.rescale_factor
+ return image
+
+
+class NumpyToTensor:
+ """
+ Convert a numpy array to a PyTorch tensor.
+ """
+
+ def __call__(self, image: np.ndarray):
+ # Same as in PyTorch, we assume incoming numpy images are in HWC format
+ # c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
+ return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()
diff --git a/.venv/lib/python3.11/site-packages/transformers/image_utils.py b/.venv/lib/python3.11/site-packages/transformers/image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51199d9f3698fc6212b5f8b3c90144fbf147ad41
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/image_utils.py
@@ -0,0 +1,871 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import os
+from io import BytesIO
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+import requests
+from packaging import version
+
+from .utils import (
+ ExplicitEnum,
+ TensorType,
+ is_jax_tensor,
+ is_numpy_array,
+ is_tf_tensor,
+ is_torch_available,
+ is_torch_tensor,
+ is_torchvision_available,
+ is_vision_available,
+ logging,
+ requires_backends,
+ to_numpy,
+)
+from .utils.constants import ( # noqa: F401
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+)
+
+
+if is_vision_available():
+ import PIL.Image
+ import PIL.ImageOps
+
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PILImageResampling = PIL.Image.Resampling
+ else:
+ PILImageResampling = PIL.Image
+
+ if is_torchvision_available():
+ from torchvision.transforms import InterpolationMode
+
+ pil_torch_interpolation_mapping = {
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
+ PILImageResampling.BOX: InterpolationMode.BOX,
+ PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
+ PILImageResampling.HAMMING: InterpolationMode.HAMMING,
+ PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
+ PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
+ }
+
+
+if TYPE_CHECKING:
+ if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+ImageInput = Union[
+ "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
+] # noqa
+
+
+VideoInput = Union[
+ List["PIL.Image.Image"],
+ "np.ndarray",
+ "torch.Tensor",
+ List["np.ndarray"],
+ List["torch.Tensor"],
+ List[List["PIL.Image.Image"]],
+ List[List["np.ndarrray"]],
+ List[List["torch.Tensor"]],
+] # noqa
+
+
+class ChannelDimension(ExplicitEnum):
+ FIRST = "channels_first"
+ LAST = "channels_last"
+
+
+class AnnotationFormat(ExplicitEnum):
+ COCO_DETECTION = "coco_detection"
+ COCO_PANOPTIC = "coco_panoptic"
+
+
+class AnnotionFormat(ExplicitEnum):
+ COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
+ COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
+
+
+AnnotationType = Dict[str, Union[int, str, List[Dict]]]
+
+
+def is_pil_image(img):
+ return is_vision_available() and isinstance(img, PIL.Image.Image)
+
+
+class ImageType(ExplicitEnum):
+ PIL = "pillow"
+ TORCH = "torch"
+ NUMPY = "numpy"
+ TENSORFLOW = "tensorflow"
+ JAX = "jax"
+
+
+def get_image_type(image):
+ if is_pil_image(image):
+ return ImageType.PIL
+ if is_torch_tensor(image):
+ return ImageType.TORCH
+ if is_numpy_array(image):
+ return ImageType.NUMPY
+ if is_tf_tensor(image):
+ return ImageType.TENSORFLOW
+ if is_jax_tensor(image):
+ return ImageType.JAX
+ raise ValueError(f"Unrecognised image type {type(image)}")
+
+
+def is_valid_image(img):
+ return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
+
+
+def valid_images(imgs):
+ # If we have an list of images, make sure every image is valid
+ if isinstance(imgs, (list, tuple)):
+ for img in imgs:
+ if not valid_images(img):
+ return False
+ # If not a list of tuple, we have been given a single image or batched tensor of images
+ elif not is_valid_image(imgs):
+ return False
+ return True
+
+
+def is_batched(img):
+ if isinstance(img, (list, tuple)):
+ return is_valid_image(img[0])
+ return False
+
+
+def is_scaled_image(image: np.ndarray) -> bool:
+ """
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
+ """
+ if image.dtype == np.uint8:
+ return False
+
+ # It's possible the image has pixel values in [0, 255] but is of floating type
+ return np.min(image) >= 0 and np.max(image) <= 1
+
+
+def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
+ """
+ Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
+ If the input is a batch of images, it is converted to a list of images.
+
+ Args:
+ images (`ImageInput`):
+ Image of images to turn into a list of images.
+ expected_ndims (`int`, *optional*, defaults to 3):
+ Expected number of dimensions for a single input image. If the input image has a different number of
+ dimensions, an error is raised.
+ """
+ if is_batched(images):
+ return images
+
+ # Either the input is a single image, in which case we create a list of length 1
+ if isinstance(images, PIL.Image.Image):
+ # PIL images are never batched
+ return [images]
+
+ if is_valid_image(images):
+ if images.ndim == expected_ndims + 1:
+ # Batch of images
+ images = list(images)
+ elif images.ndim == expected_ndims:
+ # Single image
+ images = [images]
+ else:
+ raise ValueError(
+ f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
+ f" {images.ndim} dimensions."
+ )
+ return images
+ raise ValueError(
+ "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
+ f"jax.ndarray, but got {type(images)}."
+ )
+
+
+def to_numpy_array(img) -> np.ndarray:
+ if not is_valid_image(img):
+ raise ValueError(f"Invalid image type: {type(img)}")
+
+ if is_vision_available() and isinstance(img, PIL.Image.Image):
+ return np.array(img)
+ return to_numpy(img)
+
+
+def infer_channel_dimension_format(
+ image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
+) -> ChannelDimension:
+ """
+ Infers the channel dimension format of `image`.
+
+ Args:
+ image (`np.ndarray`):
+ The image to infer the channel dimension of.
+ num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
+ The number of channels of the image.
+
+ Returns:
+ The channel dimension of the image.
+ """
+ num_channels = num_channels if num_channels is not None else (1, 3)
+ num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
+
+ if image.ndim == 3:
+ first_dim, last_dim = 0, 2
+ elif image.ndim == 4:
+ first_dim, last_dim = 1, 3
+ else:
+ raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
+
+ if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
+ logger.warning(
+ f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
+ )
+ return ChannelDimension.FIRST
+ elif image.shape[first_dim] in num_channels:
+ return ChannelDimension.FIRST
+ elif image.shape[last_dim] in num_channels:
+ return ChannelDimension.LAST
+ raise ValueError("Unable to infer channel dimension format")
+
+
+def get_channel_dimension_axis(
+ image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
+) -> int:
+ """
+ Returns the channel dimension axis of the image.
+
+ Args:
+ image (`np.ndarray`):
+ The image to get the channel dimension axis of.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
+
+ Returns:
+ The channel dimension axis of the image.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+ if input_data_format == ChannelDimension.FIRST:
+ return image.ndim - 3
+ elif input_data_format == ChannelDimension.LAST:
+ return image.ndim - 1
+ raise ValueError(f"Unsupported data format: {input_data_format}")
+
+
+def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
+ """
+ Returns the (height, width) dimensions of the image.
+
+ Args:
+ image (`np.ndarray`):
+ The image to get the dimensions of.
+ channel_dim (`ChannelDimension`, *optional*):
+ Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
+
+ Returns:
+ A tuple of the image's height and width.
+ """
+ if channel_dim is None:
+ channel_dim = infer_channel_dimension_format(image)
+
+ if channel_dim == ChannelDimension.FIRST:
+ return image.shape[-2], image.shape[-1]
+ elif channel_dim == ChannelDimension.LAST:
+ return image.shape[-3], image.shape[-2]
+ else:
+ raise ValueError(f"Unsupported data format: {channel_dim}")
+
+
+def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
+ if (
+ isinstance(annotation, dict)
+ and "image_id" in annotation
+ and "annotations" in annotation
+ and isinstance(annotation["annotations"], (list, tuple))
+ and (
+ # an image can have no annotations
+ len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
+ )
+ ):
+ return True
+ return False
+
+
+def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
+ if (
+ isinstance(annotation, dict)
+ and "image_id" in annotation
+ and "segments_info" in annotation
+ and "file_name" in annotation
+ and isinstance(annotation["segments_info"], (list, tuple))
+ and (
+ # an image can have no segments
+ len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
+ )
+ ):
+ return True
+ return False
+
+
+def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
+ return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
+
+
+def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
+ return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
+
+
+def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
+ """
+ Loads `image` to a PIL Image.
+
+ Args:
+ image (`str` or `PIL.Image.Image`):
+ The image to convert to the PIL Image format.
+ timeout (`float`, *optional*):
+ The timeout value in seconds for the URL request.
+
+ Returns:
+ `PIL.Image.Image`: A PIL Image.
+ """
+ requires_backends(load_image, ["vision"])
+ if isinstance(image, str):
+ if image.startswith("http://") or image.startswith("https://"):
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
+ # like http_huggingface_co.png
+ image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
+ elif os.path.isfile(image):
+ image = PIL.Image.open(image)
+ else:
+ if image.startswith("data:image/"):
+ image = image.split(",")[1]
+
+ # Try to load as base64
+ try:
+ b64 = base64.decodebytes(image.encode())
+ image = PIL.Image.open(BytesIO(b64))
+ except Exception as e:
+ raise ValueError(
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
+ )
+ elif isinstance(image, PIL.Image.Image):
+ image = image
+ else:
+ raise TypeError(
+ "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
+ )
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+def load_images(
+ images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
+) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:
+ """Loads images, handling different levels of nesting.
+
+ Args:
+ images: A single image, a list of images, or a list of lists of images to load.
+ timeout: Timeout for loading images.
+
+ Returns:
+ A single image, a list of images, a list of lists of images.
+ """
+ if isinstance(images, (list, tuple)):
+ if len(images) and isinstance(images[0], (list, tuple)):
+ return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
+ else:
+ return [load_image(image, timeout=timeout) for image in images]
+ else:
+ return load_image(images, timeout=timeout)
+
+
+def validate_preprocess_arguments(
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ size_divisibility: Optional[int] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[Dict[str, int]] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional["PILImageResampling"] = None,
+):
+ """
+ Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
+ Raises `ValueError` if arguments incompatibility is caught.
+ Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
+ sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
+ existing arguments when possible.
+
+ """
+ if do_rescale and rescale_factor is None:
+ raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
+
+ if do_pad and size_divisibility is None:
+ # Here, size_divisor might be passed as the value of size
+ raise ValueError(
+ "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
+ )
+
+ if do_normalize and (image_mean is None or image_std is None):
+ raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
+
+ if do_center_crop and crop_size is None:
+ raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
+
+ if do_resize and (size is None or resample is None):
+ raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
+
+
+def validate_fast_preprocess_arguments(
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ size_divisibility: Optional[int] = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[Dict[str, int]] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample: Optional["PILImageResampling"] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+):
+ """
+ Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
+ Raises `ValueError` if arguments incompatibility is caught.
+ """
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # Extra checks for ImageProcessorFast
+ if return_tensors != "pt":
+ raise ValueError("Only returning PyTorch tensors is currently supported.")
+
+ if data_format != ChannelDimension.FIRST:
+ raise ValueError("Only channel first data format is currently supported.")
+
+
+# In the future we can add a TF implementation here when we have TF models.
+class ImageFeatureExtractionMixin:
+ """
+ Mixin that contain utilities for preparing image features.
+ """
+
+ def _ensure_format_supported(self, image):
+ if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
+ raise ValueError(
+ f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
+ "`torch.Tensor` are."
+ )
+
+ def to_pil_image(self, image, rescale=None):
+ """
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
+ needed.
+
+ Args:
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
+ The image to convert to the PIL Image format.
+ rescale (`bool`, *optional*):
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
+ default to `True` if the image type is a floating type, `False` otherwise.
+ """
+ self._ensure_format_supported(image)
+
+ if is_torch_tensor(image):
+ image = image.numpy()
+
+ if isinstance(image, np.ndarray):
+ if rescale is None:
+ # rescale default to the array being of floating type.
+ rescale = isinstance(image.flat[0], np.floating)
+ # If the channel as been moved to first dim, we put it back at the end.
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
+ image = image.transpose(1, 2, 0)
+ if rescale:
+ image = image * 255
+ image = image.astype(np.uint8)
+ return PIL.Image.fromarray(image)
+ return image
+
+ def convert_rgb(self, image):
+ """
+ Converts `PIL.Image.Image` to RGB format.
+
+ Args:
+ image (`PIL.Image.Image`):
+ The image to convert.
+ """
+ self._ensure_format_supported(image)
+ if not isinstance(image, PIL.Image.Image):
+ return image
+
+ return image.convert("RGB")
+
+ def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
+ """
+ Rescale a numpy image by scale amount
+ """
+ self._ensure_format_supported(image)
+ return image * scale
+
+ def to_numpy_array(self, image, rescale=None, channel_first=True):
+ """
+ Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
+ dimension.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to convert to a NumPy array.
+ rescale (`bool`, *optional*):
+ Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
+ default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
+ channel_first (`bool`, *optional*, defaults to `True`):
+ Whether or not to permute the dimensions of the image to put the channel dimension first.
+ """
+ self._ensure_format_supported(image)
+
+ if isinstance(image, PIL.Image.Image):
+ image = np.array(image)
+
+ if is_torch_tensor(image):
+ image = image.numpy()
+
+ rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
+
+ if rescale:
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
+
+ if channel_first and image.ndim == 3:
+ image = image.transpose(2, 0, 1)
+
+ return image
+
+ def expand_dims(self, image):
+ """
+ Expands 2-dimensional `image` to 3 dimensions.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to expand.
+ """
+ self._ensure_format_supported(image)
+
+ # Do nothing if PIL image
+ if isinstance(image, PIL.Image.Image):
+ return image
+
+ if is_torch_tensor(image):
+ image = image.unsqueeze(0)
+ else:
+ image = np.expand_dims(image, axis=0)
+ return image
+
+ def normalize(self, image, mean, std, rescale=False):
+ """
+ Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
+ if it's a PIL Image.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to normalize.
+ mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
+ The mean (per channel) to use for normalization.
+ std (`List[float]` or `np.ndarray` or `torch.Tensor`):
+ The standard deviation (per channel) to use for normalization.
+ rescale (`bool`, *optional*, defaults to `False`):
+ Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
+ happen automatically.
+ """
+ self._ensure_format_supported(image)
+
+ if isinstance(image, PIL.Image.Image):
+ image = self.to_numpy_array(image, rescale=True)
+ # If the input image is a PIL image, it automatically gets rescaled. If it's another
+ # type it may need rescaling.
+ elif rescale:
+ if isinstance(image, np.ndarray):
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
+ elif is_torch_tensor(image):
+ image = self.rescale(image.float(), 1 / 255.0)
+
+ if isinstance(image, np.ndarray):
+ if not isinstance(mean, np.ndarray):
+ mean = np.array(mean).astype(image.dtype)
+ if not isinstance(std, np.ndarray):
+ std = np.array(std).astype(image.dtype)
+ elif is_torch_tensor(image):
+ import torch
+
+ if not isinstance(mean, torch.Tensor):
+ if isinstance(mean, np.ndarray):
+ mean = torch.from_numpy(mean)
+ else:
+ mean = torch.tensor(mean)
+ if not isinstance(std, torch.Tensor):
+ if isinstance(std, np.ndarray):
+ std = torch.from_numpy(std)
+ else:
+ std = torch.tensor(std)
+
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
+ return (image - mean[:, None, None]) / std[:, None, None]
+ else:
+ return (image - mean) / std
+
+ def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
+ """
+ Resizes `image`. Enforces conversion of input to PIL.Image.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to resize.
+ size (`int` or `Tuple[int, int]`):
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
+ matched to this.
+
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
+ this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ The filter to user for resampling.
+ default_to_square (`bool`, *optional*, defaults to `True`):
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
+ square (`size`,`size`). If set to `False`, will replicate
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
+ with support for resizing only the smallest edge and providing an optional `max_size`.
+ max_size (`int`, *optional*, defaults to `None`):
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
+ greater than `max_size` after being resized according to `size`, then the image is resized again so
+ that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
+ edge may be shorter than `size`. Only used if `default_to_square` is `False`.
+
+ Returns:
+ image: A resized `PIL.Image.Image`.
+ """
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
+
+ self._ensure_format_supported(image)
+
+ if not isinstance(image, PIL.Image.Image):
+ image = self.to_pil_image(image)
+
+ if isinstance(size, list):
+ size = tuple(size)
+
+ if isinstance(size, int) or len(size) == 1:
+ if default_to_square:
+ size = (size, size) if isinstance(size, int) else (size[0], size[0])
+ else:
+ width, height = image.size
+ # specified size only for the smallest edge
+ short, long = (width, height) if width <= height else (height, width)
+ requested_new_short = size if isinstance(size, int) else size[0]
+
+ if short == requested_new_short:
+ return image
+
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
+
+ if max_size is not None:
+ if max_size <= requested_new_short:
+ raise ValueError(
+ f"max_size = {max_size} must be strictly greater than the requested "
+ f"size for the smaller edge size = {size}"
+ )
+ if new_long > max_size:
+ new_short, new_long = int(max_size * new_short / new_long), max_size
+
+ size = (new_short, new_long) if width <= height else (new_long, new_short)
+
+ return image.resize(size, resample=resample)
+
+ def center_crop(self, image, size):
+ """
+ Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
+ size given, it will be padded (so the returned result has the size asked).
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
+ The image to resize.
+ size (`int` or `Tuple[int, int]`):
+ The size to which crop the image.
+
+ Returns:
+ new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
+ height, width).
+ """
+ self._ensure_format_supported(image)
+
+ if not isinstance(size, tuple):
+ size = (size, size)
+
+ # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
+ if is_torch_tensor(image) or isinstance(image, np.ndarray):
+ if image.ndim == 2:
+ image = self.expand_dims(image)
+ image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
+ else:
+ image_shape = (image.size[1], image.size[0])
+
+ top = (image_shape[0] - size[0]) // 2
+ bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
+ left = (image_shape[1] - size[1]) // 2
+ right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
+
+ # For PIL Images we have a method to crop directly.
+ if isinstance(image, PIL.Image.Image):
+ return image.crop((left, top, right, bottom))
+
+ # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
+ channel_first = True if image.shape[0] in [1, 3] else False
+
+ # Transpose (height, width, n_channels) format images
+ if not channel_first:
+ if isinstance(image, np.ndarray):
+ image = image.transpose(2, 0, 1)
+ if is_torch_tensor(image):
+ image = image.permute(2, 0, 1)
+
+ # Check if cropped area is within image boundaries
+ if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
+ return image[..., top:bottom, left:right]
+
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
+ new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
+ if isinstance(image, np.ndarray):
+ new_image = np.zeros_like(image, shape=new_shape)
+ elif is_torch_tensor(image):
+ new_image = image.new_zeros(new_shape)
+
+ top_pad = (new_shape[-2] - image_shape[0]) // 2
+ bottom_pad = top_pad + image_shape[0]
+ left_pad = (new_shape[-1] - image_shape[1]) // 2
+ right_pad = left_pad + image_shape[1]
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
+
+ top += top_pad
+ bottom += top_pad
+ left += left_pad
+ right += left_pad
+
+ new_image = new_image[
+ ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
+ ]
+
+ return new_image
+
+ def flip_channel_order(self, image):
+ """
+ Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
+ `image` to a NumPy array if it's a PIL Image.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
+ be first.
+ """
+ self._ensure_format_supported(image)
+
+ if isinstance(image, PIL.Image.Image):
+ image = self.to_numpy_array(image)
+
+ return image[::-1, :, :]
+
+ def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
+ """
+ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
+ counter clockwise around its centre.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
+ rotating.
+
+ Returns:
+ image: A rotated `PIL.Image.Image`.
+ """
+ resample = resample if resample is not None else PIL.Image.NEAREST
+
+ self._ensure_format_supported(image)
+
+ if not isinstance(image, PIL.Image.Image):
+ image = self.to_pil_image(image)
+
+ return image.rotate(
+ angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
+ )
+
+
+def validate_annotations(
+ annotation_format: AnnotationFormat,
+ supported_annotation_formats: Tuple[AnnotationFormat, ...],
+ annotations: List[Dict],
+) -> None:
+ if annotation_format not in supported_annotation_formats:
+ raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
+
+ if annotation_format is AnnotationFormat.COCO_DETECTION:
+ if not valid_coco_detection_annotations(annotations):
+ raise ValueError(
+ "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
+ "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
+ "being a list of annotations in the COCO format."
+ )
+
+ if annotation_format is AnnotationFormat.COCO_PANOPTIC:
+ if not valid_coco_panoptic_annotations(annotations):
+ raise ValueError(
+ "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
+ "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
+ "the latter being a list of annotations in the COCO format."
+ )
+
+
+def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
+ unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
+ if unused_keys:
+ unused_key_str = ", ".join(unused_keys)
+ # TODO raise a warning here instead of simply logging?
+ logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
diff --git a/.venv/lib/python3.11/site-packages/transformers/keras_callbacks.py b/.venv/lib/python3.11/site-packages/transformers/keras_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e832729a1eeb482d1193753cc2c07ad1f16c2e
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/keras_callbacks.py
@@ -0,0 +1,413 @@
+import logging
+import os
+from pathlib import Path
+from time import sleep
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import tensorflow as tf
+from huggingface_hub import Repository, create_repo
+from packaging.version import parse
+
+from . import IntervalStrategy, PreTrainedTokenizerBase
+from .modelcard import TrainingSummary
+from .modeling_tf_utils import keras
+
+
+logger = logging.getLogger(__name__)
+
+
+class KerasMetricCallback(keras.callbacks.Callback):
+ """
+ Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
+ compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
+ operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the
+ `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute
+ metrics and return a dict mapping metric names to metric values.
+
+ We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that
+ this example skips some post-processing for readability and simplicity, and should probably not be used as-is!
+
+ ```py
+ from datasets import load_metric
+
+ rouge_metric = load_metric("rouge")
+
+
+ def rouge_fn(predictions, labels):
+ decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
+ result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)
+ return {key: value.mid.fmeasure * 100 for key, value in result.items()}
+ ```
+
+ The above function will return a dict containing values which will be logged like any other Keras metric:
+
+ ```
+ {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781
+ ```
+
+ Args:
+ metric_fn (`Callable`):
+ Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`.
+ These contain the model's outputs and matching labels from the dataset. It should return a dict mapping
+ metric names to numerical values.
+ eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
+ Validation data to be used to generate predictions for the `metric_fn`.
+ output_cols (`List[str], *optional*):
+ A list of columns to be retained from the model output as the predictions. Defaults to all.
+ label_cols ('`List[str]`, *optional*'):
+ A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not
+ supplied.
+ batch_size (`int`, *optional*):
+ Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
+ predict_with_generate (`bool`, *optional*, defaults to `False`):
+ Whether we should use `model.generate()` to get outputs for the model.
+ use_xla_generation (`bool`, *optional*, defaults to `False`):
+ If we're generating, whether to compile model generation with XLA. This can massively increase the speed of
+ generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA
+ generation, it's a good idea to pad your inputs to the same size, or to use the `pad_to_multiple_of`
+ argument in your `tokenizer` or `DataCollator`, which will reduce the number of unique input shapes and
+ save a lot of compilation time. This option has no effect is `predict_with_generate` is `False`.
+ generate_kwargs (`dict`, *optional*):
+ Keyword arguments to pass to `model.generate()` when generating. Has no effect if `predict_with_generate`
+ is `False`.
+
+ """
+
+ def __init__(
+ self,
+ metric_fn: Callable,
+ eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
+ output_cols: Optional[List[str]] = None,
+ label_cols: Optional[List[str]] = None,
+ batch_size: Optional[int] = None,
+ predict_with_generate: bool = False,
+ use_xla_generation: bool = False,
+ generate_kwargs: Optional[dict] = None,
+ ):
+ super().__init__()
+ self.metric_fn = metric_fn
+ self.batch_size = batch_size
+ if not isinstance(eval_dataset, tf.data.Dataset):
+ if batch_size is None:
+ raise ValueError(
+ "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset "
+ "the batch_size argument must be set."
+ )
+ # Wrap a tf.data.Dataset around it
+ eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False)
+ self.eval_dataset = eval_dataset
+ self.predict_with_generate = predict_with_generate
+ self.output_cols = output_cols
+
+ # This next block attempts to parse out which elements of the dataset should be appended to the labels list
+ # that is passed to the metric_fn
+ if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2:
+ input_spec, label_spec = eval_dataset.element_spec
+ else:
+ input_spec = eval_dataset.element_spec
+ label_spec = None
+ if label_cols is not None:
+ for label in label_cols:
+ if label not in input_spec:
+ raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!")
+ self.label_cols = label_cols
+ self.use_keras_label = False
+ elif label_spec is not None:
+ # If the dataset inputs are split into a 2-tuple of inputs and labels,
+ # assume the second element is the labels
+ self.label_cols = None
+ self.use_keras_label = True
+ elif "labels" in input_spec:
+ self.label_cols = ["labels"]
+ self.use_keras_label = False
+ logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
+ elif "start_positions" in input_spec and "end_positions" in input_spec:
+ self.label_cols = ["start_positions", "end_positions"]
+ self.use_keras_label = False
+ logging.warning(
+ "No label_cols specified for KerasMetricCallback, assuming you want the "
+ "start_positions and end_positions keys."
+ )
+ else:
+ raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
+ if parse(tf.__version__) < parse("2.7"):
+ logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
+
+ self.use_xla_generation = use_xla_generation
+ self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs
+
+ self.generation_function = None
+
+ @staticmethod
+ def _concatenate_batches(batches, padding_index=-100):
+ # If all batches are unidimensional or same length, do a simple concatenation
+ if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches):
+ return np.concatenate(batches, axis=0)
+
+ # Welp, they're not the same length. Let's do some padding
+ max_len = max([batch.shape[1] for batch in batches])
+ num_samples = sum([batch.shape[0] for batch in batches])
+ output = np.full_like(
+ batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
+ )
+ # i keeps track of which part of the concatenated array we're writing the next batch to
+ i = 0
+ for batch in batches:
+ output[i : i + len(batch), : batch.shape[1]] = batch
+ i += len(batch)
+ return output
+
+ def _postprocess_predictions_or_labels(self, inputs):
+ if isinstance(inputs[0], dict):
+ outputs = {}
+ for key in inputs[0].keys():
+ outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
+ # If it's a dict with only one key, just return the array
+ if len(outputs) == 1:
+ outputs = list(outputs.values())[0]
+ elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
+ outputs = []
+ for input_list in zip(*inputs):
+ outputs.append(self._concatenate_batches(input_list))
+ if len(outputs) == 1:
+ outputs = outputs[0] # If it's a list with only one element, just return the array
+ elif isinstance(inputs[0], np.ndarray):
+ outputs = self._concatenate_batches(inputs)
+ elif isinstance(inputs[0], tf.Tensor):
+ outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
+ else:
+ raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
+ return outputs
+
+ def on_epoch_end(self, epoch, logs=None):
+ if hasattr(self.model, "config"):
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
+ else:
+ ignore_keys = []
+
+ main_input_name = None
+ if self.predict_with_generate:
+ # This dense conditional recognizes the case where we have an encoder-decoder model, but
+ # avoids getting tangled up when we just have a model with a layer called 'encoder'
+ if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
+ main_input_name = self.model.encoder.main_input_name
+ else:
+ main_input_name = getattr(self.model, "main_input_name", "input_ids")
+
+ if self.use_xla_generation and self.generation_function is None:
+
+ def generation_function(inputs, attention_mask):
+ return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs)
+
+ self.generation_function = tf.function(generation_function, jit_compile=True)
+
+ prediction_list = []
+ label_list = []
+
+ # The whole predict/generate loop is handled inside this method
+ for batch in self.eval_dataset:
+ if isinstance(batch, tuple):
+ batch, labels = batch
+ else:
+ labels = None
+ if self.predict_with_generate:
+ if isinstance(batch, dict):
+ generation_inputs = batch[main_input_name]
+ attention_mask = batch.get("attention_mask", None)
+ else:
+ generation_inputs = batch
+ attention_mask = None
+ if self.use_xla_generation:
+ predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
+ else:
+ predictions = self.model.generate(
+ generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
+ )
+ else:
+ predictions = self.model.predict_on_batch(batch)
+ if isinstance(predictions, dict):
+ # This converts any dict-subclass to a regular dict
+ # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
+ predictions = dict(predictions)
+ if self.output_cols is not None:
+ predictions = {key: predictions[key] for key in self.output_cols}
+ else:
+ predictions = {
+ key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]
+ }
+ prediction_list.append(predictions)
+ if not self.use_keras_label:
+ labels = {key: batch[key].numpy() for key in self.label_cols}
+ elif isinstance(labels, dict):
+ labels = {key: array.numpy() for key, array in labels.items()}
+ elif isinstance(labels, list) or isinstance(labels, tuple):
+ labels = [array.numpy() for array in labels]
+ elif isinstance(labels, tf.Tensor):
+ labels = labels.numpy()
+ else:
+ raise TypeError(f"Confused by labels of type {type(labels)}")
+ label_list.append(labels)
+
+ all_preds = self._postprocess_predictions_or_labels(prediction_list)
+ all_labels = self._postprocess_predictions_or_labels(label_list)
+
+ metric_output = self.metric_fn((all_preds, all_labels))
+ if not isinstance(metric_output, dict):
+ raise TypeError(
+ f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
+ )
+ # This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch
+ # in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of
+ # new keys in there, which will then get read by the History callback and treated like any other metric value.
+ # I promise that I have it in writing from Chollet that this is okay.
+ logs.update(metric_output)
+
+
+class PushToHubCallback(keras.callbacks.Callback):
+ """
+ Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can
+ be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such
+ as with the `from_pretrained` method.
+
+ ```py
+ from transformers.keras_callbacks import PushToHubCallback
+
+ push_to_hub_callback = PushToHubCallback(
+ output_dir="./model_save",
+ tokenizer=tokenizer,
+ hub_model_id="gpt5-7xlarge",
+ )
+
+ model.fit(train_dataset, callbacks=[push_to_hub_callback])
+ ```
+
+ Args:
+ output_dir (`str`):
+ The output directory where the model predictions and checkpoints will be written and synced with the
+ repository on the Hub.
+ save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
+ The checkpoint save strategy to adopt during training. Possible values are:
+
+ - `"no"`: Save is done at the end of training.
+ - `"epoch"`: Save is done at the end of each epoch.
+ - `"steps"`: Save is done every `save_steps`
+ save_steps (`int`, *optional*):
+ The number of steps between saves when using the "steps" `save_strategy`.
+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
+ The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
+ hub_model_id (`str`, *optional*):
+ The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
+ which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
+ for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
+ `"organization_name/model"`.
+
+ Will default to the name of `output_dir`.
+ hub_token (`str`, *optional*):
+ The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
+ `huggingface-cli login`.
+ checkpoint (`bool`, *optional*, defaults to `False`):
+ Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
+ resumed. Only usable when `save_strategy` is `"epoch"`.
+ """
+
+ def __init__(
+ self,
+ output_dir: Union[str, Path],
+ save_strategy: Union[str, IntervalStrategy] = "epoch",
+ save_steps: Optional[int] = None,
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
+ hub_model_id: Optional[str] = None,
+ hub_token: Optional[str] = None,
+ checkpoint: bool = False,
+ **model_card_args,
+ ):
+ super().__init__()
+ if checkpoint and save_strategy != "epoch":
+ raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
+ if isinstance(save_strategy, str):
+ save_strategy = IntervalStrategy(save_strategy.lower())
+ self.save_strategy = save_strategy
+ if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
+ raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
+ self.save_steps = save_steps
+ output_dir = Path(output_dir)
+
+ # Create repo and retrieve repo_id
+ if hub_model_id is None:
+ hub_model_id = output_dir.absolute().name
+ self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id
+
+ self.output_dir = output_dir
+ self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token)
+
+ self.tokenizer = tokenizer
+ self.last_job = None
+ self.checkpoint = checkpoint
+ self.training_history = None
+ self.model_card_args = model_card_args
+
+ def on_train_begin(self, logs=None):
+ # Although we can access model.history, we have no guarantees that the History callback will fire before this
+ # one, so we keep track of it here too
+ self.training_history = []
+
+ def on_train_batch_end(self, batch, logs=None):
+ if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0:
+ if self.last_job is not None and not self.last_job.is_done:
+ return # The last upload is still running, don't start another
+ self.model.save_pretrained(self.output_dir)
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(self.output_dir)
+ _, self.last_job = self.repo.push_to_hub(
+ commit_message=f"Training in progress steps {batch}", blocking=False
+ )
+
+ def on_epoch_end(self, epoch, logs=None):
+ logs = logs.copy() # Don't accidentally write things that Keras will read later
+ if "epoch" not in logs:
+ logs["epoch"] = epoch
+ self.training_history.append(logs)
+ if self.save_strategy == IntervalStrategy.EPOCH:
+ if self.last_job is not None and not self.last_job.is_done:
+ return # The last upload is still running, don't start another
+ self.model.save_pretrained(self.output_dir)
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(self.output_dir)
+ if self.checkpoint:
+ checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
+ self.model._save_checkpoint(checkpoint_dir, epoch)
+ train_summary = TrainingSummary.from_keras(
+ model=self.model,
+ model_name=self.hub_model_id,
+ keras_history=self.training_history,
+ **self.model_card_args,
+ )
+ model_card = train_summary.to_model_card()
+ with (self.output_dir / "README.md").open("w") as f:
+ f.write(model_card)
+ _, self.last_job = self.repo.push_to_hub(
+ commit_message=f"Training in progress epoch {epoch}", blocking=False
+ )
+
+ def on_train_end(self, logs=None):
+ # Makes sure the latest version of the model is uploaded
+ if self.last_job is not None and not self.last_job.is_done:
+ logging.info("Pushing the last epoch to the Hub, this may take a while...")
+ while not self.last_job.is_done:
+ sleep(1)
+ else:
+ self.model.save_pretrained(self.output_dir)
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(self.output_dir)
+ train_summary = TrainingSummary.from_keras(
+ model=self.model,
+ model_name=self.hub_model_id,
+ keras_history=self.training_history,
+ **self.model_card_args,
+ )
+ model_card = train_summary.to_model_card()
+ with (self.output_dir / "README.md").open("w") as f:
+ f.write(model_card)
+ self.repo.push_to_hub(commit_message="End of training", blocking=True)
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/__init__.py b/.venv/lib/python3.11/site-packages/transformers/kernels/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh b/.venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..c0db0c88c9db2c09d7f601937ea0f6ac480913bf
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/cuda/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1327 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/ms_deform_attn.h b/.venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/ms_deform_attn.h
new file mode 100644
index 0000000000000000000000000000000000000000..119b1fa317d1e5fcfb61a4837e560e9248db05f3
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/deformable_detr/ms_deform_attn.h
@@ -0,0 +1,61 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+
+#include "cpu/ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "cuda/ms_deform_attn_cuda.h"
+#endif
+
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..388a73d22d4c9b561e2a887b50a1897b8cf2def9
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp
@@ -0,0 +1,40 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+
+#include
+#include
+
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..7eac8c8bcd1bf529bb9c13d54d2d4215c9e4c89f
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h
@@ -0,0 +1,32 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8ea1d7fabe2684dbb85f00fae2c47b469687cb2c
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu
@@ -0,0 +1,156 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+#include "cuda/ms_deform_im2col_cuda.cuh"
+
+#include
+#include
+#include
+#include
+
+#pragma once
+#include
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..34f8ae9cb77bbaa8cb4dd25e0cb86632db9ad05d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh
@@ -0,0 +1,1467 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+
+#include
+#include
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..fbcf4543e66bb1162f42ce2ae57e1bac92243cb4
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h
@@ -0,0 +1,29 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..c0db0c88c9db2c09d7f601937ea0f6ac480913bf
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1327 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/deta/ms_deform_attn.h b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/ms_deform_attn.h
new file mode 100644
index 0000000000000000000000000000000000000000..119b1fa317d1e5fcfb61a4837e560e9248db05f3
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/ms_deform_attn.h
@@ -0,0 +1,61 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+
+#include "cpu/ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "cuda/ms_deform_attn_cuda.h"
+#endif
+
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/deta/vision.cpp b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/vision.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6ce3875568b9ba8d660c90acc805077cca98f891
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/deta/vision.cpp
@@ -0,0 +1,16 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include "ms_deform_attn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
\ No newline at end of file
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-311.pyc b/.venv/lib/python3.11/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c9d2926cf4ef21dd0249a6847d1111f8763f43c
Binary files /dev/null and b/.venv/lib/python3.11/site-packages/transformers/kernels/falcon_mamba/__pycache__/selective_scan_with_ln_interface.cpython-311.pyc differ
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda.cu b/.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..571d5a8a8307e95aac689eb3c9333d1ad350c7de
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda.cu
@@ -0,0 +1,187 @@
+#include
+#include
+
+#define MIN_VALUE (-1e38)
+
+template
+__global__ void kernel_forward(
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
+ const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ F *__restrict__ const y = _y + _offset;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ F aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+
+ F ww = u + kk;
+ F p = max(pp, ww);
+ F e1 = exp(pp - p);
+ F e2 = exp(ww - p);
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+}
+
+template
+__global__ void kernel_forward_with_state(
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
+ const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset_s = _b * C * 3 + _c * 3;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ F *__restrict__ const y = _y + _offset;
+ F *__restrict__ const s = _s + _offset_s;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ F aa = s[0], bb = s[1], pp = s[2];
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+
+ F ww = u + kk;
+ F p = max(pp, ww);
+ F e1 = exp(pp - p);
+ F e2 = exp(ww - p);
+ y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ s[0] = aa;
+ s[1] = bb;
+ s[2] = pp;
+}
+
+template
+__global__ void kernel_backward(
+ const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
+ const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y,
+ const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk,
+ F *__restrict__ const _gv
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ F u = _u[_c];
+ F w = _w[_c];
+ const F *__restrict__ const k = _k + _offset;
+ const F *__restrict__ const v = _v + _offset;
+ const F *__restrict__ const y = _y + _offset;
+ const F *__restrict__ const gy = _gy + _offset;
+ F *__restrict__ const gk = _gk + _offset;
+ F *__restrict__ const gv = _gv + _offset;
+
+ F q[Tmax], r[Tmax];
+
+ F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+ const F yy = y[ii];
+
+ F ww = u + kk;
+ F p = max(pp, ww);
+ F e1 = exp(pp - p);
+ F e2 = exp(ww - p);
+ const F qq = gy[ii] / (e1 * bb + e2);
+ gw += (ga - gb * yy) * e1 * qq;
+ gu += (vv - yy) * e2 * qq;
+ q[i] = qq;
+ r[i] = ww - p;
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ ga = e1 * (aa + ga);
+ gb = e1 * (bb + gb);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ const int _offsetBC = _b * C + _c;
+ _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
+ _gu[_offsetBC] = gu;
+
+ aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = T - 1; i >= 0; i--) {
+ const int ii = i * C;
+ const F kk = k[ii];
+ const F vv = v[ii];
+ const F yy = y[ii];
+ const F qq = q[i];
+ const F rr = r[i];
+
+ F e1 = qq * exp(rr);
+ F e2 = exp(kk + pp);
+ gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
+ gv[ii] = e1 + e2 * aa;
+
+ const F ww = w + pp;
+ const F www = rr - u - kk;
+ const F p = max(ww, www);
+ e1 = exp(ww - p);
+ e2 = qq * exp(www - p);
+ aa = e1 * aa + e2;
+ bb = e1 * bb - e2 * yy;
+ pp = p;
+ }
+}
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward<<>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward_with_state<<>>(B, T, C, w, u, k, v, y, s);
+}
+
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
+}
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu b/.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu
new file mode 100644
index 0000000000000000000000000000000000000000..042cb4aba1db98be5916aea1de86a7fed0b6510d
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_cuda_bf16.cu
@@ -0,0 +1,186 @@
+#include
+#include
+#include "ATen/ATen.h"
+#define MIN_VALUE (-1e38)
+typedef at::BFloat16 bf16;
+
+__global__ void kernel_forward_bf16(
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ float u = float(_u[_c]);
+ float w = _w[_c];
+ const bf16 *__restrict__ const k = _k + _offset;
+ const bf16 *__restrict__ const v = _v + _offset;
+ bf16 *__restrict__ const y = _y + _offset;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ float aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+
+ float ww = u + kk;
+ float p = max(pp, ww);
+ float e1 = exp(pp - p);
+ float e2 = exp(ww - p);
+ y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+}
+
+__global__ void kernel_forward_with_state_bf16(
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y,
+ float *__restrict__ const _s
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset_s = _b * C * 3 + _c * 3;
+ const int _offset = _b * T * C + _c;
+
+ float u = float(_u[_c]);
+ float w = _w[_c];
+ const bf16 *__restrict__ const k = _k + _offset;
+ const bf16 *__restrict__ const v = _v + _offset;
+ bf16 *__restrict__ const y = _y + _offset;
+ float *__restrict__ const s = _s + _offset_s;
+
+ // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+ float aa = s[0], bb = s[1], pp = s[2];
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+
+ float ww = u + kk;
+ float p = max(pp, ww);
+ float e1 = exp(pp - p);
+ float e2 = exp(ww - p);
+ y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ s[0] = aa;
+ s[1] = bb;
+ s[2] = pp;
+}
+
+__global__ void kernel_backward_bf16(
+ const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
+ const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y,
+ const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu,
+ bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int _b = idx / C;
+ const int _c = idx % C;
+ const int _offset = _b * T * C + _c;
+
+ float u = float(_u[_c]);
+ float w = _w[_c];
+ const bf16 *__restrict__ const k = _k + _offset;
+ const bf16 *__restrict__ const v = _v + _offset;
+ const bf16 *__restrict__ const y = _y + _offset;
+ const bf16 *__restrict__ const gy = _gy + _offset;
+ bf16 *__restrict__ const gk = _gk + _offset;
+ bf16 *__restrict__ const gv = _gv + _offset;
+
+ float q[Tmax], r[Tmax];
+
+ float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
+ for (int i = 0; i < T; i++) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+ const float yy = float(y[ii]);
+
+ float ww = u + kk;
+ float p = max(pp, ww);
+ float e1 = exp(pp - p);
+ float e2 = exp(ww - p);
+ const float qq = float(gy[ii]) / (e1 * bb + e2);
+ gw += (ga - gb * yy) * e1 * qq;
+ gu += (vv - yy) * e2 * qq;
+ q[i] = qq;
+ r[i] = ww - p;
+
+ ww = w + pp;
+ p = max(ww, kk);
+ e1 = exp(ww - p);
+ e2 = exp(kk - p);
+ ga = e1 * (aa + ga);
+ gb = e1 * (bb + gb);
+ aa = e1 * aa + e2 * vv;
+ bb = e1 * bb + e2;
+ pp = p;
+ }
+ const int _offsetBC = _b * C + _c;
+ _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
+ _gu[_offsetBC] = bf16(gu);
+
+ aa = 0, bb = 0, pp = MIN_VALUE;
+ for (int i = T - 1; i >= 0; i--) {
+ const int ii = i * C;
+ const float kk = float(k[ii]);
+ const float vv = float(v[ii]);
+ const float yy = float(y[ii]);
+ const float qq = q[i];
+ const float rr = r[i];
+
+ float e1 = qq * exp(rr);
+ float e2 = exp(kk + pp);
+ gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
+ gv[ii] = bf16(e1 + e2 * aa);
+
+ const float ww = w + pp;
+ const float www = rr - u - kk;
+ const float p = max(ww, www);
+ e1 = exp(ww - p);
+ e2 = qq * exp(www - p);
+ aa = e1 * aa + e2;
+ bb = e1 * bb - e2 * yy;
+ pp = p;
+ }
+}
+
+void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward_bf16<<>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_forward_with_state_bf16<<>>(B, T, C, w, u, k, v, y, s);
+}
+
+void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
+ dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+ assert(B * C % threadsPerBlock.x == 0);
+ dim3 numBlocks(B * C / threadsPerBlock.x);
+ kernel_backward_bf16<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
+}
diff --git a/.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_op.cpp b/.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..55e7280665927b523a88021d5111daf28a63c905
--- /dev/null
+++ b/.venv/lib/python3.11/site-packages/transformers/kernels/rwkv/wkv_op.cpp
@@ -0,0 +1,66 @@
+#include
+#include "ATen/ATen.h"
+typedef at::BFloat16 bf16;
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
+void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
+void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s);
+void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s);
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
+void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
+
+void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
+ const int B = k.size(0);
+ const int T = k.size(1);
+ const int C = k.size(2);
+ cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr());
+}
+void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
+ const int B = k.size(0);
+ const int T = k.size(1);
+ const int C = k.size(2);
+ cuda_forward_bf16(B, T, C, w.data_ptr(), u.data_ptr