JinghuiLuAstronaut commited on
Commit
038d1cf
·
verified ·
1 Parent(s): 30b5140

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. LTA_openwebtext_dualt/logs/debug_2k_stream1024_fc_mask1_4gpu/debug_2k_stream1024_fc_mask1_4gpu_now_20260517_125945.log +147 -0
  2. LTA_openwebtext_dualt/logs/infer/lta_owt_lm1bclassic_fullvocab_bert_c1024_len1024_elfLdim_d1280_l32_h16_ff5120_lr3e-4_gbs512_2node8gpu_1m_save10k_t-20260522071024-s2ss5_latest_step0030000_shard01_gpu1_b16.log +18 -0
  3. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/activations.py +369 -0
  4. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cache_utils.py +1623 -0
  5. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/configuration_utils.py +1365 -0
  6. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/file_utils.py +105 -0
  7. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/fusion_mapping.py +270 -0
  8. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_processing_backends.py +689 -0
  9. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_processing_utils.py +688 -0
  10. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_utils.py +1069 -0
  11. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/masking_utils.py +1514 -0
  12. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py +503 -0
  13. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/blt/configuration_blt.py +286 -0
  14. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/jetmoe/__init__.py +27 -0
  15. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/jetmoe/modeling_jetmoe.py +830 -0
  16. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/vitmatte/__init__.py +29 -0
  17. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/vitmatte/image_processing_pil_vitmatte.py +159 -0
  18. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/optimization.py +1342 -0
  19. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/tokenization_python.py +1420 -0
  20. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/video_utils.py +893 -0
LTA_openwebtext_dualt/logs/debug_2k_stream1024_fc_mask1_4gpu/debug_2k_stream1024_fc_mask1_4gpu_now_20260517_125945.log ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NCCL version 2.25.1+cuda12.8
2
+ {
3
+ "device": "cuda:0",
4
+ "rank": 0,
5
+ "world_size": 4,
6
+ "samples": "tokenized_hf:13425484:pad=0",
7
+ "vocab_size": 2048,
8
+ "tokenizer_vocab_size": 2048,
9
+ "save_dir": "runs/debug_2k_stream1024_fc_mask1_4gpu_now_20260517_125945",
10
+ "batch_size": 32,
11
+ "grad_accum": 2,
12
+ "effective_batch_size": 256,
13
+ "global_batch_size": 256,
14
+ "lr_schedule": "cosine",
15
+ "optimizer": "adamw",
16
+ "epochs": 0.0,
17
+ "steps_per_epoch": 52443,
18
+ "total_steps": 1,
19
+ "warmup_steps": 26222,
20
+ "warmup_epochs": 0.5,
21
+ "min_lr": 6e-05,
22
+ "weight_decay": 0.1,
23
+ "output_weight_decay": -1.0,
24
+ "adamw_param_groups": "nanogpt",
25
+ "adam_beta1": 0.9,
26
+ "adam_beta2": 0.95,
27
+ "adam_eps": 1e-08,
28
+ "muon_impl": "legacy",
29
+ "muon_momentum": 0.95,
30
+ "muon_ns_steps": 5,
31
+ "muon_update_scale": 1.0,
32
+ "muon_nesterov": false,
33
+ "muon_width_scale": false,
34
+ "muon_grouping": "",
35
+ "muon_param_count": 0,
36
+ "muon_adam_param_count": 0,
37
+ "muon_param_names": [],
38
+ "muon_adam_param_names": [],
39
+ "muon_effective_nesterov": false,
40
+ "muon_effective_width_scale": false,
41
+ "muon_effective_weight_decay": 0.1,
42
+ "muon_adam_fallback_nesterov": false,
43
+ "muon_adam_fallback_weight_decay": 0.1,
44
+ "ema_decay": 0.0,
45
+ "ema_start_step": 0,
46
+ "model_type": "ddit",
47
+ "ddit_mlp_type": "gelu",
48
+ "elf_num_time_tokens": 4,
49
+ "elf_num_model_mode_tokens": 0,
50
+ "qk_norm": true,
51
+ "output_bias": false,
52
+ "output_init_std": -1.0,
53
+ "norm_type": "rmsnorm",
54
+ "target_loss": "hard_ce",
55
+ "linear_soft_target_power": 1.0,
56
+ "linear_soft_target_min_conf": 0.0,
57
+ "linear_soft_target_max_conf": 1.0,
58
+ "t_sampling_mode": "logit_normal",
59
+ "t_sampling_power": 1.0,
60
+ "t_sampling_eps": 0.0001,
61
+ "t_sampling_logit_mean": -1.5,
62
+ "t_sampling_logit_std": 0.8,
63
+ "dual_t": true,
64
+ "corrupt_t_mode": "same",
65
+ "corrupt_min_t": 0.0,
66
+ "corrupt_max_t": 1.0,
67
+ "prefix_block_prob": 0.0,
68
+ "prefix_block_len": 128,
69
+ "mask_ratio_floor_schedule": "none",
70
+ "dirichlet_endpoint_mode": "categorical_dual_t",
71
+ "dirichlet_semantic_t_mode": "same",
72
+ "dirichlet_semantic_t_value": 0.0,
73
+ "dirichlet_semantic_t_curve": "linear",
74
+ "dirichlet_semantic_t_power": 1.0,
75
+ "endpoint_sequence_random_prob_alpha": 0.0,
76
+ "categorical_wrong_from_full_vocab": true,
77
+ "categorical_wrong_from_batch_valid_tokens": false,
78
+ "categorical_wrong_basin_token_ids": "",
79
+ "categorical_wrong_basin_prob": 0.0,
80
+ "categorical_wrong_unigram_prob": 0.0,
81
+ "categorical_wrong_uniform_prob": 0.0,
82
+ "categorical_wrong_corpus_unigram_path": "",
83
+ "categorical_wrong_corpus_unigram_alpha": 1.0,
84
+ "categorical_wrong_basin_shared_prob": 0.0,
85
+ "categorical_wrong_unigram_shared_prob": 0.0,
86
+ "mask_mixture_original_prob": 0.0,
87
+ "mask_mixture_lowk_prob": 0.0,
88
+ "mask_mixture_lowcorrupt_prob": 0.0,
89
+ "mask_mixture_block_prob": 0.0,
90
+ "mask_mixture_all_prob": 0.0,
91
+ "mask_mixture_lowk_clean_tokens": "1,2,4,8,16,32,64",
92
+ "mask_mixture_lowcorrupt_tokens": "1,2,4,8,16,32,64",
93
+ "mask_mixture_block_tokens": "64,128",
94
+ "simplex_bridge_sampler": "dirichlet",
95
+ "logistic_normal_sigma_min": 0.18,
96
+ "logistic_normal_sigma_max": 2.2,
97
+ "logistic_normal_tau_min": 0.65,
98
+ "logistic_normal_tau_max": 1.15,
99
+ "torch_compile": false,
100
+ "compile_mode": "max-autotune",
101
+ "state_format": "prob",
102
+ "meanflow_weight": 0.0,
103
+ "rollout_train_prob": 0.0,
104
+ "rollout_train_steps": 1,
105
+ "rollout_train_infer_steps": 64,
106
+ "rollout_train_temp": 1.45,
107
+ "rollout_train_max_gamma": 1.0,
108
+ "rollout_train_corrupt_only": true,
109
+ "rollout_train_samplewise": false,
110
+ "rollout_train_compute_always": false,
111
+ "bridge_noise_init": "logistic_normal",
112
+ "noise_sigma": -1.0,
113
+ "allow_tf32": true,
114
+ "activation_checkpointing": false,
115
+ "activation_checkpoint_interval": 1,
116
+ "activation_checkpoint_scope": "block",
117
+ "ddp_static_graph": false,
118
+ "ddp_gradient_as_bucket_view": true,
119
+ "blocking_data_transfer": false,
120
+ "dataloader_prefetch_factor": 2,
121
+ "full_train_stats": false,
122
+ "tokenized_hf": true,
123
+ "tokenized_pad_token": "pad",
124
+ "elf_conditional_hf": false,
125
+ "record_pad_truncate": false,
126
+ "record_add_eos": false,
127
+ "record_add_special_tokens": false,
128
+ "record_pad_token": "pad",
129
+ "record_shuffle_buffer": 10000,
130
+ "wrap": false,
131
+ "wrap_mode": "stream",
132
+ "wrap_record_buffer_size": 200,
133
+ "owt_cached_chunks": false,
134
+ "owt_chunk_cache_dir": "",
135
+ "owt_chunk_cache_rebuild": false,
136
+ "owt_chunk_cache_write_batch": 4096,
137
+ "owt_exact_repeat_per_chunk": 0,
138
+ "online_chunk_shuffle": false,
139
+ "online_chunk_shuffle_buffer": 10000,
140
+ "openwebtext_split": "all",
141
+ "detokenizer": "auto",
142
+ "resolved_detokenizer": null,
143
+ "num_workers": 2,
144
+ "latest_every": 1000,
145
+ "resume_path": ""
146
+ }
147
+ step=1 epoch=1/1 epoch_step=1/52443 micro_steps=2 elapsed=2.6s lr=4.576310e-08 loss=7.6246 loss_recon=7.6246 loss_meanflow=0.0000 mean_model_t=0.2163 mean_corrupt_t=0.2163 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.0000 corrupt_frac=1.0000 acc_corrupt=0.0000 loss_corrupt=7.6246 wrong_frac=0.7833 init_acc_corrupt=0.1268 acc_corrupt_t_0p0_0p2=0.0000 corrupt_frac_t_0p0_0p2=0.5469 acc_corrupt_t_0p2_0p4=0.0000 corrupt_frac_t_0p2_0p4=0.3750 acc_corrupt_t_0p4_0p6=0.0000 corrupt_frac_t_0p4_0p6=0.0625 acc_corrupt_t_0p6_0p8=0.0000 corrupt_frac_t_0p6_0p8=0.0312 out_w_norm=0.0000 out_g_norm=0.2113 loss_all=7.6246 init_gold_top10=0.1954 init_gold_top100=0.2858
LTA_openwebtext_dualt/logs/infer/lta_owt_lm1bclassic_fullvocab_bert_c1024_len1024_elfLdim_d1280_l32_h16_ff5120_lr3e-4_gbs512_2node8gpu_1m_save10k_t-20260522071024-s2ss5_latest_step0030000_shard01_gpu1_b16.log ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [ckpt] runs/lta_owt_lm1bclassic_fullvocab_bert_c1024_len1024_elfLdim_d1280_l32_h16_ff5120_lr3e-4_gbs512_2node8gpu_1m_save10k_t-20260522071024-s2ss5/latest.pt step=30000
2
+ [decode] steps128_c1024_t1p45 generated 16/256
3
+ [decode] steps128_c1024_t1p45 generated 32/256
4
+ [decode] steps128_c1024_t1p45 generated 48/256
5
+ [decode] steps128_c1024_t1p45 generated 64/256
6
+ [decode] steps128_c1024_t1p45 generated 80/256
7
+ [decode] steps128_c1024_t1p45 generated 96/256
8
+ [decode] steps128_c1024_t1p45 generated 112/256
9
+ [decode] steps128_c1024_t1p45 generated 128/256
10
+ [decode] steps128_c1024_t1p45 generated 144/256
11
+ [decode] steps128_c1024_t1p45 generated 160/256
12
+ [decode] steps128_c1024_t1p45 generated 176/256
13
+ [decode] steps128_c1024_t1p45 generated 192/256
14
+ [decode] steps128_c1024_t1p45 generated 208/256
15
+ [decode] steps128_c1024_t1p45 generated 224/256
16
+ [decode] steps128_c1024_t1p45 generated 240/256
17
+ [decode] steps128_c1024_t1p45 generated 256/256
18
+ [summary] {"name": "steps128_c1024_t1p45", "step": 30000, "decode_steps": 128, "concentration_max": 1024.0, "raw_genppl": 15.432598193356394, "stripped_genppl": 15.314571366003578, "sample_entropy": 3.169319289650098, "distinct_1": 0.005100250244140625, "distinct_2": 0.10675937805474096, "top_token_mass": 0.2738838195800781, "raw_kept": 256, "stripped_kept": 256}
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/activations.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+ from collections import OrderedDict
18
+
19
+ import torch
20
+ from torch import Tensor, nn
21
+
22
+ from .integrations.hub_kernels import use_kernel_forward_from_hub
23
+ from .utils import logging
24
+ from .utils.import_utils import is_torchdynamo_compiling
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ @use_kernel_forward_from_hub("GeluTanh")
31
+ class GELUTanh(nn.Module):
32
+ """
33
+ A fast C implementation of the tanh approximation of the GeLU activation function. See
34
+ https://huggingface.co/papers/1606.08415.
35
+
36
+ This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
37
+ match due to rounding errors.
38
+ """
39
+
40
+ def __init__(self, use_gelu_tanh_python: bool = False):
41
+ super().__init__()
42
+ if use_gelu_tanh_python:
43
+ self.act = self._gelu_tanh_python
44
+ else:
45
+ self.act = functools.partial(nn.functional.gelu, approximate="tanh")
46
+
47
+ def _gelu_tanh_python(self, input: Tensor) -> Tensor:
48
+ return input * 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
49
+
50
+ def forward(self, input: Tensor) -> Tensor:
51
+ return self.act(input)
52
+
53
+
54
+ # Added for compatibility with autoawq which is archived now and imports PytorchGELUTanh from activations.py
55
+ PytorchGELUTanh = GELUTanh
56
+
57
+
58
+ @use_kernel_forward_from_hub("NewGELU")
59
+ class NewGELUActivation(nn.Module):
60
+ """
61
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
62
+ the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
63
+ """
64
+
65
+ def forward(self, input: Tensor) -> Tensor:
66
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
67
+
68
+
69
+ @use_kernel_forward_from_hub("GeLU")
70
+ class GELUActivation(nn.Module):
71
+ """
72
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
73
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
74
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
75
+ Also see the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
76
+ """
77
+
78
+ def __init__(self, use_gelu_python: bool = False):
79
+ super().__init__()
80
+ if use_gelu_python:
81
+ self.act = self._gelu_python
82
+ else:
83
+ self.act = nn.functional.gelu
84
+
85
+ def _gelu_python(self, input: Tensor) -> Tensor:
86
+ return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
87
+
88
+ def forward(self, input: Tensor) -> Tensor:
89
+ return self.act(input)
90
+
91
+
92
+ @use_kernel_forward_from_hub("SiLU")
93
+ class SiLUActivation(nn.Module):
94
+ """
95
+ See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
96
+ Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
97
+ Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
98
+ Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
99
+ later.
100
+ """
101
+
102
+ def forward(self, input: Tensor) -> Tensor:
103
+ return nn.functional.silu(input)
104
+
105
+
106
+ @use_kernel_forward_from_hub("FastGELU")
107
+ class FastGELUActivation(nn.Module):
108
+ """
109
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
110
+ """
111
+
112
+ def forward(self, input: Tensor) -> Tensor:
113
+ return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
114
+
115
+
116
+ @use_kernel_forward_from_hub("QuickGELU")
117
+ class QuickGELUActivation(nn.Module):
118
+ """
119
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
120
+ """
121
+
122
+ def forward(self, input: Tensor) -> Tensor:
123
+ return input * torch.sigmoid(1.702 * input)
124
+
125
+
126
+ class ClippedGELUActivation(nn.Module):
127
+ """
128
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
129
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
130
+ https://huggingface.co/papers/2004.09602.
131
+
132
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
133
+ initially created.
134
+
135
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
136
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://huggingface.co/papers/1606.08415
137
+ """
138
+
139
+ def __init__(self, min: float, max: float):
140
+ if min > max:
141
+ raise ValueError(f"min should be < max (got min: {min}, max: {max})")
142
+
143
+ super().__init__()
144
+ self.min = min
145
+ self.max = max
146
+
147
+ def forward(self, x: Tensor) -> Tensor:
148
+ return torch.clip(gelu(x), self.min, self.max)
149
+
150
+
151
+ class AccurateGELUActivation(nn.Module):
152
+ """
153
+ Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
154
+ https://github.com/hendrycks/GELUs
155
+
156
+ Implemented along with MEGA (Moving Average Equipped Gated Attention)
157
+ """
158
+
159
+ def __init__(self):
160
+ super().__init__()
161
+ self.precomputed_constant = math.sqrt(2 / math.pi)
162
+
163
+ def forward(self, input: Tensor) -> Tensor:
164
+ return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
165
+
166
+
167
+ class MishActivation(nn.Module):
168
+ """
169
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://huggingface.co/papers/1908.08681). Also
170
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
171
+ """
172
+
173
+ def __init__(self):
174
+ super().__init__()
175
+ self.act = nn.functional.mish
176
+
177
+ def _mish_python(self, input: Tensor) -> Tensor:
178
+ return input * torch.tanh(nn.functional.softplus(input))
179
+
180
+ def forward(self, input: Tensor) -> Tensor:
181
+ return self.act(input)
182
+
183
+
184
+ class LinearActivation(nn.Module):
185
+ """
186
+ Applies the linear activation function, i.e. forwarding input directly to output.
187
+ """
188
+
189
+ def forward(self, input: Tensor) -> Tensor:
190
+ return input
191
+
192
+
193
+ class LaplaceActivation(nn.Module):
194
+ """
195
+ Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
196
+ https://huggingface.co/papers/2209.10655
197
+
198
+ Inspired by squared relu, but with bounded range and gradient for better stability
199
+ """
200
+
201
+ def forward(self, input, mu=0.707107, sigma=0.282095):
202
+ input = (input - mu).div(sigma * math.sqrt(2.0))
203
+ return 0.5 * (1.0 + torch.erf(input))
204
+
205
+
206
+ class ReLUSquaredActivation(nn.Module):
207
+ """
208
+ Applies the relu^2 activation introduced in https://huggingface.co/papers/2109.08668
209
+ """
210
+
211
+ def forward(self, input):
212
+ relu_applied = nn.functional.relu(input)
213
+ squared = torch.square(relu_applied)
214
+ return squared
215
+
216
+
217
+ class SqrtSoftplusActivation(nn.Module):
218
+ """sqrt(softplus(x)) — the router scoring function used by DeepSeek V4."""
219
+
220
+ def forward(self, input):
221
+ return nn.functional.softplus(input).sqrt()
222
+
223
+
224
+ class ClassInstantier(OrderedDict):
225
+ def __getitem__(self, key):
226
+ content = super().__getitem__(key)
227
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
228
+ return cls(**kwargs)
229
+
230
+
231
+ class XIELUActivation(nn.Module):
232
+ """
233
+ Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
234
+
235
+ If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA
236
+ Otherwise, we emit a single warning and use xIELU Python
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ alpha_p_init=0.8,
242
+ alpha_n_init=0.8,
243
+ beta=0.5,
244
+ eps=-1e-6,
245
+ dtype=torch.bfloat16,
246
+ with_vector_loads=False,
247
+ ):
248
+ super().__init__()
249
+ self.alpha_p = nn.Parameter(torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=dtype))).unsqueeze(0))
250
+ self.alpha_n = nn.Parameter(
251
+ torch.log(torch.expm1(torch.tensor(alpha_n_init - beta, dtype=dtype))).unsqueeze(0)
252
+ )
253
+ self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
254
+ self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
255
+ self.with_vector_loads = with_vector_loads
256
+ # Temporary until xIELU CUDA fully implemented
257
+ self._beta_scalar = float(beta)
258
+ self._eps_scalar = float(eps)
259
+
260
+ self._xielu_cuda_obj = None
261
+ try:
262
+ import xielu.ops # noqa: F401
263
+
264
+ self._xielu_cuda_obj = torch.classes.xielu.XIELU()
265
+ msg = "Using experimental xIELU CUDA."
266
+ try:
267
+ from torch.compiler import allow_in_graph
268
+
269
+ self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
270
+ msg += " Enabled torch._dynamo for xIELU CUDA."
271
+ except Exception as err:
272
+ msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance."
273
+ self._xielu_cuda_fn = self._xielu_cuda
274
+ logger.warning_once(msg)
275
+ except Exception as err:
276
+ logger.warning_once(
277
+ f"CUDA-fused xIELU not available ({err}) – falling back to a Python version.\n"
278
+ "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`"
279
+ )
280
+
281
+ def _xielu_python(self, x: Tensor) -> Tensor:
282
+ alpha_p = nn.functional.softplus(self.alpha_p)
283
+ alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
284
+ return torch.where(
285
+ x > 0,
286
+ alpha_p * x * x + self.beta * x,
287
+ (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
288
+ )
289
+
290
+ def _xielu_cuda(self, x: Tensor) -> Tensor:
291
+ """Firewall function to prevent torch.compile from seeing .item() calls"""
292
+ original_shape = x.shape
293
+ # CUDA kernel expects 3D tensors, reshape if needed
294
+ while x.dim() < 3:
295
+ x = x.unsqueeze(0)
296
+ if x.dim() > 3:
297
+ x = x.view(-1, 1, x.size(-1))
298
+ if original_shape != x.shape:
299
+ logger.warning_once(
300
+ "Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
301
+ original_shape,
302
+ x.shape,
303
+ )
304
+ result = self._xielu_cuda_obj.forward(
305
+ x,
306
+ self.alpha_p.to(x.dtype),
307
+ self.alpha_n.to(x.dtype),
308
+ # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
309
+ self._beta_scalar,
310
+ self._eps_scalar,
311
+ self.with_vector_loads,
312
+ )
313
+ return result.view(original_shape)
314
+
315
+ def forward(self, input: Tensor) -> Tensor:
316
+ if self._xielu_cuda_obj is not None and input.is_cuda:
317
+ if not is_torchdynamo_compiling():
318
+ return self._xielu_cuda_fn(input)
319
+ else:
320
+ logger.warning_once("torch._dynamo is compiling, using Python version of xIELU.")
321
+ return self._xielu_python(input)
322
+
323
+
324
+ ACT2CLS = {
325
+ "gelu": GELUActivation,
326
+ "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
327
+ "gelu_fast": FastGELUActivation,
328
+ "gelu_new": NewGELUActivation,
329
+ "gelu_python": (GELUActivation, {"use_gelu_python": True}),
330
+ "gelu_pytorch_tanh": GELUTanh,
331
+ "gelu_python_tanh": (GELUTanh, {"use_gelu_tanh_python": True}),
332
+ "gelu_accurate": AccurateGELUActivation,
333
+ "hardswish": nn.Hardswish,
334
+ "laplace": LaplaceActivation,
335
+ "leaky_relu": nn.LeakyReLU,
336
+ "linear": LinearActivation,
337
+ "mish": MishActivation,
338
+ "quick_gelu": QuickGELUActivation,
339
+ "relu": nn.ReLU,
340
+ "relu2": ReLUSquaredActivation,
341
+ "relu6": nn.ReLU6,
342
+ "sigmoid": nn.Sigmoid,
343
+ "silu": SiLUActivation,
344
+ "sqrtsoftplus": SqrtSoftplusActivation,
345
+ "swish": nn.SiLU,
346
+ "tanh": nn.Tanh,
347
+ "prelu": nn.PReLU,
348
+ "xielu": XIELUActivation,
349
+ }
350
+ ACT2FN = ClassInstantier(ACT2CLS)
351
+
352
+
353
+ def get_activation(activation_string):
354
+ if activation_string in ACT2FN:
355
+ return ACT2FN[activation_string]
356
+ else:
357
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
358
+
359
+
360
+ # For backwards compatibility with: from activations import gelu_python
361
+ gelu_python = get_activation("gelu_python")
362
+ gelu_new = get_activation("gelu_new")
363
+ gelu = get_activation("gelu")
364
+ gelu_fast = get_activation("gelu_fast")
365
+ gelu_pytorch_tanh = get_activation("gelu_pytorch_tanh")
366
+ quick_gelu = get_activation("quick_gelu")
367
+ silu = get_activation("silu")
368
+ mish = get_activation("mish")
369
+ linear_act = get_activation("linear")
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cache_utils.py ADDED
@@ -0,0 +1,1623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Iterable
3
+
4
+ import torch
5
+
6
+ from .configuration_utils import PreTrainedConfig
7
+ from .utils import (
8
+ is_hqq_available,
9
+ is_optimum_quanto_available,
10
+ is_quanto_greater,
11
+ is_torch_greater_or_equal,
12
+ is_torchdynamo_compiling,
13
+ logging,
14
+ )
15
+
16
+
17
+ if is_hqq_available():
18
+ from hqq.core.quantize import Quantizer as HQQQuantizer
19
+
20
+ _is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ # Registry mapping ``config.layer_types[i]`` -> the dynamic cache layer class to build for
27
+ # that layer. ``DynamicCache.__init__`` consults this mapping when a ``config`` is provided
28
+ # so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own
29
+ # cache-layer subclass and stop needing a model-specific ``Cache`` subclass.
30
+ #
31
+ # A cache layer subclass with a class attribute ``layer_type = "..."`` auto-registers via
32
+ # ``CacheLayerMixin.__init_subclass__``. Each registered class must accept a
33
+ # ``PreTrainedConfig`` (the decoder text config) as the only positional argument.
34
+ LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {}
35
+
36
+
37
+ class CacheLayerMixin(ABC):
38
+ """Base, abstract class for a single layer's cache."""
39
+
40
+ is_compileable = False
41
+ # Subclasses can set ``layer_type`` to auto-register themselves in
42
+ # ``LAYER_TYPE_CACHE_MAPPING`` at import time (used by ``DynamicCache`` to dispatch
43
+ # per-layer cache classes from ``config.layer_types``).
44
+ layer_type: str | None = None
45
+
46
+ def __init_subclass__(cls, **kwargs):
47
+ super().__init_subclass__(**kwargs)
48
+ layer_type = cls.__dict__.get("layer_type", None)
49
+ if layer_type is not None:
50
+ LAYER_TYPE_CACHE_MAPPING[layer_type] = cls
51
+
52
+ def __init__(self):
53
+ self.keys: torch.Tensor | None = None
54
+ self.values: torch.Tensor | None = None
55
+ self.is_initialized = False
56
+
57
+ def __repr__(self):
58
+ return f"{self.__class__.__name__}"
59
+
60
+ @abstractmethod
61
+ def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: ...
62
+
63
+ @abstractmethod
64
+ def update(
65
+ self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
66
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...
67
+
68
+ @abstractmethod
69
+ def get_mask_sizes(self, query_length: int) -> tuple[int, int]: ...
70
+
71
+ @abstractmethod
72
+ def get_seq_length(self) -> int: ...
73
+
74
+ @abstractmethod
75
+ def get_max_cache_shape(self) -> int: ...
76
+
77
+ def offload(self):
78
+ """Offload this layer's data to CPU device."""
79
+ if self.is_initialized:
80
+ self.keys = self.keys.to("cpu", non_blocking=True)
81
+ self.values = self.values.to("cpu", non_blocking=True)
82
+
83
+ def prefetch(self):
84
+ """In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
85
+ if self.is_initialized and self.keys.device != self.device:
86
+ self.keys = self.keys.to(self.device, non_blocking=True)
87
+ self.values = self.values.to(self.device, non_blocking=True)
88
+
89
+ def reset(self) -> None:
90
+ """Resets the cache values while preserving the objects"""
91
+ if self.is_initialized:
92
+ self.keys.zero_()
93
+ self.values.zero_()
94
+ # This attribute is set on several Layers
95
+ if hasattr(self, "cumulative_length"):
96
+ # It can either be an int for dynamic layers, or a tensor for static layers
97
+ if isinstance(self.cumulative_length, int):
98
+ self.cumulative_length = 0
99
+ else:
100
+ self.cumulative_length.zero_()
101
+
102
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
103
+ """Reorders this layer's cache for beam search."""
104
+ if self.get_seq_length() > 0:
105
+ self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
106
+ self.values = self.values.index_select(0, beam_idx.to(self.values.device))
107
+
108
+
109
+ class DynamicLayer(CacheLayerMixin):
110
+ """
111
+ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
112
+ It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
113
+ """
114
+
115
+ is_sliding = False
116
+
117
+ def __init__(self, config: PreTrainedConfig | None = None):
118
+ super().__init__()
119
+
120
+ def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
121
+ self.dtype, self.device = key_states.dtype, key_states.device
122
+ self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
123
+ self.values = torch.tensor([], dtype=self.dtype, device=self.device)
124
+ self.is_initialized = True
125
+
126
+ def update(
127
+ self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
128
+ ) -> tuple[torch.Tensor, torch.Tensor]:
129
+ """
130
+ Update the key and value caches in-place, and return the necessary keys and value states.
131
+
132
+ Args:
133
+ key_states (`torch.Tensor`): The new key states to cache.
134
+ value_states (`torch.Tensor`): The new value states to cache.
135
+
136
+ Returns:
137
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
138
+ """
139
+ # Lazy initialization
140
+ if not self.is_initialized:
141
+ self.lazy_initialization(key_states, value_states)
142
+
143
+ self.keys = torch.cat([self.keys, key_states], dim=-2)
144
+ self.values = torch.cat([self.values, value_states], dim=-2)
145
+ return self.keys, self.values
146
+
147
+ def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
148
+ """Return the length and offset of the cache, used to generate the mask"""
149
+ kv_offset = 0
150
+ kv_length = self.get_seq_length() + query_length
151
+ return kv_length, kv_offset
152
+
153
+ def get_seq_length(self) -> int:
154
+ """Returns the sequence length of the cached states."""
155
+ if not self.is_initialized or self.keys.numel() == 0:
156
+ return 0
157
+ return self.keys.shape[-2]
158
+
159
+ def get_max_cache_shape(self) -> int:
160
+ """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
161
+ return -1
162
+
163
+ def crop(self, max_length: int) -> None:
164
+ """
165
+ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
166
+ to remove `max_length` tokens.
167
+ """
168
+ if max_length < 0:
169
+ max_length = self.get_seq_length() - abs(max_length)
170
+
171
+ if self.get_seq_length() <= max_length:
172
+ return
173
+
174
+ self.keys = self.keys[..., :max_length, :]
175
+ self.values = self.values[..., :max_length, :]
176
+
177
+ def batch_repeat_interleave(self, repeats: int) -> None:
178
+ """Repeat the cache `repeats` times in the batch dimension."""
179
+ if self.get_seq_length() > 0:
180
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
181
+ self.values = self.values.repeat_interleave(repeats, dim=0)
182
+
183
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
184
+ """Only keep the `indices` in the batch dimension of the cache."""
185
+ if self.get_seq_length() > 0:
186
+ self.keys = self.keys[indices, ...]
187
+ self.values = self.values[indices, ...]
188
+
189
+
190
+ class DynamicSlidingWindowLayer(DynamicLayer):
191
+ """
192
+ A cache layer that grows dynamically as more tokens are generated, up until the sliding window size.
193
+ It stores the key and value states as tensors of shape `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
194
+ """
195
+
196
+ is_sliding = True
197
+
198
+ def __init__(self, config: PreTrainedConfig | None = None, sliding_window: int | None = None):
199
+ super().__init__()
200
+ # Accept either a config (registry-style construction via LAYER_TYPE_CACHE_MAPPING)
201
+ # or a raw ``sliding_window`` int (legacy callers).
202
+ if sliding_window is None:
203
+ if config is None:
204
+ raise ValueError("Either `config` or `sliding_window` must be provided.")
205
+ sliding_window = getattr(config, "sliding_window", None) or getattr(config, "attention_chunk_size", None)
206
+ self.sliding_window = sliding_window
207
+ self.cumulative_length = 0
208
+ self._sliding_window_tensor = torch.tensor(self.sliding_window, dtype=torch.long)
209
+
210
+ def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
211
+ super().lazy_initialization(key_states, value_states)
212
+ self._sliding_window_tensor = self._sliding_window_tensor.to(self.device)
213
+
214
+ def update(
215
+ self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
216
+ ) -> tuple[torch.Tensor, torch.Tensor]:
217
+ """
218
+ Update the key and value caches in-place, and return the necessary keys and value states.
219
+
220
+ Args:
221
+ key_states (`torch.Tensor`): The new key states to cache.
222
+ value_states (`torch.Tensor`): The new value states to cache.
223
+
224
+ Returns:
225
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
226
+ """
227
+ # Lazy initialization
228
+ if not self.is_initialized:
229
+ self.lazy_initialization(key_states, value_states)
230
+
231
+ self.cumulative_length += key_states.shape[-2]
232
+
233
+ # Compute the full states
234
+ full_key_states = torch.cat([self.keys, key_states], dim=-2)
235
+ full_value_states = torch.cat([self.values, value_states], dim=-2)
236
+ # Only cache the last `self.sliding_window - 1` tokens (or all of them if lower than that)
237
+ self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
238
+ self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
239
+
240
+ # Return the full states
241
+ return full_key_states, full_value_states
242
+
243
+ def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
244
+ """Return the length and offset of the cache, used to generate the attention mask"""
245
+ is_full = self.cumulative_length >= self.sliding_window
246
+
247
+ kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0)
248
+ if is_full:
249
+ kv_length = self.sliding_window - 1 + query_length
250
+ else:
251
+ kv_length = self.cumulative_length + query_length
252
+
253
+ return kv_length, kv_offset
254
+
255
+ def get_seq_length(self) -> int:
256
+ """Returns the sequence length of the cached states."""
257
+ return self.cumulative_length
258
+
259
+ def get_max_cache_shape(self) -> int:
260
+ """Return the maximum cache shape of the cache"""
261
+ return self.sliding_window
262
+
263
+ def crop(self, max_length: int) -> None:
264
+ """
265
+ Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
266
+ negative to remove `max_length` tokens.
267
+ """
268
+ if self.get_seq_length() >= self.sliding_window:
269
+ raise ValueError(
270
+ "Cannot `crop` a `DynamicSlidingWindowLayer` after it has seen more tokens than its"
271
+ "sliding window (otherwise some states are lost)"
272
+ )
273
+ super().crop(max_length)
274
+ self.cumulative_length = self.keys.shape[-2]
275
+
276
+
277
+ class StaticLayer(CacheLayerMixin):
278
+ """
279
+ A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
280
+ It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
281
+
282
+ Args:
283
+ max_cache_len (`int`):
284
+ Maximum number of tokens that can be stored, used for tensor preallocation.
285
+ """
286
+
287
+ is_compileable = True
288
+ is_sliding = False
289
+
290
+ def __init__(self, max_cache_len: int):
291
+ super().__init__()
292
+ self.max_cache_len = max_cache_len
293
+ # Very important that it's a tensor here, to avoid recompiling when we update it and use it to create positions
294
+ self.cumulative_length = torch.tensor([0], dtype=int)
295
+
296
+ def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
297
+ """
298
+ Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
299
+ num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
300
+ devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
301
+
302
+ If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
303
+ function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
304
+ internally don't compile the prefill, this is guaranteed to have been called already when compiling.
305
+ If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
306
+ it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
307
+ i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
308
+ not be compiled anyway for performances!
309
+ """
310
+ self.dtype, self.device = key_states.dtype, key_states.device
311
+ self.max_batch_size, self.num_heads = key_states.shape[:2]
312
+ self.v_head_dim = value_states.shape[-1]
313
+ self.k_head_dim = key_states.shape[-1]
314
+
315
+ self.keys = torch.zeros(
316
+ (self.max_batch_size, self.num_heads, self.max_cache_len, self.k_head_dim),
317
+ dtype=self.dtype,
318
+ device=self.device,
319
+ )
320
+ self.values = torch.zeros(
321
+ (self.max_batch_size, self.num_heads, self.max_cache_len, self.v_head_dim),
322
+ dtype=self.dtype,
323
+ device=self.device,
324
+ )
325
+ self.cumulative_length = self.cumulative_length.to(self.device)
326
+ # Note: `mark_static_address` is used to tag the tensors as a fixed data pointer, preventing compiled graph
327
+ # breaks or cudagraph skips due to inplace mutations when updating the cache. However, it is not supported when
328
+ # tracing the graph, so we skip it in this case. As prefill should never be compiled, this is not an issue and it
329
+ # will still be run (except when users compile prefill explicitly, but this should be avoided!)
330
+ # Without this, we cannot use cudagraphs!!
331
+ if not is_torchdynamo_compiling():
332
+ torch._dynamo.mark_static_address(self.keys)
333
+ torch._dynamo.mark_static_address(self.values)
334
+ torch._dynamo.mark_static_address(self.cumulative_length)
335
+
336
+ self.is_initialized = True
337
+
338
+ def update(
339
+ self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
340
+ ) -> tuple[torch.Tensor, torch.Tensor]:
341
+ """
342
+ Update the key and value caches in-place, and return the necessary keys and value states.
343
+
344
+ Args:
345
+ key_states (`torch.Tensor`): The new key states to cache.
346
+ value_states (`torch.Tensor`): The new value states to cache.
347
+
348
+ Returns:
349
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
350
+ """
351
+ # Lazy initialization
352
+ if not self.is_initialized:
353
+ self.lazy_initialization(key_states, value_states)
354
+
355
+ # Create a tensor to slice the static kv at the correct indices
356
+ kv_length = key_states.shape[-2]
357
+ cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
358
+ # Note that has to be performed in-place, as we have a static address that we need to keep
359
+ self.cumulative_length.add_(kv_length)
360
+
361
+ # Update the cache
362
+ try:
363
+ self.keys.index_copy_(2, cache_position, key_states)
364
+ self.values.index_copy_(2, cache_position, value_states)
365
+ except NotImplementedError:
366
+ # Fallback for devices like MPS where index_copy_ might not be supported.
367
+ self.keys[:, :, cache_position] = key_states
368
+ self.values[:, :, cache_position] = value_states
369
+
370
+ return self.keys, self.values
371
+
372
+ def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
373
+ """Return the length and offset of the cache, used to generate the attention mask"""
374
+ kv_offset = 0
375
+ kv_length = self.max_cache_len
376
+ return kv_length, kv_offset
377
+
378
+ def get_seq_length(self) -> int:
379
+ """Returns the sequence length of the cached states."""
380
+ return self.cumulative_length if self.is_initialized else 0
381
+
382
+ def get_max_cache_shape(self) -> int:
383
+ """Return the maximum cache shape of the cache"""
384
+ return self.max_cache_len
385
+
386
+
387
+ class StaticSlidingWindowLayer(StaticLayer):
388
+ """
389
+ A static cache layer that stores the key and value states as static tensors of shape
390
+ `[batch_size, num_heads, min(max_cache_len, sliding_window), head_dim]`. It lazily allocates its full backing
391
+ tensors, and then mutates them in-place. Built for `torch.compile` support.
392
+
393
+ Args:
394
+ max_cache_len (`int`):
395
+ Maximum number of tokens that can be stored, used for tensor preallocation.
396
+ sliding_window (`int`):
397
+ The size of the sliding window.
398
+ """
399
+
400
+ is_sliding = True
401
+
402
+ def __init__(self, max_cache_len: int, sliding_window: int):
403
+ effective_max_cache_len = min(sliding_window, max_cache_len)
404
+ super().__init__(max_cache_len=effective_max_cache_len)
405
+ # Here, to avoid data-dependent control flows, we also need to use a python int to keep track of the cumulative length
406
+ self.cumulative_length_int = 0
407
+
408
+ def update(
409
+ self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
410
+ ) -> tuple[torch.Tensor, torch.Tensor]:
411
+ """
412
+ Update the key and value caches in-place, and return the necessary keys and value states.
413
+
414
+ Args:
415
+ key_states (`torch.Tensor`): The new key states to cache.
416
+ value_states (`torch.Tensor`): The new value states to cache.
417
+
418
+ Returns:
419
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
420
+ """
421
+ # Lazy initialization
422
+ if not self.is_initialized:
423
+ self.lazy_initialization(key_states, value_states)
424
+
425
+ kv_length = key_states.shape[-2]
426
+ current_length = self.cumulative_length_int
427
+ is_full = current_length >= self.max_cache_len
428
+ # Update it now that we saved the value above
429
+ self.cumulative_length_int += kv_length
430
+
431
+ if is_full:
432
+ # In general, we should use a much simpler `cat` here as well, independently of the states size. However,
433
+ # dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details
434
+ if key_states.shape[-2] == 1:
435
+ # Roll all values to the left by 1 position
436
+ new_keys = self.keys.roll(-1, dims=-2)
437
+ new_values = self.values.roll(-1, dims=-2)
438
+ # Overwrite the last position with new states
439
+ # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
440
+ index = torch.tensor([-1], dtype=int, device=self.device)
441
+ new_keys[:, :, index] = key_states
442
+ new_values[:, :, index] = value_states
443
+
444
+ # Copy back into `self` (do not just assign again) in order to keep the static dynamo address
445
+ self.keys.copy_(new_keys)
446
+ self.values.copy_(new_values)
447
+
448
+ # Very important to return the `self` tensors here, as they have the static dynamo address
449
+ return self.keys, self.values
450
+ # Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...)
451
+ else:
452
+ full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
453
+ full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
454
+ # Not yet full, but becoming full on this update
455
+ elif current_length + kv_length > self.max_cache_len:
456
+ # Fast prefill path, no need to cat() in this case, as the cache is currently empty
457
+ if current_length == 0:
458
+ full_key_states = key_states
459
+ full_value_states = value_states
460
+ else:
461
+ full_key_states = torch.cat((self.keys[:, :, :current_length, :], key_states), dim=-2)
462
+ full_value_states = torch.cat((self.values[:, :, :current_length, :], value_states), dim=-2)
463
+ else:
464
+ # Note: very important to use the tensor version of the cumulative length here, as otherwise cudagraphs
465
+ # (triggered by mode="reduced_overhead") will lead to random crashes, as the int would be overwritten
466
+ cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
467
+ try:
468
+ self.keys.index_copy_(2, cache_position, key_states)
469
+ self.values.index_copy_(2, cache_position, value_states)
470
+ except NotImplementedError:
471
+ self.keys[:, :, cache_position] = key_states
472
+ self.values[:, :, cache_position] = value_states
473
+
474
+ # Update the tensor version of the length in-place (we don't need to update it if we are already outside
475
+ # of this branch, as we don't need the tensor anymore)
476
+ self.cumulative_length.add_(kv_length)
477
+
478
+ # Very important to return the `self` tensors here, as they have the static dynamo address
479
+ return self.keys, self.values
480
+
481
+ # We only cache the last `sliding_window` tokens
482
+ self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
483
+ self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
484
+ # we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context
485
+ return full_key_states, full_value_states
486
+
487
+ def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
488
+ """Return the length and offset of the cache, used to generate the attention mask"""
489
+ sliding_window = self.max_cache_len
490
+ is_full = self.cumulative_length_int >= self.max_cache_len
491
+
492
+ kv_offset = max(self.cumulative_length_int - sliding_window + 1, 0)
493
+ # The cache is already full
494
+ if is_full:
495
+ kv_length = sliding_window + query_length - 1
496
+ # Not yet full, but becoming full on this update
497
+ elif self.cumulative_length_int + query_length > sliding_window:
498
+ kv_length = self.cumulative_length_int + query_length
499
+ # Here the Cache is still smaller than the local size, but we return the local size as it's static
500
+ else:
501
+ kv_length = sliding_window
502
+
503
+ return kv_length, kv_offset
504
+
505
+ def get_seq_length(self) -> int:
506
+ """Returns the sequence length of the cached states."""
507
+ return self.cumulative_length_int
508
+
509
+ def reset(self):
510
+ super().reset()
511
+ self.cumulative_length_int = 0
512
+
513
+
514
+ class QuantizedLayer(DynamicLayer):
515
+ """
516
+ A quantized layer similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
517
+ It allows the model to generate longer sequence length without allocating too much memory for the key and value caches by
518
+ applying quantization.
519
+
520
+ The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length`
521
+ is set as a maximum capacity for the original precision cache. When the length goes beyond maximum capacity, the original
522
+ precision cache is discarded and moved into the quantized cache. The quantization is done per-channel with a set `q_group_size`
523
+ for both Keys and Values, in contrast to what was described in the paper.
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ nbits: int = 4,
529
+ axis_key: int = 0,
530
+ axis_value: int = 0,
531
+ q_group_size: int = 64,
532
+ residual_length: int = 128,
533
+ ):
534
+ super().__init__()
535
+ self.nbits = nbits
536
+ self.axis_key = axis_key
537
+ self.axis_value = axis_value
538
+ self.q_group_size = q_group_size
539
+ self.residual_length = residual_length
540
+ self.cumulative_length = 0
541
+
542
+ def update(
543
+ self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
544
+ ) -> tuple[torch.Tensor, torch.Tensor]:
545
+ """
546
+ Update the key and value caches in-place, and return the necessary keys and value states.
547
+
548
+ Args:
549
+ key_states (`torch.Tensor`): The new key states to cache.
550
+ value_states (`torch.Tensor`): The new value states to cache.
551
+
552
+ Returns:
553
+ tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
554
+ """
555
+ self.cumulative_length += key_states.shape[-2]
556
+
557
+ # Lazy initialization
558
+ if not self.is_initialized:
559
+ self.lazy_initialization(key_states, value_states)
560
+ self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key)
561
+ self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value)
562
+ return key_states, value_states
563
+
564
+ dequant_keys = self._dequantize(self._quantized_keys)
565
+ dequant_values = self._dequantize(self._quantized_values)
566
+ keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2)
567
+ values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2)
568
+ if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
569
+ self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
570
+ self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
571
+ self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
572
+ self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
573
+ else:
574
+ self.keys = torch.cat([self.keys, key_states], dim=-2)
575
+ self.values = torch.cat([self.values, value_states], dim=-2)
576
+
577
+ return keys_to_return, values_to_return
578
+
579
+ @abstractmethod
580
+ def _quantize(self, tensor, axis): ...
581
+
582
+ @abstractmethod
583
+ def _dequantize(self, q_tensor): ...
584
+
585
+ def get_seq_length(self) -> int:
586
+ """Returns the sequence length of the cached states."""
587
+ return self.cumulative_length
588
+
589
+
590
+ class QuantoQuantizedLayer(QuantizedLayer):
591
+ def __init__(
592
+ self,
593
+ nbits: int = 4,
594
+ axis_key: int = 0,
595
+ axis_value: int = 0,
596
+ q_group_size: int = 64,
597
+ residual_length: int = 128,
598
+ ):
599
+ super().__init__(
600
+ nbits=nbits,
601
+ axis_key=axis_key,
602
+ axis_value=axis_value,
603
+ q_group_size=q_group_size,
604
+ residual_length=residual_length,
605
+ )
606
+
607
+ # We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
608
+ if not is_optimum_quanto_available():
609
+ raise ImportError(
610
+ "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto "
611
+ "backend. Please install it via with `pip install optimum-quanto`"
612
+ )
613
+ elif is_quanto_greater("0.2.5", accept_dev=True):
614
+ from optimum.quanto import MaxOptimizer, qint2, qint4
615
+ else:
616
+ raise ImportError(
617
+ "You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedLayer`. "
618
+ )
619
+
620
+ if self.nbits not in [2, 4]:
621
+ raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
622
+
623
+ if self.axis_key not in [0, -1]:
624
+ raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
625
+
626
+ if self.axis_value not in [0, -1]:
627
+ raise ValueError(
628
+ f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
629
+ )
630
+
631
+ self.qtype = qint4 if self.nbits == 4 else qint2
632
+ self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
633
+
634
+ def _quantize(self, tensor, axis):
635
+ from optimum.quanto import quantize_weight
636
+
637
+ scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
638
+ qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
639
+ return qtensor
640
+
641
+ def _dequantize(self, qtensor):
642
+ return qtensor.dequantize()
643
+
644
+
645
+ class HQQQuantizedLayer(QuantizedLayer):
646
+ def __init__(
647
+ self,
648
+ nbits: int = 4,
649
+ axis_key: int = 0,
650
+ axis_value: int = 0,
651
+ q_group_size: int = 64,
652
+ residual_length: int = 128,
653
+ ):
654
+ super().__init__(
655
+ nbits=nbits,
656
+ axis_key=axis_key,
657
+ axis_value=axis_value,
658
+ q_group_size=q_group_size,
659
+ residual_length=residual_length,
660
+ )
661
+
662
+ if not is_hqq_available():
663
+ raise ImportError(
664
+ "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
665
+ "Please install it via with `pip install hqq`"
666
+ )
667
+
668
+ if self.nbits not in [1, 2, 3, 4, 8]:
669
+ raise ValueError(
670
+ f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
671
+ )
672
+
673
+ if self.axis_key not in [0, 1]:
674
+ raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
675
+
676
+ if self.axis_value not in [0, 1]:
677
+ raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
678
+
679
+ self.quantizer = HQQQuantizer
680
+
681
+ def _quantize(self, tensor, axis):
682
+ qtensor, meta = self.quantizer.quantize(
683
+ tensor,
684
+ axis=axis,
685
+ device=self.keys.device,
686
+ compute_dtype=self.keys.dtype,
687
+ nbits=self.nbits,
688
+ group_size=self.q_group_size,
689
+ )
690
+ meta["compute_dtype"] = self.keys.dtype
691
+ self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device) # Move to device and cast to dtype
692
+ meta["scale"] = meta["scale"].to(qtensor.device)
693
+ meta["zero"] = meta["zero"].to(qtensor.device)
694
+ return qtensor, meta
695
+
696
+ def _dequantize(self, qtensor):
697
+ quant_tensor, meta = qtensor
698
+ tensor = self.quantizer.dequantize(quant_tensor, meta)
699
+ return tensor
700
+
701
+
702
+ class LinearAttentionCacheLayerMixin(ABC):
703
+ """Base, abstract class for a linear attention single layer's cache."""
704
+
705
+ # All shapes are static by essence in a LinearAttention layer, so it is compileable
706
+ is_compileable = True
707
+
708
+ def __init__(self):
709
+ self.conv_states: torch.Tensor | None = None
710
+ self.recurrent_states: torch.Tensor | None = None
711
+ self.is_conv_states_initialized = False
712
+ self.is_recurrent_states_initialized = False
713
+ self.has_previous_state = False
714
+
715
+ def __repr__(self):
716
+ return f"{self.__class__.__name__}"
717
+
718
+ @abstractmethod
719
+ def lazy_initialization(
720
+ self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None
721
+ ) -> None: ...
722
+
723
+ @abstractmethod
724
+ def update_conv_state(self, conv_states: torch.Tensor) -> torch.Tensor: ...
725
+
726
+ @abstractmethod
727
+ def update_recurrent_state(self, recurrent_states: torch.Tensor) -> torch.Tensor: ...
728
+
729
+ def offload(self):
730
+ """Offload this layer's data to CPU device."""
731
+ if self.is_conv_states_initialized:
732
+ self.conv_states = self.conv_states.to("cpu", non_blocking=True)
733
+ if self.is_recurrent_states_initialized:
734
+ self.recurrent_states = self.recurrent_states.to("cpu", non_blocking=True)
735
+
736
+ def prefetch(self):
737
+ """In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
738
+ if self.is_conv_states_initialized and self.conv_states.device != self.device:
739
+ self.conv_states = self.conv_states.to(self.device, non_blocking=True)
740
+ if self.is_recurrent_states_initialized and self.recurrent_states.device != self.device:
741
+ self.recurrent_states = self.recurrent_states.to(self.device, non_blocking=True)
742
+
743
+ def reset(self) -> None:
744
+ """Resets the cache values while preserving the objects"""
745
+ if self.is_conv_states_initialized:
746
+ self.conv_states.zero_()
747
+ if self.is_recurrent_states_initialized:
748
+ self.recurrent_states.zero_()
749
+ self.has_previous_state = False
750
+
751
+ def reorder_cache(self, beam_idx: torch.LongTensor):
752
+ """Reorders the cache for beam search, given the selected beam indices."""
753
+ if self.is_conv_states_initialized:
754
+ self.conv_states = self.conv_states.index_select(0, beam_idx.to(self.device))
755
+ # recurrent_states can stay empty sometimes, see e.g. lfm2 which only uses the conv_states
756
+ if self.is_recurrent_states_initialized:
757
+ self.recurrent_states = self.recurrent_states.index_select(0, beam_idx.to(self.device))
758
+
759
+ def crop(self, max_length: int):
760
+ # We don't crop the linear attention cache, so simply do nothing here
761
+ pass
762
+
763
+
764
+ class LinearAttentionLayer(LinearAttentionCacheLayerMixin):
765
+ def __init__(self, config: PreTrainedConfig | None = None):
766
+ super().__init__()
767
+
768
+ def lazy_initialization(
769
+ self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None
770
+ ) -> None:
771
+ # Here, we will lazy init both states separately, each in their own update function
772
+ if conv_states is not None:
773
+ self.dtype, self.device = conv_states.dtype, conv_states.device
774
+ # Even if prefill is larfer/shorter than the conv_size, the tensor is always either padded or truncated
775
+ self.max_batch_size, self.conv_kernel_size = conv_states.shape[0], conv_states.shape[-1]
776
+ # The shape is always static, so we init as such
777
+ self.conv_states = torch.zeros_like(conv_states, dtype=self.dtype, device=self.device)
778
+ # Mark as static address to be able to use cudagraphs
779
+ if not is_torchdynamo_compiling():
780
+ torch._dynamo.mark_static_address(self.conv_states)
781
+ self.is_conv_states_initialized = True
782
+ if recurrent_states is not None:
783
+ # The shape is always static, so we init as such
784
+ self.recurrent_states = torch.zeros_like(recurrent_states, dtype=self.dtype, device=self.device)
785
+ # Mark as static address to be able to use cudagraphs
786
+ if not is_torchdynamo_compiling():
787
+ torch._dynamo.mark_static_address(self.recurrent_states)
788
+ self.is_recurrent_states_initialized = True
789
+
790
+ def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor:
791
+ """
792
+ Update the linear attention cache in-place, and return the necessary conv states.
793
+
794
+ Args:
795
+ conv_states (`torch.Tensor`): The new conv states to cache.
796
+
797
+ Returns:
798
+ `torch.Tensor`: The updated conv states.
799
+ """
800
+ # Lazy initialization
801
+ if not self.is_conv_states_initialized:
802
+ self.lazy_initialization(conv_states=conv_states)
803
+
804
+ if not self.has_previous_state:
805
+ # Note that we copy instead of assigning, to preserve the static address for cudagraphs
806
+ self.conv_states.copy_(conv_states)
807
+ self.has_previous_state = True
808
+ # Technically, this update is not logically correct if the prefill is smaller than `conv_kernel_size`,
809
+ # as it will `roll` anyway in the first decoding step, even though it should `roll` ONLY if the cache is already full.
810
+ # But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now
811
+ else:
812
+ # Note that we copy instead of assigning, to preserve the static address for cudagraphs
813
+ num_new_tokens = conv_states.shape[-1]
814
+ if num_new_tokens >= self.conv_kernel_size:
815
+ self.conv_states.copy_(conv_states[..., -self.conv_kernel_size :])
816
+ else:
817
+ new_conv_states = self.conv_states.roll(shifts=-num_new_tokens, dims=-1)
818
+ new_conv_states[:, :, -num_new_tokens:] = conv_states
819
+ self.conv_states.copy_(new_conv_states)
820
+
821
+ return self.conv_states
822
+
823
+ def update_recurrent_state(self, recurrent_states: torch.Tensor, **kwargs) -> torch.Tensor:
824
+ """
825
+ Update the linear attention cache in-place, and return the necessary ssm states.
826
+
827
+ Args:
828
+ smm_states (`torch.Tensor`): The new ssm states to cache.
829
+
830
+ Returns:
831
+ `torch.Tensor`: The updated ssm states.
832
+ """
833
+ if not self.is_recurrent_states_initialized:
834
+ self.lazy_initialization(recurrent_states=recurrent_states)
835
+ # Note that we copy instead of assigning, to preserve the static address for cudagraphs
836
+ self.recurrent_states.copy_(recurrent_states)
837
+ return self.recurrent_states
838
+
839
+
840
+ class LinearAttentionAndFullAttentionLayer(LinearAttentionLayer, DynamicLayer):
841
+ # The dynamic Attention part makes it non-compileable
842
+ is_compileable = False
843
+
844
+ def __init__(self, config: PreTrainedConfig | None = None):
845
+ DynamicLayer.__init__(self)
846
+ LinearAttentionLayer.__init__(self)
847
+
848
+ def lazy_initialization(self, *args, **kwargs) -> None:
849
+ # When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args
850
+ if len(args) == 2 and len(kwargs) == 0:
851
+ DynamicLayer.lazy_initialization(self, *args)
852
+ # Otherwise, for the LinearAttention cache, when it's called in `update_conv_state` or `update_recurrent_state`, it's
853
+ # always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states)
854
+ if len(args) == 0 and len(kwargs) == 1:
855
+ LinearAttentionLayer.lazy_initialization(self, **kwargs)
856
+
857
+ def reset(self) -> None:
858
+ LinearAttentionLayer.reset(self)
859
+ DynamicLayer.reset(self)
860
+
861
+ def reorder_cache(self, beam_idx: torch.LongTensor):
862
+ """Reorders the cache for beam search, given the selected beam indices."""
863
+ LinearAttentionLayer.reorder_cache(self, beam_idx)
864
+ DynamicLayer.reorder_cache(self, beam_idx)
865
+
866
+
867
+ # Pre-register the standard layer types (some classes are shared between multiple types,
868
+ # e.g. ``DynamicSlidingWindowLayer`` covers both ``"sliding_attention"`` and
869
+ # ``"chunked_attention"`` — those need an explicit map entry rather than the
870
+ # auto-registration via ``CacheLayerMixin.__init_subclass__``).
871
+ LAYER_TYPE_CACHE_MAPPING.update(
872
+ {
873
+ "full_attention": DynamicLayer,
874
+ # From a cache point of view, sliding and chunked are the same in how they should behave;
875
+ # only the mask differs.
876
+ "sliding_attention": DynamicSlidingWindowLayer,
877
+ "chunked_attention": DynamicSlidingWindowLayer,
878
+ # Linear-attention-shaped layers (mamba / conv / pure linear-attention / moe placeholders)
879
+ # don't grow per-token KV; they're tracked just so position bookkeeping stays consistent.
880
+ "mamba": LinearAttentionLayer,
881
+ "conv": LinearAttentionLayer,
882
+ "linear_attention": LinearAttentionLayer,
883
+ "moe": LinearAttentionLayer,
884
+ # Hybrid layers (e.g. zamba / zamba2) carry both a linear-attention state and a dynamic-attention state.
885
+ "hybrid": LinearAttentionAndFullAttentionLayer,
886
+ }
887
+ )
888
+
889
+
890
+ class Cache:
891
+ """
892
+ A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
893
+ the Cache of each layer.
894
+
895
+ Args:
896
+ layers (`Optional`, *optional*):
897
+ A list of pre-created `CacheLayerMixin` or `LinearAttentionCacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate`
898
+ will be used.
899
+ layer_class_to_replicate (`type[CacheLayerMixin | LinearAttentionCacheLayerMixin]`, *optional*):
900
+ Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer,
901
+ and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current
902
+ list of layers.
903
+ offloading (`bool`, *optional*, defaults to `False`):
904
+ Whether to perform offloading of the layers to `cpu`, to save GPU memory.
905
+ offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
906
+ If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
907
+ usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
908
+ """
909
+
910
+ def __init__(
911
+ self,
912
+ layers: list[CacheLayerMixin | LinearAttentionCacheLayerMixin] | None = None,
913
+ layer_class_to_replicate: type[CacheLayerMixin | LinearAttentionCacheLayerMixin] | None = None,
914
+ offloading: bool = False,
915
+ offload_only_non_sliding: bool = True,
916
+ ):
917
+ if layers is not None and layer_class_to_replicate is not None:
918
+ raise ValueError(
919
+ "You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a "
920
+ "`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to "
921
+ "`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache."
922
+ )
923
+ if layers is None and layer_class_to_replicate is None:
924
+ raise ValueError(
925
+ "You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache."
926
+ )
927
+ self.layers = layers if layers is not None else []
928
+ self.layer_class_to_replicate = layer_class_to_replicate
929
+ self.offloading = offloading
930
+ if self.offloading:
931
+ self.only_non_sliding = offload_only_non_sliding
932
+ self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
933
+
934
+ def __repr__(self):
935
+ return f"{self.__class__.__name__}(layers={self.layers})"
936
+
937
+ def prefetch(self, layer_idx: int, only_non_sliding: bool = True):
938
+ """
939
+ Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers
940
+ which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers.
941
+ Note that we use a non-default stream for this, to avoid blocking.
942
+ """
943
+ if only_non_sliding:
944
+ # Try to find next non-sliding, starting at `layer_idx`
945
+ try:
946
+ layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False)
947
+ # In this case, we need to circle back to the beginning
948
+ except ValueError:
949
+ layer_idx = self.is_sliding.index(False)
950
+ else:
951
+ layer_idx = layer_idx if layer_idx < len(self.layers) else 0
952
+
953
+ # Prefetch
954
+ with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
955
+ self.layers[layer_idx].prefetch()
956
+
957
+ def offload(self, layer_idx: int, only_non_sliding: bool = True):
958
+ """
959
+ Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a
960
+ non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier
961
+ computation in the layer's `update` methods are finished.
962
+ """
963
+ if not (only_non_sliding and self.is_sliding[layer_idx]):
964
+ self.layers[layer_idx].offload()
965
+
966
+ def update(
967
+ self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args, **kwargs
968
+ ) -> tuple[torch.Tensor, torch.Tensor]:
969
+ """
970
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
971
+
972
+ Parameters:
973
+ key_states (`torch.Tensor`):
974
+ The new key states to cache.
975
+ value_states (`torch.Tensor`):
976
+ The new value states to cache.
977
+ layer_idx (`int`):
978
+ The index of the layer to cache the states for.
979
+
980
+ Return:
981
+ A tuple containing the updated key and value states.
982
+ """
983
+ # In this case, the `layers` were not provided, and we must append as much as `layer_idx`
984
+ if self.layer_class_to_replicate is not None:
985
+ while len(self.layers) <= layer_idx:
986
+ self.layers.append(self.layer_class_to_replicate())
987
+
988
+ if self.offloading:
989
+ # Wait for the stream to finish if needed, and start prefetching the next layer
990
+ torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
991
+ self.prefetch(layer_idx + 1, self.only_non_sliding)
992
+
993
+ keys, values = self.layers[layer_idx].update(key_states, value_states, *args, **kwargs)
994
+
995
+ if self.offloading:
996
+ self.offload(layer_idx, self.only_non_sliding)
997
+
998
+ return keys, values
999
+
1000
+ def update_conv_state(self, conv_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor:
1001
+ """
1002
+ Updates the cache with the new `conv_states` for the layer `layer_idx`.
1003
+
1004
+ Parameters:
1005
+ conv_states (`torch.Tensor`):
1006
+ The new conv states to cache.
1007
+ layer_idx (`int`):
1008
+ The index of the layer to cache the states for.
1009
+
1010
+ Return:
1011
+ `torch.Tensor`: The updated conv states.
1012
+ """
1013
+ # NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support
1014
+ # out of the box
1015
+ if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin):
1016
+ raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!")
1017
+ conv_states = self.layers[layer_idx].update_conv_state(conv_states, **kwargs)
1018
+ return conv_states
1019
+
1020
+ def update_recurrent_state(self, recurrent_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor:
1021
+ """
1022
+ Updates the cache with the new `recurrent_states` for the layer `layer_idx`.
1023
+
1024
+ Parameters:
1025
+ smm_states (`torch.Tensor`):
1026
+ The new ssm states to cache.
1027
+ layer_idx (`int`):
1028
+ The index of the layer to cache the states for.
1029
+
1030
+ Return:
1031
+ `torch.Tensor`: The updated ssm states.
1032
+ """
1033
+ # NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support
1034
+ # out of the box
1035
+ if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin):
1036
+ raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!")
1037
+ recurrent_states = self.layers[layer_idx].update_recurrent_state(recurrent_states, **kwargs)
1038
+ return recurrent_states
1039
+
1040
+ def early_initialization(
1041
+ self,
1042
+ batch_size: int,
1043
+ num_heads: int | list[int],
1044
+ head_dim: int | list[int],
1045
+ dtype: torch.dtype,
1046
+ device: torch.device,
1047
+ ):
1048
+ """
1049
+ Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
1050
+ This is useful for our `export` recipes, as `export` needs everything in advance.
1051
+ """
1052
+ # To allow different num_heads and head_dim depending on layers, we accept lists
1053
+ if isinstance(num_heads, int):
1054
+ num_heads = [num_heads] * len(self)
1055
+ if isinstance(head_dim, int):
1056
+ head_dim = [head_dim] * len(self)
1057
+
1058
+ if len(num_heads) != len(self.layers):
1059
+ raise ValueError(
1060
+ f"`num_head` was provided as a list of length {len(num_heads)}, but the Cache currently has {len(self.layers)} layers"
1061
+ )
1062
+ if len(head_dim) != len(self.layers):
1063
+ raise ValueError(
1064
+ f"`head_dim` was provided as a list of length {len(num_heads)}, but the Cache currently has {len(self.layers)} layers"
1065
+ )
1066
+
1067
+ for layer, layer_num_heads, layer_head_dim in zip(self.layers, num_heads, head_dim):
1068
+ # Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
1069
+ # this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
1070
+ # creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
1071
+ fake_kv_tensor = torch.zeros((batch_size, layer_num_heads, 0, layer_head_dim), dtype=dtype, device=device)
1072
+ # Init the layer
1073
+ layer.lazy_initialization(fake_kv_tensor, fake_kv_tensor)
1074
+
1075
+ def get_seq_length(self, layer_idx: int = 0) -> int:
1076
+ """Returns the sequence length of the cache for the given layer."""
1077
+ if layer_idx >= len(self.layers):
1078
+ return 0
1079
+
1080
+ # For alternating attention/linear attention caches, `get_seq_length` needs to use attention layer idx when called with default layer_idx
1081
+ if not isinstance(self.layers[layer_idx], CacheLayerMixin):
1082
+ # If this is called with non-default arg, raise
1083
+ if layer_idx != 0:
1084
+ raise ValueError(
1085
+ f"You called `get_seq_length` on layer index {layer_idx}, but this layer is a LinearAttention layer, which "
1086
+ "does not track sequence length."
1087
+ )
1088
+ try:
1089
+ # Use the first attention layer
1090
+ layer_idx = next(idx for idx in range(len(self)) if isinstance(self.layers[idx], CacheLayerMixin))
1091
+ except StopIteration:
1092
+ raise ValueError(
1093
+ "`get_seq_length` can only be called on Attention layers, and the current Cache seem to only contain "
1094
+ "LinearAttention layers."
1095
+ )
1096
+
1097
+ return self.layers[layer_idx].get_seq_length()
1098
+
1099
+ def has_previous_state(self, layer_idx: int | None = None) -> bool:
1100
+ """Returns whether the LinearAttention layer at index `layer_idx` has previous state or not."""
1101
+ if layer_idx is not None and layer_idx >= len(self.layers):
1102
+ return False
1103
+
1104
+ # In this case, use last LinearAttention layer
1105
+ if layer_idx is None:
1106
+ try:
1107
+ layer_idx = next(
1108
+ idx
1109
+ for idx in range(len(self) - 1, -1, -1)
1110
+ if isinstance(self.layers[idx], LinearAttentionCacheLayerMixin)
1111
+ )
1112
+ except StopIteration:
1113
+ raise ValueError(
1114
+ "`has_previous_state` can only be called on LinearAttention layers, and the current Cache seem to "
1115
+ "only contain Attention layers."
1116
+ )
1117
+ elif not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin):
1118
+ raise ValueError(
1119
+ f"You called `has_previous_state` on layer index {layer_idx}, but this layer is an Attention layer, which "
1120
+ "does not support calling it."
1121
+ )
1122
+
1123
+ return self.layers[layer_idx].has_previous_state
1124
+
1125
+ def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]:
1126
+ """
1127
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1128
+ the given layer at `layer_idx`.
1129
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
1130
+ """
1131
+ # For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is
1132
+ # simply the query_length
1133
+ if layer_idx >= len(self.layers):
1134
+ return query_length, 0
1135
+
1136
+ # For alternating attention/linear attention caches, `get_mask_sizes` needs to use attention layer idx when called with default layer_idx
1137
+ if not isinstance(self.layers[layer_idx], CacheLayerMixin):
1138
+ # If this is called with non-default arg, raise
1139
+ if layer_idx != 0:
1140
+ raise ValueError(
1141
+ f"You called `get_mask_sizes` on layer index {layer_idx}, but this layer is a LinearAttention layer, which "
1142
+ "does not track sequence length."
1143
+ )
1144
+ try:
1145
+ # Use the first attention layer
1146
+ layer_idx = next(idx for idx in range(len(self)) if isinstance(self.layers[idx], CacheLayerMixin))
1147
+ except StopIteration:
1148
+ raise ValueError(
1149
+ "`get_mask_sizes` can only be called on Attention layers, and the current Cache seem to only contain "
1150
+ "LinearAttention layers."
1151
+ )
1152
+
1153
+ return self.layers[layer_idx].get_mask_sizes(query_length)
1154
+
1155
+ def get_max_cache_shape(self, layer_idx: int = 0) -> int:
1156
+ """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
1157
+ # For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1
1158
+ # as DynamicLayer does
1159
+ if layer_idx >= len(self.layers):
1160
+ return -1
1161
+ return self.layers[layer_idx].get_max_cache_shape()
1162
+
1163
+ def reset(self):
1164
+ """Recursively reset all layers tensors"""
1165
+ for layer_idx in range(len(self.layers)):
1166
+ self.layers[layer_idx].reset()
1167
+
1168
+ def reorder_cache(self, beam_idx: torch.LongTensor):
1169
+ """Reorder the cache for beam search"""
1170
+ for layer_idx in range(len(self.layers)):
1171
+ self.layers[layer_idx].reorder_cache(beam_idx)
1172
+
1173
+ def crop(self, max_length: int):
1174
+ """Crop the cache to the given length"""
1175
+ for layer_idx in range(len(self.layers)):
1176
+ self.layers[layer_idx].crop(max_length)
1177
+
1178
+ def batch_repeat_interleave(self, repeats: int):
1179
+ """Repeat and interleave the cache"""
1180
+ for layer_idx in range(len(self.layers)):
1181
+ self.layers[layer_idx].batch_repeat_interleave(repeats)
1182
+
1183
+ def batch_select_indices(self, indices: torch.Tensor):
1184
+ """Select indices from the cache"""
1185
+ for layer_idx in range(len(self.layers)):
1186
+ self.layers[layer_idx].batch_select_indices(indices)
1187
+
1188
+ @property
1189
+ def max_batch_size(self) -> int:
1190
+ """Return the maximum batch size of the cache"""
1191
+ values = [layer.max_batch_size for layer in self.layers]
1192
+ if len(set(values)) > 1:
1193
+ raise ValueError(f"Max batch size is not consistent across layers: {values}")
1194
+ return values[0]
1195
+
1196
+ @property
1197
+ def max_cache_len(self) -> int:
1198
+ """Return the maximum cache length of the cache"""
1199
+ values = [layer.max_cache_len for layer in self.layers]
1200
+ return max(values)
1201
+
1202
+ @property
1203
+ def is_compileable(self) -> bool:
1204
+ """Return whether the cache is compilable"""
1205
+ # For DynamicCache dispatching the layers lazily (otherwise, all([]) is True)
1206
+ if len(self.layers) == 0:
1207
+ return False
1208
+ return all(layer.is_compileable for layer in self.layers)
1209
+
1210
+ @property
1211
+ def is_initialized(self) -> bool:
1212
+ """Return whether the cache data is initialized"""
1213
+ return len(self.layers) > 0 and all(layer.is_initialized for layer in self.layers)
1214
+
1215
+ @property
1216
+ def is_sliding(self) -> list[bool]:
1217
+ """Return whether the layers of the cache are sliding window"""
1218
+ return [getattr(layer, "is_sliding", False) for layer in self.layers]
1219
+
1220
+ def __len__(self):
1221
+ """
1222
+ This value corresponds to the number of layers in the model.
1223
+ """
1224
+ # Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first
1225
+ # forward through all the layers
1226
+ return len(self.layers)
1227
+
1228
+
1229
+ class DynamicCache(Cache):
1230
+ """
1231
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
1232
+ It stores the key and value states as a list of `CacheLayer`, one for each layer. The expected shape for each tensor
1233
+ in the `CacheLayer`s is `[batch_size, num_heads, seq_len, head_dim]`.
1234
+ If a config is passed, it will additionally check for sliding or hybrid cache structure, greatly reducing the
1235
+ memory requirement of the cached tensors to `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
1236
+
1237
+ See `Cache` for details on common methods that are implemented by all cache classes.
1238
+
1239
+ Args:
1240
+ ddp_cache_data (`Iterable[tuple[torch.Tensor, torch.Tensor]]`, *optional*):
1241
+ It was originally added for compatibility with `torch.distributed` (DDP). In a nutshell, it is
1242
+ `map(gather_map, zip(*caches))`, i.e. each item in the iterable contains the key and value states
1243
+ for a layer gathered across replicas by torch.distributed (shape=[global batch size, num_heads, seq_len, head_dim]).
1244
+ Note: it needs to be the 1st arg as well to work correctly
1245
+ config (`PreTrainedConfig`, *optional*):
1246
+ The config of the model for which this Cache will be used. If passed, it will be used to check for sliding
1247
+ or hybrid layer structure, greatly reducing the memory requirement of the cached tensors to
1248
+ `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
1249
+ offloading (`bool`, *optional*, defaults to `False`):
1250
+ Whether to perform offloading of the layers to `cpu`, to save GPU memory.
1251
+ offload_only_non_sliding (`bool`, *optional*, defaults to `False`):
1252
+ If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
1253
+ usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
1254
+
1255
+ Example:
1256
+
1257
+ ```python
1258
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
1259
+
1260
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1261
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
1262
+
1263
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
1264
+
1265
+ >>> # Prepare a cache class and pass it to model's forward
1266
+ >>> past_key_values = DynamicCache(config=model.config)
1267
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1268
+ >>> outputs.past_key_values # access cache filled with key/values from generation
1269
+ ```
1270
+ """
1271
+
1272
+ def __init__(
1273
+ self,
1274
+ ddp_cache_data: Iterable[tuple[torch.Tensor | None, ...]] | None = None,
1275
+ config: PreTrainedConfig | None = None,
1276
+ offloading: bool = False,
1277
+ offload_only_non_sliding: bool = False,
1278
+ ):
1279
+ layers = []
1280
+ # If a config is passed, use it to infer the layer types and initialize accordingly
1281
+ if config is not None:
1282
+ decoder_config = config.get_text_config(decoder=True)
1283
+ sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
1284
+ decoder_config, "attention_chunk_size", None
1285
+ )
1286
+ layer_types = getattr(decoder_config, "layer_types", None)
1287
+ if layer_types is None:
1288
+ layer_types = []
1289
+ for _ in range(decoder_config.num_hidden_layers):
1290
+ if sliding_window is not None:
1291
+ layer_types.append("sliding_attention")
1292
+ else:
1293
+ layer_types.append("full_attention")
1294
+ # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
1295
+ if hasattr(decoder_config, "num_kv_shared_layers"):
1296
+ layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
1297
+
1298
+ for layer_type in layer_types:
1299
+ cache_cls = LAYER_TYPE_CACHE_MAPPING.get(layer_type, DynamicLayer)
1300
+ layers.append(cache_cls(decoder_config))
1301
+
1302
+ # In this case, use the passed data to already fill in the Cache
1303
+ if ddp_cache_data is not None:
1304
+ # Init all the layers with the data
1305
+ for layer_idx, kv_and_optional_sliding in enumerate(ddp_cache_data):
1306
+ # If the config was not passed above, initialize a new cache layer for each entry of the ddp_data
1307
+ if config is None:
1308
+ # kv_and_optional_sliding contains at least two elements: the key and value states. It can also
1309
+ # contain a third element, which is an optional sliding window tensor.
1310
+ sliding_window_tensor = kv_and_optional_sliding[2] if len(kv_and_optional_sliding) == 3 else None
1311
+ # If there is a sliding window tensor, use it to initialize the layer
1312
+ if sliding_window_tensor is not None:
1313
+ # Since the same layer is dispatched across replicas, sliding_window is the same for all
1314
+ sliding_window = sliding_window_tensor[0].item()
1315
+ layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
1316
+ else:
1317
+ layers.append(DynamicLayer())
1318
+ # Update the layer with the data
1319
+ _, _ = layers[layer_idx].update(kv_and_optional_sliding[0], kv_and_optional_sliding[1])
1320
+
1321
+ # If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer
1322
+ if len(layers) == 0:
1323
+ super().__init__(
1324
+ layer_class_to_replicate=DynamicLayer,
1325
+ offloading=offloading,
1326
+ offload_only_non_sliding=offload_only_non_sliding,
1327
+ )
1328
+ else:
1329
+ super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
1330
+
1331
+ def __iter__(self):
1332
+ for layer in self.layers:
1333
+ yield layer.keys, layer.values, getattr(layer, "_sliding_window_tensor", None)
1334
+
1335
+
1336
+ class StaticCache(Cache):
1337
+ """
1338
+ Static Cache class to be used with `torch.compile(model)` and `torch.export()`. It will check the `config`
1339
+ for potential hybrid cache structure, and initialize each layer accordingly.
1340
+
1341
+ See `Cache` for details on common methods that are implemented by all cache classes.
1342
+
1343
+ Args:
1344
+ config (`PreTrainedConfig`):
1345
+ The config of the model for which this Cache will be used. It will be used to check for sliding
1346
+ or hybrid layer structure, and initialize each layer accordingly.
1347
+ max_cache_len (`int`):
1348
+ The maximum number of tokens that this Cache should hold.
1349
+ offloading (`bool`, *optional*, defaults to `False`):
1350
+ Whether to perform offloading of the layers to `cpu`, to save GPU memory.
1351
+ offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
1352
+ If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
1353
+ usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
1354
+
1355
+ Example:
1356
+
1357
+ ```python
1358
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
1359
+
1360
+ >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
1361
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
1362
+
1363
+ >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
1364
+
1365
+ >>> # Prepare a cache class and pass it to model's forward
1366
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
1367
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
1368
+ >>> past_key_values = StaticCache(config=model.config, max_cache_len=max_generated_length)
1369
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1370
+ >>> outputs.past_key_values # access cache filled with key/values from generation
1371
+ StaticCache()
1372
+ ```
1373
+ """
1374
+
1375
+ # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
1376
+ def __init__(
1377
+ self,
1378
+ config: PreTrainedConfig,
1379
+ max_cache_len: int,
1380
+ offloading: bool = False,
1381
+ offload_only_non_sliding: bool = True,
1382
+ **kwargs,
1383
+ ):
1384
+ config = config.get_text_config(decoder=True)
1385
+ layer_types = getattr(config, "layer_types", None)
1386
+ # If `layer_types` is not explicitly provided, infer if the model is fully sliding
1387
+ if layer_types is None:
1388
+ if getattr(config, "sliding_window", None) is not None:
1389
+ layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
1390
+ elif getattr(config, "attention_chunk_size", None) is not None:
1391
+ layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
1392
+ else:
1393
+ layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
1394
+ # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
1395
+ if hasattr(config, "num_kv_shared_layers"):
1396
+ layer_types = layer_types[: -config.num_kv_shared_layers]
1397
+
1398
+ sliding_layer_types = {
1399
+ name
1400
+ for name, cls in LAYER_TYPE_CACHE_MAPPING.items()
1401
+ if isinstance(cls, type) and issubclass(cls, DynamicSlidingWindowLayer) and name != "chunked_attention"
1402
+ }
1403
+ layers = []
1404
+ for layer_type in layer_types:
1405
+ if layer_type == "chunked_attention":
1406
+ # From a cache point of view, both sliding and chunked are the same in how they should behave and how many
1407
+ # states they should return - only the mask changes to make them different at the end!
1408
+ layer = StaticSlidingWindowLayer(
1409
+ max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size
1410
+ )
1411
+ elif layer_type in sliding_layer_types:
1412
+ layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
1413
+ # LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache
1414
+ elif layer_type in ("mamba", "conv", "linear_attention", "moe"):
1415
+ layer = LinearAttentionLayer()
1416
+ else:
1417
+ layer = StaticLayer(max_cache_len=max_cache_len)
1418
+ layers.append(layer)
1419
+
1420
+ super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
1421
+
1422
+
1423
+ class QuantizedCache(Cache):
1424
+ """
1425
+ A quantizer cache similar to what is described in the
1426
+ [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
1427
+ It allows the model to generate longer sequence length without allocating too much memory for keys and values
1428
+ by applying quantization.
1429
+ The cache has two types of storage, one for original precision and one for the
1430
+ quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the
1431
+ length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache.
1432
+ The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was
1433
+ described in the paper.
1434
+
1435
+ See `Cache` for details on common methods that are implemented by all cache classes.
1436
+
1437
+ Args:
1438
+ backend (`str`):
1439
+ The quantization backend to use. One of `("quanto", "hqq").
1440
+ config (`PreTrainedConfig`):
1441
+ The config of the model for which this Cache will be used.
1442
+ nbits (`int`, *optional*, defaults to 4):
1443
+ The number of bits for quantization.
1444
+ axis_key (`int`, *optional*, defaults to 0):
1445
+ The axis on which to quantize the keys.
1446
+ axis_value (`int`, *optional*, defaults to 0):
1447
+ The axis on which to quantize the values.
1448
+ q_group_size (`int`, *optional*, defaults to 64):
1449
+ Quantization is done per-channel according to a set `q_group_size` for both keys and values.
1450
+ residual_length (`int`, *optional*, defaults to 128):
1451
+ Maximum capacity for the original precision cache
1452
+ """
1453
+
1454
+ def __init__(
1455
+ self,
1456
+ backend: str,
1457
+ config: PreTrainedConfig,
1458
+ nbits: int = 4,
1459
+ axis_key: int = 0,
1460
+ axis_value: int = 0,
1461
+ q_group_size: int = 64,
1462
+ residual_length: int = 128,
1463
+ ):
1464
+ if backend == "quanto":
1465
+ layer_class = QuantoQuantizedLayer
1466
+ elif backend == "hqq":
1467
+ layer_class = HQQQuantizedLayer
1468
+ else:
1469
+ raise ValueError(f"Unknown quantization backend `{backend}`")
1470
+
1471
+ config = config.get_text_config(decoder=True)
1472
+ layers = [
1473
+ layer_class(nbits, axis_key, axis_value, q_group_size, residual_length)
1474
+ for _ in range(config.num_hidden_layers)
1475
+ ]
1476
+ super().__init__(layers=layers)
1477
+
1478
+
1479
+ class EncoderDecoderCache(Cache):
1480
+ """
1481
+ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
1482
+ cross-attention caches.
1483
+
1484
+ See `Cache` for details on common methods that are implemented by all cache classes.
1485
+
1486
+ Args:
1487
+ caches (`Iterable`):
1488
+ Usually an iterable of length 2, containing 2 `Cache` objects, the first one for self-attention, the
1489
+ second one for cross-attention. Can optionally also be an iterable of length 1, containing a
1490
+ `tuple[tuple[torch.Tensor]]` (usually used for compatibility with torch dp and ddp).
1491
+
1492
+ Example:
1493
+
1494
+ ```python
1495
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
1496
+
1497
+ >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
1498
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
1499
+
1500
+ >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
1501
+
1502
+ >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
1503
+ >>> self_attention_cache = DynamicCache(config=self.config)
1504
+ >>> cross_attention_cache = DynamicCache(config=self.config)
1505
+ >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
1506
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1507
+ >>> outputs.past_key_values # access cache filled with key/values from generation
1508
+ EncoderDecoderCache()
1509
+ ```
1510
+ """
1511
+
1512
+ def __init__(self, *caches) -> None:
1513
+ # For dp and ddp support, if only one argument is passed, it should be an iterable of DynamicCache ddp data
1514
+ if len(caches) == 1:
1515
+ self_attention_cache_data, cross_attention_cache_data = [], []
1516
+ for combined_cache_data in caches[0]:
1517
+ if len(combined_cache_data) == 6: # two tuple of style (self_attn_k, self_attn_v, self_attn_sliding)
1518
+ self_attention_cache_data.append(combined_cache_data[:3])
1519
+ cross_attention_cache_data.append(combined_cache_data[3:])
1520
+ # To support old DDP-style init, we handle the case where the tuple has no sliding window tensor
1521
+ elif len(combined_cache_data) == 4: # two tuple of style (self_attn_k, self_attn_v)
1522
+ self_attention_cache_data.append(combined_cache_data[:2])
1523
+ cross_attention_cache_data.append(combined_cache_data[2:])
1524
+ else:
1525
+ raise ValueError(f"Expected {len(combined_cache_data) = } to be 4 or 6.\n{combined_cache_data = }")
1526
+ self.self_attention_cache = DynamicCache(self_attention_cache_data)
1527
+ self.cross_attention_cache = DynamicCache(cross_attention_cache_data)
1528
+ # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
1529
+ elif len(caches) == 2:
1530
+ if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
1531
+ raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }")
1532
+ self.self_attention_cache = caches[0]
1533
+ self.cross_attention_cache = caches[1]
1534
+ # Error case
1535
+ else:
1536
+ raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}")
1537
+
1538
+ self.is_updated = {}
1539
+ for layer_idx in range(len(self.cross_attention_cache)):
1540
+ self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)
1541
+
1542
+ def __iter__(self):
1543
+ """Returns tuples of style (self_attn_k, self_attn_v, self_attn_sliding, cross_attn_k, cross_attn_v, cross_attn_sliding)"""
1544
+ for self_attention_layer, cross_attention_layer in zip(self.self_attention_cache, self.cross_attention_cache):
1545
+ yield self_attention_layer + cross_attention_layer
1546
+
1547
+ def __repr__(self) -> str:
1548
+ return (
1549
+ f"{self.__class__.__name__}(self_attention_cache={self.self_attention_cache}, cross_attention_cache="
1550
+ f"{self.cross_attention_cache})"
1551
+ )
1552
+
1553
+ def __len__(self):
1554
+ """
1555
+ Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds
1556
+ to the number of layers in the model.
1557
+ """
1558
+ return len(self.self_attention_cache)
1559
+
1560
+ def get_seq_length(self, layer_idx: int = 0) -> int:
1561
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
1562
+ return self.self_attention_cache.get_seq_length(layer_idx)
1563
+
1564
+ def reset(self):
1565
+ self.self_attention_cache.reset()
1566
+ self.cross_attention_cache.reset()
1567
+ for layer_idx in self.is_updated:
1568
+ self.is_updated[layer_idx] = False
1569
+
1570
+ def reorder_cache(self, beam_idx: torch.LongTensor):
1571
+ """Reorders the cache for beam search, given the selected beam indices."""
1572
+ self.self_attention_cache.reorder_cache(beam_idx)
1573
+ self.cross_attention_cache.reorder_cache(beam_idx)
1574
+
1575
+ def check_dynamic_cache(self, method: str):
1576
+ if not (
1577
+ isinstance(self.self_attention_cache, DynamicCache)
1578
+ and isinstance(self.cross_attention_cache, DynamicCache)
1579
+ ):
1580
+ raise TypeError(
1581
+ f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
1582
+ f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
1583
+ )
1584
+
1585
+ # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
1586
+ def crop(self, maximum_length: int):
1587
+ """
1588
+ Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
1589
+ negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search (on the Hub).
1590
+ """
1591
+ self.check_dynamic_cache(self.crop.__name__)
1592
+ self.self_attention_cache.crop(maximum_length)
1593
+
1594
+ def batch_repeat_interleave(self, repeats: int):
1595
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search (on the Hub)."""
1596
+ self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
1597
+ self.self_attention_cache.batch_repeat_interleave(repeats)
1598
+ self.cross_attention_cache.batch_repeat_interleave(repeats)
1599
+
1600
+ def batch_select_indices(self, indices: torch.Tensor):
1601
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search (on the Hub)."""
1602
+ self.check_dynamic_cache(self.batch_select_indices.__name__)
1603
+ self.self_attention_cache.batch_select_indices(indices)
1604
+ self.cross_attention_cache.batch_select_indices(indices)
1605
+
1606
+ def get_max_cache_shape(self) -> int:
1607
+ """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
1608
+ return self.self_attention_cache.get_max_cache_shape()
1609
+
1610
+ def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]:
1611
+ return self.self_attention_cache.get_mask_sizes(query_length, layer_idx)
1612
+
1613
+ @property
1614
+ def is_sliding(self):
1615
+ return self.self_attention_cache.is_sliding
1616
+
1617
+ @property
1618
+ def is_compileable(self) -> bool:
1619
+ return self.self_attention_cache.is_compileable
1620
+
1621
+
1622
+ # Deprecated alias: SlidingWindowCache was removed in transformers v5. StaticCache is the replacement.
1623
+ SlidingWindowCache = StaticCache
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/configuration_utils.py ADDED
@@ -0,0 +1,1365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Configuration base class and utilities."""
16
+
17
+ import copy
18
+ import json
19
+ import math
20
+ import os
21
+ from collections.abc import Sequence
22
+ from dataclasses import MISSING, dataclass, fields
23
+ from functools import wraps
24
+ from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar, Union
25
+
26
+ from huggingface_hub import create_repo
27
+ from huggingface_hub.dataclasses import strict
28
+ from packaging import version
29
+ from typing_extensions import dataclass_transform
30
+
31
+ from . import __version__
32
+ from .dynamic_module_utils import custom_object_save
33
+ from .generation.configuration_utils import GenerationConfig
34
+ from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
35
+ from .modeling_rope_utils import RotaryEmbeddingConfigMixin
36
+ from .utils import (
37
+ CONFIG_NAME,
38
+ PushToHubMixin,
39
+ cached_file,
40
+ copy_func,
41
+ extract_commit_hash,
42
+ is_torch_available,
43
+ logging,
44
+ )
45
+ from .utils.generic import is_timm_config_dict
46
+
47
+
48
+ if TYPE_CHECKING:
49
+ import torch
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ # type hinting: specifying the type of config class that inherits from PreTrainedConfig
56
+ SpecificPreTrainedConfigType = TypeVar("SpecificPreTrainedConfigType", bound="PreTrainedConfig")
57
+
58
+ _FLOAT_TAG_KEY = "__float__"
59
+ _FLOAT_TAG_VALUES = {"Infinity": float("inf"), "-Infinity": float("-inf"), "NaN": float("nan")}
60
+
61
+
62
+ ALLOWED_LAYER_TYPES = (
63
+ "full_attention",
64
+ "sliding_attention",
65
+ "chunked_attention",
66
+ "compressed_sparse_attention", # CSA, used in deepseek_v4
67
+ "heavily_compressed_attention", # HCA, used in deepseek_v4
68
+ "linear_attention", # used in minimax
69
+ "conv", # used in LFMv2
70
+ "mamba",
71
+ "attention",
72
+ "sparse",
73
+ "dense",
74
+ "hybrid", # for layers that have both mamba and attention in zamba and zamba2
75
+ "moe", # for nemotron_h, which uses either attention, mamba or moe
76
+ )
77
+
78
+
79
+ # copied from huggingface_hub.dataclasses.strict when `accept_kwargs=True`
80
+ def wrap_init_to_accept_kwargs(cls: dataclass):
81
+ # Get the original dataclass-generated __init__
82
+ original_init = cls.__init__
83
+
84
+ @wraps(original_init)
85
+ def __init__(self, *args, **kwargs: Any) -> None:
86
+ # Extract only the fields that are part of the dataclass
87
+ dataclass_fields = {f.name for f in fields(cls)}
88
+ standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields}
89
+
90
+ # We need to call bare `__init__` without `__post_init__` but the `original_init` of
91
+ # any dataclas contains a call to post-init at the end (without kwargs)
92
+ if len(args) > 0:
93
+ raise ValueError(
94
+ f"{cls.__name__} accepts only keyword arguments, but found `{len(args)}` positional args."
95
+ )
96
+
97
+ for f in fields(cls): # type: ignore
98
+ if f.name in standard_kwargs:
99
+ setattr(self, f.name, standard_kwargs[f.name])
100
+ elif f.default is not MISSING:
101
+ setattr(self, f.name, f.default)
102
+ elif f.default_factory is not MISSING:
103
+ setattr(self, f.name, f.default_factory())
104
+ else:
105
+ raise TypeError(f"Missing required field - '{f.name}'")
106
+
107
+ # Pass any additional kwargs to `__post_init__` and let the object
108
+ # decide whether to set the attr or use for different purposes (e.g. BC checks)
109
+ additional_kwargs = {}
110
+ for name, value in kwargs.items():
111
+ if name not in dataclass_fields:
112
+ additional_kwargs[name] = value
113
+
114
+ self.__post_init__(**additional_kwargs)
115
+
116
+ cls.__init__ = __init__
117
+ return cls
118
+
119
+
120
+ @dataclass_transform(kw_only_default=True)
121
+ @strict(accept_kwargs=True)
122
+ @dataclass(repr=False)
123
+ class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
124
+ # no-format
125
+ r"""
126
+ Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
127
+ methods for loading/downloading/saving configurations.
128
+
129
+ <Tip>
130
+
131
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
132
+ initialize a model does **not** load the model weights. It only affects the model's configuration.
133
+
134
+ </Tip>
135
+
136
+ Class attributes (overridden by derived classes):
137
+
138
+ - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
139
+ the correct object in [`~transformers.AutoConfig`].
140
+ - **has_no_defaults_at_init** (`bool`) -- Whether the config class can be initialized without providing input arguments.
141
+ Some configurations requires inputs to be defined at init and have no default values, usually these are composite configs,
142
+ (but not necessarily) such as [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`]. They have to be initialized from
143
+ two or more configs of type [`~transformers.PreTrainedConfig`].
144
+ - **keys_to_ignore_at_inference** (`list[str]`) -- A list of keys to ignore by default when looking at dictionary
145
+ outputs of the model during inference.
146
+ - **attribute_map** (`dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
147
+ naming of attributes.
148
+ - **base_model_tp_plan** (`dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
149
+ parallel plan applied to the sub-module when `model.tensor_parallel` is called.
150
+ - **base_model_pp_plan** (`dict[str, tuple[list[str]]]`) -- A dict that maps child-modules of a base model to a
151
+ pipeline parallel plan that enables users to place the child-module on the appropriate device.
152
+
153
+ Common attributes (present in all subclasses):
154
+
155
+ - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
156
+ embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
157
+ - **hidden_size** (`int`) -- The hidden size of the model.
158
+ - **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
159
+ model.
160
+ - **num_hidden_layers** (`int`) -- The number of blocks in the model.
161
+
162
+ <Tip warning={true}>
163
+
164
+ Setting parameters for sequence generation in the model config is deprecated. For backward compatibility, loading
165
+ some of them will still be possible, but attempting to overwrite them will throw an exception -- you should set
166
+ them in a [~transformers.GenerationConfig]. Check the documentation of [~transformers.GenerationConfig] for more
167
+ information about the individual parameters.
168
+
169
+ </Tip>
170
+
171
+ Arg:
172
+ name_or_path (`str`, *optional*, defaults to `""`):
173
+ Store the string that was passed to [`PreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path`
174
+ if the configuration was created with such a method.
175
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
176
+ Whether or not the model should return all hidden-states.
177
+ output_attentions (`bool`, *optional*, defaults to `False`):
178
+ Whether or not the model should returns all attentions.
179
+ return_dict (`bool`, *optional*, defaults to `True`):
180
+ Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
181
+ is_encoder_decoder (`bool`, *optional*, defaults to `False`):
182
+ Whether the model is used as an encoder/decoder or not.
183
+ chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
184
+ The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
185
+ the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
186
+ sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
187
+ Forward Chunking work?](../glossary.html#feed-forward-chunking).
188
+
189
+ > Parameters for fine-tuning tasks
190
+
191
+ architectures (`list[str]`, *optional*):
192
+ Model architectures that can be used with the model pretrained weights.
193
+ id2label (`dict[int, str]`, *optional*):
194
+ A map from index (for instance prediction index, or target index) to label.
195
+ label2id (`dict[str, int]`, *optional*):
196
+ A map from label to index for the model.
197
+ num_labels (`int`, *optional*):
198
+ Number of labels to use in the last layer added to the model, typically for a classification task.
199
+ problem_type (`str`, *optional*):
200
+ Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
201
+ `"single_label_classification"` or `"multi_label_classification"`.
202
+
203
+ > PyTorch specific parameters
204
+
205
+ dtype (`str`, *optional*):
206
+ The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
207
+ (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
208
+ model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
209
+ `float16` weights.
210
+ """
211
+
212
+ # Class attributes that we don't want to save or have in `self.__dict__`
213
+ # They are not supposed to be set/changed by users. Each field is set when
214
+ # creating a model class
215
+ base_config_key: ClassVar[str] = ""
216
+ sub_configs: ClassVar[dict[str, type["PreTrainedConfig"]]] = {}
217
+ has_no_defaults_at_init: ClassVar[bool] = False
218
+ keys_to_ignore_at_inference: ClassVar[list[str]] = []
219
+ attribute_map: ClassVar[dict[str, str]] = {}
220
+ base_model_tp_plan: ClassVar[dict[str, Any] | None] = None
221
+ base_model_pp_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
222
+ base_model_ep_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
223
+ _auto_class: ClassVar[str | None] = None
224
+
225
+ # Attributes set internally when saving and used to infer model
226
+ # class for `Auto` mapping
227
+ model_type: ClassVar[str] = ""
228
+ transformers_version: str | None = None
229
+ architectures: list[str] | None = None
230
+
231
+ # Common attributes for all models
232
+ output_hidden_states: bool | None = False
233
+ return_dict: bool | None = True
234
+ dtype: Union[str, "torch.dtype"] | None = None
235
+ chunk_size_feed_forward: int = 0
236
+ is_encoder_decoder: bool = False
237
+
238
+ # Fine-tuning task arguments
239
+ id2label: dict[int, str] | dict[str, str] | None = None
240
+ label2id: dict[str, int] | dict[str, str] | None = None
241
+ problem_type: Literal["regression", "single_label_classification", "multi_label_classification"] | None = None
242
+
243
+ def __post_init__(self, **kwargs):
244
+ # BC for the `torch_dtype` argument instead of the simpler `dtype`
245
+ # Do not warn, as it would otherwise always be triggered since most configs on the hub have `torch_dtype`
246
+ if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
247
+ # If both are provided, keep `dtype`
248
+ self.dtype = self.dtype if self.dtype is not None else torch_dtype
249
+ if self.dtype is not None and isinstance(self.dtype, str) and is_torch_available():
250
+ # we will start using self.dtype in v5, but to be consistent with
251
+ # from_pretrained's dtype arg convert it to an actual torch.dtype object
252
+ import torch
253
+
254
+ self.dtype = getattr(torch, self.dtype)
255
+
256
+ # Keep the default value of `num_labels=2` in case users have saved a classfier with 2 labels
257
+ # Our configs prev wouldn't save `id2label` for 2 labels because it is the default. In all other
258
+ # cases we expect the config dict to have an `id2label` field if it's a clf model, or not otherwise
259
+ if self.id2label is None:
260
+ self.num_labels = kwargs.get("num_labels", 2)
261
+ else:
262
+ if kwargs.get("num_labels") is not None and len(self.id2label) != kwargs.get("num_labels"):
263
+ logger.warning(
264
+ f"You passed `num_labels={kwargs.get('num_labels')}` which is incompatible to "
265
+ f"the `id2label` map of length `{len(self.id2label)}`."
266
+ )
267
+ # Keys are always strings in JSON so convert ids to int
268
+ self.id2label = {int(key): value for key, value in self.id2label.items()}
269
+
270
+ if self.problem_type == "single_label_classification" and self.num_labels == 1:
271
+ raise ValueError(
272
+ '`problem_type="single_label_classification"` requires `num_labels > 1`. For binary '
273
+ 'classification use `num_labels=2`, or use `problem_type="regression"` for a '
274
+ "single-output regression head."
275
+ )
276
+
277
+ # BC for rotary embeddings. We will pop out legacy keys from kwargs and rename to new format
278
+ if hasattr(self, "rope_parameters"):
279
+ kwargs = self.convert_rope_params_to_dict(**kwargs)
280
+ elif kwargs.get("rope_scaling") and kwargs.get("rope_theta"):
281
+ logger.warning(
282
+ f"{self.__class__.__name__} got `key=rope_scaling` in kwargs but hasn't set it as attribute. "
283
+ "For RoPE standardization you need to set `self.rope_parameters` in model's config. "
284
+ )
285
+ kwargs = self.convert_rope_params_to_dict(**kwargs)
286
+
287
+ # Parameters for sequence generation saved in the config are popped instead of loading them.
288
+ for parameter_name in GenerationConfig._get_default_generation_params().keys():
289
+ kwargs.pop(parameter_name, None)
290
+
291
+ # Name or path to the pretrained checkpoint
292
+ self._name_or_path = str(kwargs.pop("name_or_path", ""))
293
+ self._commit_hash = kwargs.pop("_commit_hash", None)
294
+
295
+ # Attention/Experts implementation to use, if relevant (it sets it recursively on sub-configs)
296
+ self._output_attentions: bool | None = kwargs.pop("output_attentions", False)
297
+ self._attn_implementation: str | None = kwargs.pop("attn_implementation", None)
298
+ self._experts_implementation: str | None = kwargs.pop("experts_implementation", None)
299
+
300
+ # Additional attributes without default values
301
+ for key, value in kwargs.items():
302
+ # Check this to avoid deserializing problematic fields from hub configs - they should use the public field
303
+ if key not in ("_attn_implementation_internal", "_experts_implementation_internal"):
304
+ try:
305
+ setattr(self, key, value)
306
+ except AttributeError as err:
307
+ logger.error(f"Can't set {key} with value {value} for {self}")
308
+ raise err
309
+
310
+ def __init_subclass__(cls, *args, **kwargs):
311
+ super().__init_subclass__(*args, **kwargs)
312
+ cls_has_custom_init = "__init__" in cls.__dict__
313
+ # kw_only=True ensures fields without defaults in subclasses can follow
314
+ # parent fields that have defaults (Python dataclass ordering rule).
315
+ # Config fields are always passed as keyword arguments, so this is safe.
316
+ cls = dataclass(cls, repr=False, kw_only=True)
317
+
318
+ if not cls_has_custom_init:
319
+ # Wrap all subclasses to accept arbitrary kwargs for BC
320
+ # only if the subclass has no custom `__init__`. Most
321
+ # remote code has an init defined, but some model are not
322
+ # See https://huggingface.co/hmellor/Ilama-3.2-1B/blob/main/configuration_ilama.py
323
+ cls = wrap_init_to_accept_kwargs(cls)
324
+
325
+ @property
326
+ def name_or_path(self) -> str | None:
327
+ return getattr(self, "_name_or_path", None)
328
+
329
+ @name_or_path.setter
330
+ def name_or_path(self, value):
331
+ self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
332
+
333
+ @property
334
+ def num_labels(self) -> int:
335
+ """
336
+ `int`: The number of labels for classification models.
337
+ """
338
+ return len(self.id2label) if self.id2label is not None else None
339
+
340
+ @num_labels.setter
341
+ def num_labels(self, num_labels: int):
342
+ # we do not store `num_labels` attribute in config, but instead
343
+ # compute it based on the length of the `id2label` map
344
+ if self.id2label is None or self.num_labels != num_labels:
345
+ self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
346
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
347
+
348
+ @property
349
+ def output_attentions(self):
350
+ """
351
+ `bool`: Whether or not the model should returns all attentions.
352
+ """
353
+ return self._output_attentions
354
+
355
+ @output_attentions.setter
356
+ def output_attentions(self, value: bool):
357
+ # If we set `output_attentions` explicitly before the attn implementation, dispatch eager
358
+ if value and self._attn_implementation is None:
359
+ self._attn_implementation = "eager"
360
+ if value and self._attn_implementation != "eager":
361
+ raise ValueError(
362
+ "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
363
+ f"{self._attn_implementation}. Please set it to 'eager' instead."
364
+ )
365
+ self._output_attentions = value
366
+
367
+ @property
368
+ def _attn_implementation(self):
369
+ return self._attn_implementation_internal
370
+
371
+ @_attn_implementation.setter
372
+ def _attn_implementation(self, value: str | dict | None):
373
+ """We set it recursively on the sub-configs as well"""
374
+ # Set if for current config
375
+ current_attn = getattr(self, "_attn_implementation", None)
376
+ attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn)
377
+ self._attn_implementation_internal = attn_implementation
378
+
379
+ # Set it recursively on the subconfigs
380
+ for subconfig_key in self.sub_configs:
381
+ subconfig = getattr(self, subconfig_key, None)
382
+ if subconfig is not None:
383
+ current_subconfig_attn = getattr(subconfig, "_attn_implementation", None)
384
+ sub_implementation = (
385
+ value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn)
386
+ )
387
+ subconfig._attn_implementation = sub_implementation
388
+
389
+ @property
390
+ def _experts_implementation(self):
391
+ return self._experts_implementation_internal
392
+
393
+ @_experts_implementation.setter
394
+ def _experts_implementation(self, value: str | dict | None):
395
+ """We set it recursively on the sub-configs as well"""
396
+ # Set if for current config
397
+ current_moe = getattr(self, "_experts_implementation", None)
398
+ experts_implementation = value if not isinstance(value, dict) else value.get("", current_moe)
399
+ self._experts_implementation_internal = experts_implementation
400
+
401
+ # Set it recursively on the subconfigs
402
+ for subconfig_key in self.sub_configs:
403
+ subconfig = getattr(self, subconfig_key, None)
404
+ if subconfig is not None:
405
+ current_subconfig_moe = getattr(subconfig, "_experts_implementation", None)
406
+ sub_implementation = (
407
+ value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_moe)
408
+ )
409
+ subconfig._experts_implementation = sub_implementation
410
+
411
+ @property
412
+ def torch_dtype(self):
413
+ logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
414
+ return self.dtype
415
+
416
+ @property
417
+ def use_return_dict(self):
418
+ logger.warning_once("`use_return_dict` is deprecated! Use `return_dict` instead!")
419
+ return self.return_dict
420
+
421
+ @torch_dtype.setter
422
+ def torch_dtype(self, value):
423
+ logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
424
+ self.dtype = value
425
+
426
+ def __setattr__(self, key, value):
427
+ if key in super().__getattribute__("attribute_map"):
428
+ key = super().__getattribute__("attribute_map")[key]
429
+ super().__setattr__(key, value)
430
+
431
+ def __getattribute__(self, key):
432
+ if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
433
+ key = super().__getattribute__("attribute_map")[key]
434
+ return super().__getattribute__(key)
435
+
436
+ def validate_output_attentions(self):
437
+ if self.output_attentions and self._attn_implementation not in ["eager", None]:
438
+ raise ValueError(
439
+ "The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
440
+ f"{self._attn_implementation}. Please set it to 'eager' instead."
441
+ )
442
+
443
+ def validate_architecture(self):
444
+ """Part of `@strict`-powered validation. Validates the architecture of the config."""
445
+ if (
446
+ hasattr(self, "head_dim")
447
+ and hasattr(self, "num_heads")
448
+ and hasattr(self, "embed_dim")
449
+ and self.head_dim * self.num_heads != self.embed_dim
450
+ ):
451
+ raise ValueError(
452
+ f"The embed_dim ({self.embed_dim}) is not a multiple of the number of attention "
453
+ f"heads ({self.num_heads})."
454
+ )
455
+
456
+ def validate_token_ids(self):
457
+ """Part of `@strict`-powered validation. Validates the contents of the special tokens."""
458
+ text_config = self.get_text_config(decoder=True)
459
+ vocab_size = getattr(text_config, "vocab_size", None)
460
+ if vocab_size is not None:
461
+ # Check for all special tokens, e..g. pad_token_id, image_token_id, audio_token_id
462
+ for name in text_config:
463
+ value = getattr(text_config, name)
464
+ if name.endswith("_token_id") and isinstance(value, int) and not 0 <= value < vocab_size:
465
+ # Can't be an exception until we can load configs that fail validation: several configs on the Hub
466
+ # store invalid special tokens, e.g. `pad_token_id=-1`
467
+ logger.warning_once(
468
+ f"Model config: {name} must be `None` or an integer within the vocabulary (between 0 "
469
+ f"and {vocab_size - 1}), got {value}. This may result in unexpected behavior."
470
+ )
471
+
472
+ def validate_layer_type(self):
473
+ """Check that `layer_types` is correctly defined."""
474
+ if not (getattr(self, "layer_types", None) is not None and hasattr(self, "num_hidden_layers")):
475
+ return
476
+ elif not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in self.layer_types):
477
+ raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES} but got {self.layer_types}")
478
+ elif self.num_hidden_layers is not None and self.num_hidden_layers != len(self.layer_types):
479
+ raise ValueError(
480
+ f"`num_hidden_layers` ({self.num_hidden_layers}) must be equal to the number of layer types "
481
+ f"({len(self.layer_types)})"
482
+ )
483
+
484
+ @property
485
+ def rope_scaling(self):
486
+ return self.rope_parameters
487
+
488
+ @rope_scaling.setter
489
+ def rope_scaling(self, value):
490
+ self.rope_parameters = value
491
+
492
+ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
493
+ """
494
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
495
+ [`~PreTrainedConfig.from_pretrained`] class method.
496
+
497
+ Args:
498
+ save_directory (`str` or `os.PathLike`):
499
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
500
+ push_to_hub (`bool`, *optional*, defaults to `False`):
501
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
502
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
503
+ namespace).
504
+ kwargs (`dict[str, Any]`, *optional*):
505
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
506
+ """
507
+ if os.path.isfile(save_directory):
508
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
509
+
510
+ generation_parameters = self._get_generation_parameters()
511
+ if len(generation_parameters) > 0:
512
+ raise ValueError(
513
+ "Some generation parameters are set in the model config. These should go into `model.generation_config`"
514
+ f"as opposed to `model.config`. \nGeneration parameters found: {str(generation_parameters)}",
515
+ )
516
+
517
+ os.makedirs(save_directory, exist_ok=True)
518
+
519
+ if push_to_hub:
520
+ commit_message = kwargs.pop("commit_message", None)
521
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
522
+ repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id
523
+ files_timestamps = self._get_files_timestamps(save_directory)
524
+
525
+ # This attribute is important to know on load, but should not be serialized on save.
526
+ if "transformers_weights" in self:
527
+ delattr(self, "transformers_weights")
528
+
529
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
530
+ # loaded from the Hub.
531
+ if self._auto_class is not None:
532
+ custom_object_save(self, save_directory, config=self)
533
+
534
+ # If we save using the predefined names, we can load using `from_pretrained`
535
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
536
+
537
+ # Strict validation at save-time: prevent bad patterns from propagating
538
+ # Using `strict` decorator guarantees that `self.validate` exists , but not all
539
+ # model config might have the decorator added
540
+ if hasattr(self, "validate"):
541
+ self.validate()
542
+ self.to_json_file(output_config_file, use_diff=True)
543
+ logger.info(f"Configuration saved in {output_config_file}")
544
+
545
+ if push_to_hub:
546
+ self._upload_modified_files(
547
+ save_directory,
548
+ repo_id,
549
+ files_timestamps,
550
+ commit_message=commit_message,
551
+ token=kwargs.get("token"),
552
+ )
553
+
554
+ @classmethod
555
+ def from_pretrained(
556
+ cls: type[SpecificPreTrainedConfigType],
557
+ pretrained_model_name_or_path: str | os.PathLike,
558
+ cache_dir: str | os.PathLike | None = None,
559
+ force_download: bool = False,
560
+ local_files_only: bool = False,
561
+ token: str | bool | None = None,
562
+ revision: str = "main",
563
+ **kwargs,
564
+ ) -> SpecificPreTrainedConfigType:
565
+ r"""
566
+ Instantiate a [`PreTrainedConfig`] (or a derived class) from a pretrained model configuration.
567
+
568
+ Args:
569
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
570
+ This can be either:
571
+
572
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
573
+ huggingface.co.
574
+ - a path to a *directory* containing a configuration file saved using the
575
+ [`~PreTrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
576
+ - a path to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
577
+ cache_dir (`str` or `os.PathLike`, *optional*):
578
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
579
+ standard cache should not be used.
580
+ force_download (`bool`, *optional*, defaults to `False`):
581
+ Whether or not to force to (re-)download the configuration files and override the cached versions if
582
+ they exist.
583
+ proxies (`dict[str, str]`, *optional*):
584
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
585
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
586
+ token (`str` or `bool`, *optional*):
587
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
588
+ the token generated when running `hf auth login` (stored in `~/.huggingface`).
589
+ revision (`str`, *optional*, defaults to `"main"`):
590
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
591
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
592
+ identifier allowed by git.
593
+
594
+ <Tip>
595
+
596
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
597
+
598
+ </Tip>
599
+
600
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
601
+ If `False`, then this function returns just the final configuration object.
602
+
603
+ If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
604
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
605
+ part of `kwargs` which has not been used to update `config` and is otherwise ignored.
606
+ subfolder (`str`, *optional*, defaults to `""`):
607
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
608
+ specify the folder name here.
609
+ kwargs (`dict[str, Any]`, *optional*):
610
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
611
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
612
+ by the `return_unused_kwargs` keyword parameter.
613
+
614
+ Returns:
615
+ [`PreTrainedConfig`]: The configuration object instantiated from this pretrained model.
616
+
617
+ Examples:
618
+
619
+ ```python
620
+ # We can't instantiate directly the base class *PreTrainedConfig* so let's show the examples on a
621
+ # derived class: BertConfig
622
+ config = BertConfig.from_pretrained(
623
+ "google-bert/bert-base-uncased"
624
+ ) # Download configuration from huggingface.co and cache.
625
+ config = BertConfig.from_pretrained(
626
+ "./test/saved_model/"
627
+ ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
628
+ config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
629
+ config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
630
+ assert config.output_attentions == True
631
+ config, unused_kwargs = BertConfig.from_pretrained(
632
+ "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
633
+ )
634
+ assert config.output_attentions == True
635
+ assert unused_kwargs == {"foo": False}
636
+ ```"""
637
+ kwargs["cache_dir"] = cache_dir
638
+ kwargs["force_download"] = force_download
639
+ kwargs["local_files_only"] = local_files_only
640
+ kwargs["revision"] = revision
641
+
642
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
643
+ if cls.base_config_key and cls.base_config_key in config_dict:
644
+ config_dict = config_dict[cls.base_config_key]
645
+
646
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
647
+ # sometimes the config has no `base_config_key` if the config is used in several composite models
648
+ # e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
649
+ for v in config_dict.values():
650
+ if isinstance(v, dict) and v.get("model_type") == cls.model_type:
651
+ config_dict = v
652
+
653
+ # raise warning only if we still can't see a match in `model_type`
654
+ if config_dict["model_type"] != cls.model_type:
655
+ logger.warning(
656
+ f"You are using a model of type `{config_dict['model_type']}` to instantiate a model of type "
657
+ f"`{cls.model_type}`. This may be expected if you are loading a checkpoint that shares a subset "
658
+ f"of the architecture (e.g., loading a `sam2_video` checkpoint into `Sam2Model`), but is otherwise "
659
+ f"not supported and can yield errors. Please verify that the checkpoint is compatible with the "
660
+ f"model you are instantiating."
661
+ )
662
+
663
+ return cls.from_dict(config_dict, **kwargs)
664
+
665
+ @classmethod
666
+ def get_config_dict(
667
+ cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
668
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
669
+ """
670
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
671
+ [`PreTrainedConfig`] using `from_dict`.
672
+
673
+ Parameters:
674
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
675
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
676
+
677
+ Returns:
678
+ `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
679
+
680
+ """
681
+ original_kwargs = copy.deepcopy(kwargs)
682
+ # Get config dict associated with the base config file
683
+ config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
684
+ if config_dict is None:
685
+ return {}, kwargs
686
+ if "_commit_hash" in config_dict:
687
+ original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
688
+
689
+ # That config file may point us toward another config file to use.
690
+ if "configuration_files" in config_dict:
691
+ configuration_file = get_configuration_file(config_dict["configuration_files"])
692
+ config_dict, kwargs = cls._get_config_dict(
693
+ pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
694
+ )
695
+
696
+ return config_dict, kwargs
697
+
698
+ @classmethod
699
+ def _get_config_dict(
700
+ cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
701
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
702
+ cache_dir = kwargs.pop("cache_dir", None)
703
+ force_download = kwargs.pop("force_download", False)
704
+ proxies = kwargs.pop("proxies", None)
705
+ token = kwargs.pop("token", None)
706
+ local_files_only = kwargs.pop("local_files_only", False)
707
+ revision = kwargs.pop("revision", None)
708
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
709
+ subfolder = kwargs.pop("subfolder", "")
710
+ from_pipeline = kwargs.pop("_from_pipeline", None)
711
+ from_auto_class = kwargs.pop("_from_auto", False)
712
+ commit_hash = kwargs.pop("_commit_hash", None)
713
+
714
+ gguf_file = kwargs.get("gguf_file")
715
+
716
+ if trust_remote_code is True:
717
+ logger.warning(
718
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
719
+ " ignored."
720
+ )
721
+
722
+ user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
723
+ if from_pipeline is not None:
724
+ user_agent["using_pipeline"] = from_pipeline
725
+
726
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
727
+
728
+ is_local = os.path.isdir(pretrained_model_name_or_path)
729
+ if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
730
+ # Special case when pretrained_model_name_or_path is a local file
731
+ resolved_config_file = pretrained_model_name_or_path
732
+ is_local = True
733
+ else:
734
+ configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
735
+
736
+ try:
737
+ # Load from local folder or from cache or download from model Hub and cache
738
+ resolved_config_file = cached_file(
739
+ pretrained_model_name_or_path,
740
+ configuration_file,
741
+ cache_dir=cache_dir,
742
+ force_download=force_download,
743
+ proxies=proxies,
744
+ local_files_only=local_files_only,
745
+ token=token,
746
+ user_agent=user_agent,
747
+ revision=revision,
748
+ subfolder=subfolder,
749
+ _commit_hash=commit_hash,
750
+ )
751
+ if resolved_config_file is None:
752
+ return None, kwargs
753
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
754
+ except OSError:
755
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
756
+ # the original exception.
757
+ raise
758
+ except Exception:
759
+ # For any other exception, we throw a generic error.
760
+ raise OSError(
761
+ f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
762
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
763
+ f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
764
+ f" containing a {configuration_file} file"
765
+ )
766
+
767
+ try:
768
+ if gguf_file:
769
+ config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
770
+ else:
771
+ # Load config dict
772
+ config_dict = cls._dict_from_json_file(resolved_config_file)
773
+
774
+ config_dict["_commit_hash"] = commit_hash
775
+ except (json.JSONDecodeError, UnicodeDecodeError):
776
+ raise OSError(f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file.")
777
+
778
+ if is_local:
779
+ logger.info(f"loading configuration file {resolved_config_file}")
780
+ else:
781
+ logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
782
+
783
+ # timm models are not saved with the model_type in the config file
784
+ if "model_type" not in config_dict and is_timm_config_dict(config_dict):
785
+ config_dict["model_type"] = "timm_wrapper"
786
+
787
+ # Some checkpoints may contain the wrong model_type in the config file.
788
+ # Allow the user to override it but warn them that it might not work.
789
+ if "model_type" in kwargs and config_dict["model_type"] != kwargs["model_type"]:
790
+ logger.warning(
791
+ f"{configuration_file} has 'model_type={config_dict['model_type']}' but you overrode "
792
+ f"it with 'model_type={kwargs['model_type']}'. This may lead to unexpected behavior."
793
+ )
794
+ config_dict["model_type"] = kwargs["model_type"]
795
+
796
+ return config_dict, kwargs
797
+
798
+ @classmethod
799
+ def from_dict(
800
+ cls: type[SpecificPreTrainedConfigType], config_dict: dict[str, Any], **kwargs
801
+ ) -> SpecificPreTrainedConfigType:
802
+ """
803
+ Instantiates a [`PreTrainedConfig`] from a Python dictionary of parameters.
804
+
805
+ Args:
806
+ config_dict (`dict[str, Any]`):
807
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
808
+ retrieved from a pretrained checkpoint by leveraging the [`~PreTrainedConfig.get_config_dict`] method.
809
+ kwargs (`dict[str, Any]`):
810
+ Additional parameters from which to initialize the configuration object.
811
+
812
+ Returns:
813
+ [`PreTrainedConfig`]: The configuration object instantiated from those parameters.
814
+ """
815
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
816
+
817
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
818
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
819
+ kwargs.setdefault("_commit_hash", config_dict["_commit_hash"])
820
+
821
+ # To remove arg here are those passed along for our internal telemetry but we still need to remove them
822
+ to_remove = ["_from_auto", "_from_pipeline"]
823
+ valid_fields = [
824
+ "num_labels",
825
+ "attn_implementation",
826
+ "experts_implementation",
827
+ "output_attentions",
828
+ "torch_dtype",
829
+ "dtype",
830
+ "name_or_path",
831
+ ]
832
+ for key, value in kwargs.items():
833
+ if key in valid_fields:
834
+ if key not in ["torch_dtype", "dtype"]:
835
+ config_dict[key] = value
836
+ to_remove.append(key)
837
+ elif value != "auto":
838
+ config_dict[key] = value
839
+
840
+ config = cls(**config_dict)
841
+
842
+ for key, value in kwargs.items():
843
+ if hasattr(config, key):
844
+ current_attr = getattr(config, key)
845
+ # To authorize passing a custom subconfig as kwarg in models that have nested configs.
846
+ # We need to update only custom kwarg values instead and keep other attr in subconfig.
847
+ if isinstance(current_attr, PreTrainedConfig) and isinstance(value, dict):
848
+ current_attr_updated = current_attr.to_dict()
849
+ current_attr_updated.update(value)
850
+ value = current_attr.__class__(**current_attr_updated)
851
+ setattr(config, key, value)
852
+ to_remove.append(key)
853
+
854
+ for key in to_remove:
855
+ kwargs.pop(key, None)
856
+
857
+ logger.info(f"Model config {config}")
858
+ if return_unused_kwargs:
859
+ return config, kwargs
860
+ else:
861
+ return config
862
+
863
+ @classmethod
864
+ def from_json_file(
865
+ cls: type[SpecificPreTrainedConfigType], json_file: str | os.PathLike
866
+ ) -> SpecificPreTrainedConfigType:
867
+ """
868
+ Instantiates a [`PreTrainedConfig`] from the path to a JSON file of parameters.
869
+
870
+ Args:
871
+ json_file (`str` or `os.PathLike`):
872
+ Path to the JSON file containing the parameters.
873
+
874
+ Returns:
875
+ [`PreTrainedConfig`]: The configuration object instantiated from that JSON file.
876
+
877
+ """
878
+ config_dict = cls._dict_from_json_file(json_file)
879
+ return cls(**config_dict)
880
+
881
+ @classmethod
882
+ def _dict_from_json_file(cls, json_file: str | os.PathLike):
883
+ with open(json_file, encoding="utf-8") as reader:
884
+ text = reader.read()
885
+ config_dict = json.loads(text)
886
+
887
+ return cls._decode_special_floats(config_dict)
888
+
889
+ @classmethod
890
+ def _encode_special_floats(cls, obj: Any) -> Any:
891
+ """
892
+ Iterates over the passed object and encode specific floats that cannot be JSON-serialized. Python's JSON
893
+ engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
894
+
895
+ It serializes floats like `Infinity` as an object: `{'__float__': Infinity}`.
896
+ """
897
+ if isinstance(obj, float):
898
+ if math.isnan(obj):
899
+ return {_FLOAT_TAG_KEY: "NaN"}
900
+ if obj == float("inf"):
901
+ return {_FLOAT_TAG_KEY: "Infinity"}
902
+ if obj == float("-inf"):
903
+ return {_FLOAT_TAG_KEY: "-Infinity"}
904
+ return obj
905
+
906
+ if isinstance(obj, dict):
907
+ return {k: cls._encode_special_floats(v) for k, v in obj.items()}
908
+
909
+ if isinstance(obj, (list, tuple)):
910
+ return [cls._encode_special_floats(v) for v in obj]
911
+
912
+ return obj
913
+
914
+ @classmethod
915
+ def _decode_special_floats(cls, obj: Any) -> Any:
916
+ """
917
+ Iterates over the passed object and decode specific floats that cannot be JSON-serialized. Python's JSON
918
+ engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
919
+
920
+ This method deserializes objects like `{'__float__': Infinity}` to their float values like `Infinity`.
921
+ """
922
+ if isinstance(obj, dict):
923
+ if set(obj.keys()) == {_FLOAT_TAG_KEY} and isinstance(obj[_FLOAT_TAG_KEY], str):
924
+ tag = obj[_FLOAT_TAG_KEY]
925
+ if tag in _FLOAT_TAG_VALUES:
926
+ return _FLOAT_TAG_VALUES[tag]
927
+ return obj
928
+
929
+ return {k: cls._decode_special_floats(v) for k, v in obj.items()}
930
+
931
+ if isinstance(obj, list):
932
+ return [cls._decode_special_floats(v) for v in obj]
933
+
934
+ return obj
935
+
936
+ def __eq__(self, other):
937
+ return isinstance(other, PreTrainedConfig) and (self.__dict__ == other.__dict__)
938
+
939
+ def __repr__(self):
940
+ return f"{self.__class__.__name__} {self.to_json_string()}"
941
+
942
+ def __iter__(self):
943
+ yield from self.__dict__
944
+
945
+ def to_diff_dict(self) -> dict[str, Any]:
946
+ """
947
+ Removes all attributes from the configuration that correspond to the default config attributes for
948
+ better readability, while always retaining the `config` attribute from the class. Serializes to a
949
+ Python dictionary.
950
+
951
+ Returns:
952
+ dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
953
+ """
954
+ config_dict = self.to_dict()
955
+
956
+ # Get the default config dict (from a fresh PreTrainedConfig instance)
957
+ default_config_dict = PreTrainedConfig().to_dict()
958
+
959
+ # get class specific config dict
960
+ class_config_dict = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
961
+
962
+ serializable_config_dict = {}
963
+
964
+ # Only serialize values that differ from the default config,
965
+ # except always keep the 'config' attribute.
966
+ for key, value in config_dict.items():
967
+ if (
968
+ isinstance(getattr(self, key, None), PreTrainedConfig)
969
+ and key in class_config_dict
970
+ and isinstance(class_config_dict[key], dict)
971
+ ):
972
+ # For nested configs we need to clean the diff recursively
973
+ diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
974
+ if "model_type" in value:
975
+ # Needs to be set even if it's not in the diff
976
+ diff["model_type"] = value["model_type"]
977
+
978
+ serializable_config_dict[key] = diff
979
+ elif (
980
+ key not in default_config_dict
981
+ or key == "transformers_version"
982
+ or key == "vocab_file"
983
+ or value != default_config_dict[key]
984
+ or (key in default_config_dict and value != class_config_dict.get(key, value))
985
+ ):
986
+ serializable_config_dict[key] = value
987
+
988
+ self._remove_keys_not_serialized(serializable_config_dict)
989
+
990
+ # Key removed only in diff dict
991
+ if "_name_or_path" in serializable_config_dict:
992
+ del serializable_config_dict["_name_or_path"]
993
+
994
+ if hasattr(self, "quantization_config"):
995
+ serializable_config_dict["quantization_config"] = (
996
+ self.quantization_config.to_dict()
997
+ if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
998
+ else self.quantization_config
999
+ )
1000
+ self.dict_dtype_to_str(serializable_config_dict)
1001
+
1002
+ return serializable_config_dict
1003
+
1004
+ def to_dict(self) -> dict[str, Any]:
1005
+ """
1006
+ Serializes this instance to a Python dictionary.
1007
+
1008
+ Returns:
1009
+ `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
1010
+ """
1011
+ output = copy.deepcopy(self.__dict__)
1012
+ if hasattr(self.__class__, "model_type"):
1013
+ output["model_type"] = self.__class__.model_type
1014
+
1015
+ # Transformers version when serializing the model
1016
+ output["transformers_version"] = __version__
1017
+
1018
+ # Pop "kwargs" since they are unpacked and set in the post init
1019
+ output.pop("kwargs", None)
1020
+
1021
+ def to_list(value):
1022
+ if isinstance(value, tuple):
1023
+ value = [to_list(item) for item in value]
1024
+ return value
1025
+
1026
+ for key, value in output.items():
1027
+ # Deal with nested configs like CLIP
1028
+ if isinstance(value, PreTrainedConfig):
1029
+ value = value.to_dict()
1030
+ del value["transformers_version"]
1031
+
1032
+ # Some models have defaults as tuples because dataclass
1033
+ # doesn't allow mutables. Let's convert back to `list``
1034
+ elif isinstance(value, tuple):
1035
+ value = to_list(value)
1036
+
1037
+ output[key] = value
1038
+
1039
+ self._remove_keys_not_serialized(output)
1040
+
1041
+ if hasattr(self, "quantization_config"):
1042
+ output["quantization_config"] = (
1043
+ self.quantization_config.to_dict()
1044
+ if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
1045
+ else self.quantization_config
1046
+ )
1047
+ self.dict_dtype_to_str(output)
1048
+
1049
+ return output
1050
+
1051
+ def to_json_string(self, use_diff: bool = True) -> str:
1052
+ """
1053
+ Serializes this instance to a JSON string.
1054
+
1055
+ Args:
1056
+ use_diff (`bool`, *optional*, defaults to `True`):
1057
+ If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
1058
+ is serialized to JSON string.
1059
+
1060
+ Returns:
1061
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
1062
+ """
1063
+ if use_diff is True:
1064
+ config_dict = self.to_diff_dict()
1065
+ else:
1066
+ config_dict = self.to_dict()
1067
+
1068
+ # Handle +/-Infinity and NaNs
1069
+ config_dict = self._encode_special_floats(config_dict)
1070
+
1071
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
1072
+
1073
+ def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True):
1074
+ """
1075
+ Save this instance to a JSON file.
1076
+
1077
+ Args:
1078
+ json_file_path (`str` or `os.PathLike`):
1079
+ Path to the JSON file in which this configuration instance's parameters will be saved.
1080
+ use_diff (`bool`, *optional*, defaults to `True`):
1081
+ If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
1082
+ is serialized to JSON file.
1083
+ """
1084
+ with open(json_file_path, "w", encoding="utf-8") as writer:
1085
+ writer.write(self.to_json_string(use_diff=use_diff))
1086
+
1087
+ def update(self, config_dict: dict[str, Any]):
1088
+ """
1089
+ Updates attributes of this class with attributes from `config_dict`.
1090
+
1091
+ Args:
1092
+ config_dict (`dict[str, Any]`): Dictionary of attributes that should be updated for this class.
1093
+ """
1094
+ for key, value in config_dict.items():
1095
+ setattr(self, key, value)
1096
+
1097
+ def update_from_string(self, update_str: str):
1098
+ """
1099
+ Updates attributes of this class with attributes from `update_str`.
1100
+
1101
+ The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
1102
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
1103
+
1104
+ The keys to change have to already exist in the config object.
1105
+
1106
+ Args:
1107
+ update_str (`str`): String with attributes that should be updated for this class.
1108
+
1109
+ """
1110
+
1111
+ d = dict(x.split("=") for x in update_str.split(","))
1112
+ for k, v in d.items():
1113
+ if not hasattr(self, k):
1114
+ raise ValueError(f"key {k} isn't in the original config dict")
1115
+
1116
+ old_v = getattr(self, k)
1117
+ if isinstance(old_v, bool):
1118
+ if v.lower() in ["true", "1", "y", "yes"]:
1119
+ v = True
1120
+ elif v.lower() in ["false", "0", "n", "no"]:
1121
+ v = False
1122
+ else:
1123
+ raise ValueError(f"can't derive true or false from {v} (key {k})")
1124
+ elif isinstance(old_v, int):
1125
+ v = int(v)
1126
+ elif isinstance(old_v, float):
1127
+ v = float(v)
1128
+ elif not isinstance(old_v, str):
1129
+ raise TypeError(
1130
+ f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
1131
+ )
1132
+
1133
+ setattr(self, k, v)
1134
+
1135
+ def dict_dtype_to_str(self, d: dict[str, Any]) -> None:
1136
+ """
1137
+ Checks whether the passed dictionary and its nested dicts have a *dtype* key and if it's not None,
1138
+ converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
1139
+ string, which can then be stored in the json format.
1140
+ """
1141
+ if d.get("dtype") is not None:
1142
+ if isinstance(d["dtype"], dict):
1143
+ d["dtype"] = {k: str(v).split(".")[-1] for k, v in d["dtype"].items()}
1144
+ # models like Emu3 can have "dtype" as token in config's vocabulary map,
1145
+ # so we also exclude int type here to avoid error in this special case.
1146
+ elif not isinstance(d["dtype"], (str, int)):
1147
+ d["dtype"] = str(d["dtype"]).split(".")[1]
1148
+ for value in d.values():
1149
+ if isinstance(value, dict):
1150
+ self.dict_dtype_to_str(value)
1151
+
1152
+ def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
1153
+ """
1154
+ Checks and removes if there are any keys in the dict that should not be serialized when saving the config.
1155
+ Runs recursive check on the dict, to remove from all sub configs.
1156
+ """
1157
+
1158
+ for key_to_remove in [
1159
+ "_is_quantized",
1160
+ "_auto_class",
1161
+ "_commit_hash",
1162
+ "_attn_implementation_internal",
1163
+ "_experts_implementation_internal",
1164
+ "ignore_keys_at_rope_validation",
1165
+ "base_model_tp_plan",
1166
+ "base_model_pp_plan",
1167
+ ]:
1168
+ d.pop(key_to_remove, None)
1169
+
1170
+ if "_output_attentions" in d:
1171
+ d["output_attentions"] = d.pop("_output_attentions")
1172
+
1173
+ for value in d.values():
1174
+ if isinstance(value, dict):
1175
+ self._remove_keys_not_serialized(value)
1176
+
1177
+ @classmethod
1178
+ def register_for_auto_class(cls, auto_class="AutoConfig"):
1179
+ """
1180
+ Register this class with a given auto class. This should only be used for custom configurations as the ones in
1181
+ the library are already mapped with `AutoConfig`.
1182
+
1183
+
1184
+
1185
+ Args:
1186
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
1187
+ The auto class to register this new configuration with.
1188
+ """
1189
+ if not isinstance(auto_class, str):
1190
+ auto_class = auto_class.__name__
1191
+
1192
+ import transformers.models.auto as auto_module
1193
+
1194
+ if not hasattr(auto_module, auto_class):
1195
+ raise ValueError(f"{auto_class} is not a valid auto class.")
1196
+
1197
+ cls._auto_class = auto_class
1198
+
1199
+ def _get_generation_parameters(self) -> dict[str, Any]:
1200
+ """
1201
+ Checks if there are generation parameters in `PreTrainedConfig` instance. Note that
1202
+ we should not save generation params in PreTrainedConfig, and we will raise error
1203
+ if there are any.
1204
+ """
1205
+ generation_params = {}
1206
+ default_config = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
1207
+ for key in GenerationConfig._get_default_generation_params().keys():
1208
+ if key == "use_cache":
1209
+ continue # common key for most models
1210
+ if hasattr(self, key) and getattr(self, key) is not None and key not in default_config:
1211
+ generation_params[key] = getattr(self, key)
1212
+
1213
+ return generation_params
1214
+
1215
+ def get_text_config(self, decoder=None, encoder=None) -> "PreTrainedConfig":
1216
+ """
1217
+ Returns the text config related to the text input (encoder) or text output (decoder) of the model. The
1218
+ `decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in,
1219
+ which is useful on models that have both text input and output modalities.
1220
+
1221
+ There are three possible outcomes of using this method:
1222
+ 1. On most models, it returns the original config instance itself.
1223
+ 2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set
1224
+ of valid names.
1225
+ 3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa.
1226
+
1227
+ Args:
1228
+ decoder (`Optional[bool]`, *optional*):
1229
+ If set to `True`, then only search for decoder config names.
1230
+ encoder (`Optional[bool]`, *optional*):
1231
+ If set to `True`, then only search for encoder config names.
1232
+ """
1233
+ return_both = decoder == encoder # both unset or both set -> search all possible names
1234
+
1235
+ decoder_possible_text_config_names = ("decoder", "generator", "text_config")
1236
+ encoder_possible_text_config_names = ("text_encoder",)
1237
+ if return_both:
1238
+ possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
1239
+ elif decoder:
1240
+ possible_text_config_names = decoder_possible_text_config_names
1241
+ else:
1242
+ possible_text_config_names = encoder_possible_text_config_names
1243
+
1244
+ valid_text_config_names = []
1245
+ for text_config_name in possible_text_config_names:
1246
+ if hasattr(self, text_config_name):
1247
+ text_config = getattr(self, text_config_name, None)
1248
+ if text_config is not None:
1249
+ valid_text_config_names += [text_config_name]
1250
+
1251
+ if len(valid_text_config_names) > 1:
1252
+ raise ValueError(
1253
+ f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
1254
+ "case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
1255
+ "e.g. `text_config = config.sub_config_name`"
1256
+ )
1257
+ elif len(valid_text_config_names) == 1:
1258
+ config_to_return = getattr(self, valid_text_config_names[0])
1259
+ else:
1260
+ config_to_return = self
1261
+
1262
+ # handle legacy models with flat config structure, when we only want one of the configs
1263
+ if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
1264
+ config_to_return = copy.deepcopy(config_to_return)
1265
+ prefix_to_keep = "decoder" if decoder else "encoder"
1266
+ for key in config_to_return.to_dict():
1267
+ # NOTE: We can't discard keys because:
1268
+ # 1) we can't truly delete a cls attribte on a dataclass; 2) we can't set the value to `None` due to
1269
+ # strict validation. So we just keep it as is, since there are only a couple old models falling in this condition
1270
+ if key.startswith(prefix_to_keep):
1271
+ # [encoder/decoder]_layers -> num_hidden_layers
1272
+ if key == prefix_to_keep + "_layers":
1273
+ new_key = "num_hidden_layers"
1274
+ # [encoder/decoder]_attention_heads -> num_attention_heads
1275
+ elif key == prefix_to_keep + "_attention_heads":
1276
+ new_key = "num_attention_heads"
1277
+ # e.g. encoder_hidden_act -> hidden_act
1278
+ else:
1279
+ new_key = key[len(prefix_to_keep) + 1 :]
1280
+
1281
+ # Does the class map the new key into a different attribute name at read time? if so, let's write
1282
+ # into that attribute instead
1283
+ if new_key in config_to_return.attribute_map:
1284
+ new_key = config_to_return.attribute_map[new_key]
1285
+
1286
+ value = getattr(config_to_return, key)
1287
+ delattr(config_to_return, key)
1288
+ setattr(config_to_return, new_key, value)
1289
+
1290
+ return config_to_return
1291
+
1292
+
1293
+ def get_configuration_file(configuration_files: list[str]) -> str:
1294
+ """
1295
+ Get the configuration file to use for this version of transformers.
1296
+
1297
+ Args:
1298
+ configuration_files (`list[str]`): The list of available configuration files.
1299
+
1300
+ Returns:
1301
+ `str`: The configuration file to use.
1302
+ """
1303
+ configuration_files_map = {}
1304
+ for file_name in configuration_files:
1305
+ if file_name.startswith("config.") and file_name.endswith(".json") and file_name != "config.json":
1306
+ v = file_name.removeprefix("config.").removesuffix(".json")
1307
+ configuration_files_map[v] = file_name
1308
+ available_versions = sorted(configuration_files_map.keys())
1309
+
1310
+ # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
1311
+ configuration_file = CONFIG_NAME
1312
+ transformers_version = version.parse(__version__)
1313
+ for v in available_versions:
1314
+ if version.parse(v) <= transformers_version:
1315
+ configuration_file = configuration_files_map[v]
1316
+ else:
1317
+ # No point going further since the versions are sorted.
1318
+ break
1319
+
1320
+ return configuration_file
1321
+
1322
+
1323
+ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
1324
+ """
1325
+ Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
1326
+ values from `dict_a` that are different from values in `dict_b`.
1327
+
1328
+ dict_b : the default config dictionary. We want to remove values that are in this one
1329
+ """
1330
+ diff = {}
1331
+ default = config_obj.__class__().to_dict() if config_obj is not None else {}
1332
+ for key, value in dict_a.items():
1333
+ obj_value = getattr(config_obj, str(key), None)
1334
+ if isinstance(obj_value, PreTrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
1335
+ diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
1336
+ diff[key] = diff_value
1337
+ elif key not in dict_b or (value != default[key]):
1338
+ diff[key] = value
1339
+ return diff
1340
+
1341
+
1342
+ PreTrainedConfig.push_to_hub = copy_func(PreTrainedConfig.push_to_hub)
1343
+ if PreTrainedConfig.push_to_hub.__doc__ is not None:
1344
+ PreTrainedConfig.push_to_hub.__doc__ = PreTrainedConfig.push_to_hub.__doc__.format(
1345
+ object="config", object_class="AutoConfig", object_files="configuration file"
1346
+ )
1347
+
1348
+
1349
+ # The alias is only here for BC - we did not have the correct CamelCasing before
1350
+ PretrainedConfig = PreTrainedConfig
1351
+
1352
+
1353
+ def layer_type_validation(layer_types: list[str], num_hidden_layers: int | None = None, attention: bool = True):
1354
+ logger.warning(
1355
+ "`layer_type_validation` is deprecated and will be removed in v5.20. "
1356
+ "Use `PreTrainedConfig.validate_layer_type` instead"
1357
+ )
1358
+
1359
+ if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
1360
+ raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
1361
+ if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
1362
+ raise ValueError(
1363
+ f"`num_hidden_layers` ({num_hidden_layers}) must be equal to the number of layer types "
1364
+ f"({len(layer_types)})"
1365
+ )
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/file_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ File utilities: utilities related to download and cache models
16
+
17
+ This module should not be update anymore and is only left for backward compatibility.
18
+ """
19
+
20
+ from . import __version__
21
+
22
+ # Backward compatibility imports, to make sure all those objects can be found in file_utils
23
+ from .utils import (
24
+ CLOUDFRONT_DISTRIB_PREFIX,
25
+ CONFIG_NAME,
26
+ DUMMY_INPUTS,
27
+ DUMMY_MASK,
28
+ ENV_VARS_TRUE_AND_AUTO_VALUES,
29
+ ENV_VARS_TRUE_VALUES,
30
+ FEATURE_EXTRACTOR_NAME,
31
+ HF_MODULES_CACHE,
32
+ MODEL_CARD_NAME,
33
+ MULTIPLE_CHOICE_DUMMY_INPUTS,
34
+ S3_BUCKET_PREFIX,
35
+ SENTENCEPIECE_UNDERLINE,
36
+ SPIECE_UNDERLINE,
37
+ TRANSFORMERS_DYNAMIC_MODULE_NAME,
38
+ WEIGHTS_INDEX_NAME,
39
+ WEIGHTS_NAME,
40
+ ContextManagers,
41
+ DummyObject,
42
+ EntryNotFoundError,
43
+ ExplicitEnum,
44
+ ModelOutput,
45
+ PaddingStrategy,
46
+ PushToHubMixin,
47
+ RepositoryNotFoundError,
48
+ RevisionNotFoundError,
49
+ TensorType,
50
+ _LazyModule,
51
+ add_code_sample_docstrings,
52
+ add_end_docstrings,
53
+ add_start_docstrings,
54
+ add_start_docstrings_to_model_forward,
55
+ copy_func,
56
+ define_sagemaker_information,
57
+ get_torch_version,
58
+ has_file,
59
+ http_user_agent,
60
+ is_apex_available,
61
+ is_bs4_available,
62
+ is_coloredlogs_available,
63
+ is_datasets_available,
64
+ is_detectron2_available,
65
+ is_faiss_available,
66
+ is_g2p_en_available,
67
+ is_in_notebook,
68
+ is_librosa_available,
69
+ is_onnx_available,
70
+ is_pandas_available,
71
+ is_phonemizer_available,
72
+ is_protobuf_available,
73
+ is_psutil_available,
74
+ is_py3nvml_available,
75
+ is_pyctcdecode_available,
76
+ is_pytesseract_available,
77
+ is_pytorch_quantization_available,
78
+ is_rjieba_available,
79
+ is_sagemaker_dp_enabled,
80
+ is_sagemaker_mp_enabled,
81
+ is_scipy_available,
82
+ is_sentencepiece_available,
83
+ is_seqio_available,
84
+ is_sklearn_available,
85
+ is_soundfile_available,
86
+ is_spacy_available,
87
+ is_speech_available,
88
+ is_tensor,
89
+ is_timm_available,
90
+ is_tokenizers_available,
91
+ is_torch_available,
92
+ is_torch_cuda_available,
93
+ is_torch_fx_proxy,
94
+ is_torch_mps_available,
95
+ is_torch_tf32_available,
96
+ is_torch_xla_available,
97
+ is_torchaudio_available,
98
+ is_training_run_on_sagemaker,
99
+ is_vision_available,
100
+ replace_return_docstrings,
101
+ requires_backends,
102
+ to_numpy,
103
+ to_py_obj,
104
+ torch_only_method,
105
+ )
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/fusion_mapping.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Fusion registration helpers.
16
+
17
+ See `docs/source/en/fusion_mapping.md` for the design overview and extension guide.
18
+ """
19
+
20
+ import math
21
+ import re
22
+ from collections.abc import Mapping
23
+ from typing import TYPE_CHECKING, Any
24
+
25
+ import torch
26
+ from torch import nn
27
+
28
+ from .conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping
29
+ from .core_model_loading import Conv3dToLinear, WeightConverter, WeightRenaming, WeightTransform
30
+ from .monkey_patching import register_patch_mapping
31
+ from .utils import logging
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from .configuration_utils import PretrainedConfig
36
+ from .modeling_utils import PreTrainedModel
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ _FUSION_DISCOVERY_CACHE: dict[str, dict[type, dict[str, type[nn.Module]]]] = {}
42
+
43
+
44
+ class ModuleFusionSpec:
45
+ """Base recipe for a fusion family.
46
+
47
+ A fusion spec decides which modules are eligible for a fusion, how to build
48
+ the runtime replacement class, and which weight transforms are needed to map
49
+ checkpoints between the original and fused layouts.
50
+ """
51
+
52
+ target_modules_patterns: tuple[str, ...] = ()
53
+
54
+ def get_empty_log(self, model_name: str) -> str:
55
+ """Return the log message emitted when no compatible modules are found."""
56
+ return f"No compatible {type(self).__name__} classes found to fuse for {model_name}"
57
+
58
+ def is_fusable(self, module: nn.Module) -> bool:
59
+ """Return whether `module` is compatible with this fusion family."""
60
+ raise NotImplementedError
61
+
62
+ def make_fused_class(self, original_cls: type[nn.Module]) -> type[nn.Module]:
63
+ """Build the runtime replacement class for a compatible module class."""
64
+ raise NotImplementedError
65
+
66
+ def make_transforms(self, config: "PretrainedConfig") -> list[WeightTransform]:
67
+ """Build the weight transforms needed to load and save the fused runtime layout."""
68
+ raise NotImplementedError
69
+
70
+
71
+ class _FusedPatchEmbeddingMixin:
72
+ def __init__(self, *args, **kwargs):
73
+ # call the original_cls.__init__()
74
+ super().__init__(*args, **kwargs)
75
+ self.patch_volume = self.proj.in_channels * math.prod(self.proj.kernel_size)
76
+
77
+ self.linear_proj = nn.Linear(
78
+ self.patch_volume,
79
+ self.proj.out_channels,
80
+ bias=self.proj.bias is not None,
81
+ device=self.proj.weight.device,
82
+ dtype=self.proj.weight.dtype,
83
+ )
84
+
85
+ del self.proj
86
+
87
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
88
+ target_dtype = self.linear_proj.weight.dtype
89
+ hidden_states = hidden_states.view(-1, self.patch_volume)
90
+ hidden_states = self.linear_proj(hidden_states.to(dtype=target_dtype))
91
+ return hidden_states.view(-1, self.embed_dim)
92
+
93
+
94
+ class PatchEmbeddingsFusionSpec(ModuleFusionSpec):
95
+ """Fuse compatible Conv3d patch embeddings into flattened Linear projections."""
96
+
97
+ target_modules_patterns = (r"(^|\.)patch_embed$",)
98
+
99
+ def is_fusable(self, module: nn.Module) -> bool:
100
+ if not isinstance(proj := getattr(module, "proj", None), nn.Conv3d):
101
+ return False
102
+
103
+ # no overlap between the patches
104
+ return (
105
+ proj.stride == proj.kernel_size
106
+ and proj.padding == (0, 0, 0)
107
+ and proj.dilation == (1, 1, 1)
108
+ and proj.groups == 1
109
+ )
110
+
111
+ def make_fused_class(self, original_cls: type[nn.Module]) -> type[nn.Module]:
112
+ fused_cls = type(f"Fused{original_cls.__name__}", (_FusedPatchEmbeddingMixin, original_cls), {})
113
+ fused_cls.__qualname__ = f"Fused{original_cls.__qualname__}"
114
+ return fused_cls
115
+
116
+ def make_transforms(self, config: "PretrainedConfig") -> list[WeightTransform]:
117
+ vision_config = getattr(config, "vision_config", config)
118
+ patch_size = vision_config.patch_size
119
+ if isinstance(patch_size, int):
120
+ patch_size = (patch_size, patch_size)
121
+ kernel_size = (vision_config.temporal_patch_size, *tuple(patch_size))
122
+ in_channels = vision_config.in_channels
123
+
124
+ return [
125
+ WeightConverter(
126
+ source_patterns=r"patch_embed\.proj\.weight$",
127
+ target_patterns=r"patch_embed\.linear_proj\.weight$",
128
+ operations=[
129
+ Conv3dToLinear(
130
+ in_channels=in_channels,
131
+ kernel_size=kernel_size,
132
+ )
133
+ ],
134
+ ),
135
+ WeightRenaming(
136
+ source_patterns=r"patch_embed\.proj\.bias$",
137
+ target_patterns=r"patch_embed\.linear_proj\.bias$",
138
+ ),
139
+ ]
140
+
141
+
142
+ def _discover_fusable_modules(
143
+ cls: "type[PreTrainedModel]",
144
+ config: "PretrainedConfig",
145
+ fusion_name: str,
146
+ spec: ModuleFusionSpec,
147
+ ) -> dict[str, type[nn.Module]]:
148
+ """Discover compatible module classes for one fusion family on a meta-initialized model.
149
+
150
+ This function:
151
+ - instantiates `cls(config)` on the meta device
152
+ - scans `named_modules()` for candidate modules
153
+ - optionally pre-filters them with `target_modules_patterns`
154
+ - uses `is_fusable(...)` as the final structural check
155
+ - builds the class-level patch mapping used by monkey patching
156
+
157
+ Results are cached per `(fusion_name, cls)` to avoid repeated meta-initialization.
158
+ This matches the current class-level fusion behavior, where one compatible
159
+ module class maps to one fused replacement class.
160
+ """
161
+
162
+ cache = _FUSION_DISCOVERY_CACHE.setdefault(fusion_name, {})
163
+ if cls in cache:
164
+ return cache[cls]
165
+
166
+ with torch.device("meta"):
167
+ model = cls(config)
168
+
169
+ seen_classes = set()
170
+ patch_mapping = {}
171
+ target_module_pattern = (
172
+ re.compile("|".join(spec.target_modules_patterns)) if spec.target_modules_patterns else None
173
+ )
174
+ for module_name, module in model.named_modules():
175
+ module_cls = type(module)
176
+ if module_cls in seen_classes:
177
+ continue
178
+ if target_module_pattern is not None and target_module_pattern.search(module_name) is None:
179
+ continue
180
+ if not spec.is_fusable(module):
181
+ continue
182
+
183
+ seen_classes.add(module_cls)
184
+ patch_mapping[module_cls.__name__] = spec.make_fused_class(module_cls)
185
+
186
+ cache[cls] = patch_mapping
187
+ return patch_mapping
188
+
189
+
190
+ def _register_module_fusion(
191
+ cls: "type[PreTrainedModel]", config: "PretrainedConfig", fusion_name: str, spec: ModuleFusionSpec
192
+ ) -> None:
193
+ """Register one fusion family for `cls`.
194
+
195
+ This function updates the two global registries used by fused loading:
196
+ - the monkey-patching registry, so compatible module classes are replaced before initialization
197
+ - the checkpoint conversion mapping, so fused runtime modules still load from the original checkpoint layout
198
+
199
+ Notes:
200
+ - conflicting checkpoint transforms fail fast
201
+ """
202
+
203
+ fusable_classes = _discover_fusable_modules(cls, config, fusion_name=fusion_name, spec=spec)
204
+ if not fusable_classes:
205
+ logger.info(spec.get_empty_log(cls.__name__))
206
+ return
207
+
208
+ register_patch_mapping(fusable_classes, overwrite=True)
209
+
210
+ if not hasattr(cls, "config_class") or not hasattr(cls.config_class, "model_type"):
211
+ raise ValueError(f"Model {cls.__name__} has no config class or model type")
212
+ model_type = cls.config_class.model_type
213
+ converters = spec.make_transforms(config)
214
+
215
+ existing_converters = get_checkpoint_conversion_mapping(model_type)
216
+ if existing_converters is not None:
217
+ # WeightConverter matching stops at the first matching source pattern, so
218
+ # conflicting converters must fail fast instead of being appended.
219
+ existing_converter_sources = {tuple(existing.source_patterns): existing for existing in existing_converters}
220
+ for converter in converters:
221
+ source_patterns = tuple(converter.source_patterns)
222
+ existing_converter = existing_converter_sources.get(source_patterns)
223
+ if existing_converter is not None:
224
+ raise ValueError(
225
+ f"Fusion {fusion_name} for model type {model_type} conflicts with an existing conversion mapping "
226
+ f"for source patterns {source_patterns}."
227
+ )
228
+
229
+ # TODO: allow compatible fusions mentioned https://github.com/huggingface/transformers/pull/45041#discussion_r3028989716
230
+ converters = existing_converters + converters
231
+
232
+ register_checkpoint_conversion_mapping(model_type, converters, overwrite=True)
233
+
234
+
235
+ _FUSION_REGISTRY: dict[str, ModuleFusionSpec] = {"patch_embeddings": PatchEmbeddingsFusionSpec()}
236
+
237
+
238
+ def _iter_enabled_fusions(fusion_config: Mapping[str, bool | Mapping[str, Any]]) -> list[str]:
239
+ """Validate `fusion_config` and return enabled fusion names in user-specified order."""
240
+
241
+ enabled_fusions = []
242
+ for fusion_name, fusion_options in fusion_config.items():
243
+ if fusion_name not in _FUSION_REGISTRY:
244
+ raise ValueError(f"Unknown fusion type: {fusion_name}")
245
+ if fusion_options is False:
246
+ continue
247
+ if fusion_options is not True and not isinstance(fusion_options, Mapping):
248
+ raise ValueError(
249
+ f"Invalid fusion config for {fusion_name}: expected `True`, `False`, or a mapping of options."
250
+ )
251
+ enabled_fusions.append(fusion_name)
252
+ return enabled_fusions
253
+
254
+
255
+ def register_fusion_patches(
256
+ cls: "type[PreTrainedModel]", config, fusion_config: Mapping[str, bool | Mapping[str, Any]] | None = None
257
+ ) -> None:
258
+ """Register requested runtime fusions for `cls`.
259
+
260
+ This function:
261
+ - validates `fusion_config` against `_FUSION_REGISTRY`
262
+ - resolves the enabled fusion families in user order
263
+ - registers monkey patches and checkpoint transforms before model instantiation
264
+ """
265
+
266
+ if not fusion_config:
267
+ return
268
+
269
+ for fusion_name in _iter_enabled_fusions(fusion_config):
270
+ _register_module_fusion(cls, config, fusion_name, _FUSION_REGISTRY[fusion_name])
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_processing_backends.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections.abc import Iterable
16
+ from functools import lru_cache
17
+ from typing import Any, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from .image_processing_base import BatchFeature
22
+ from .image_processing_utils import BaseImageProcessor
23
+ from .image_transforms import (
24
+ center_crop as np_center_crop,
25
+ )
26
+ from .image_transforms import (
27
+ convert_to_rgb,
28
+ divide_to_patches, # noqa: F401 - re-exported for backward compat with image_processing_utils_fast
29
+ get_resize_output_image_size,
30
+ get_size_with_aspect_ratio,
31
+ group_images_by_shape,
32
+ reorder_images,
33
+ )
34
+ from .image_transforms import (
35
+ normalize as np_normalize,
36
+ )
37
+ from .image_transforms import (
38
+ rescale as np_rescale,
39
+ )
40
+ from .image_transforms import (
41
+ resize as np_resize,
42
+ )
43
+ from .image_utils import (
44
+ ChannelDimension,
45
+ ImageInput,
46
+ ImageType,
47
+ SizeDict,
48
+ get_image_size,
49
+ get_image_size_for_max_height_width,
50
+ get_image_type,
51
+ get_max_height_width,
52
+ infer_channel_dimension_format,
53
+ is_valid_image,
54
+ load_image_as_tensor,
55
+ )
56
+ from .processing_utils import ImagesKwargs, Unpack
57
+ from .utils import (
58
+ TensorType,
59
+ is_torch_available,
60
+ is_torchvision_available,
61
+ is_vision_available,
62
+ logging,
63
+ )
64
+ from .utils.import_utils import is_rocm_platform, is_torchdynamo_compiling, requires
65
+
66
+
67
+ if is_vision_available():
68
+ from .image_utils import PILImageResampling
69
+
70
+ if is_torch_available():
71
+ import torch
72
+
73
+ if is_torchvision_available():
74
+ from torchvision.transforms.v2 import functional as tvF
75
+
76
+ from .image_utils import pil_torch_interpolation_mapping, torch_pil_interpolation_mapping
77
+ else:
78
+ pil_torch_interpolation_mapping = None
79
+ torch_pil_interpolation_mapping = None
80
+
81
+
82
+ logger = logging.get_logger(__name__)
83
+
84
+
85
+ @requires(backends=("torch", "torchvision"))
86
+ class TorchvisionBackend(BaseImageProcessor):
87
+ """Torchvision backend for GPU-accelerated batched image processing."""
88
+
89
+ def __init__(self, **kwargs: Unpack[ImagesKwargs]):
90
+ super().__init__(**kwargs)
91
+ self._set_attributes(**kwargs)
92
+
93
+ @property
94
+ def is_fast(self) -> bool:
95
+ """
96
+ `bool`: Whether or not this image processor is using the fast (Torchvision) backend.
97
+ The `is_fast` property is deprecated and will be removed in v5.3 of Transformers.
98
+ Use the `backend` attribute instead (e.g., `processor.backend == "torchvision"`).
99
+ """
100
+ logger.warning_once(
101
+ "The `is_fast` property is deprecated and will be removed in v5.3 of Transformers. "
102
+ "Use the `backend` attribute instead (e.g., `processor.backend == 'torchvision'`)."
103
+ )
104
+ return True
105
+
106
+ @property
107
+ def backend(self) -> str:
108
+ """
109
+ `str`: The backend used by this image processor.
110
+ """
111
+ return "torchvision"
112
+
113
+ def fetch_images(self, image_url_or_urls: str | list[str] | list[list[str]]):
114
+ """
115
+ Convert a single or a list of URLs / paths into `torch.Tensor` objects.
116
+
117
+ Already-valid image objects (tensors, numpy arrays, PIL Images) are passed through
118
+ unchanged so that callers who pre-load images are unaffected.
119
+ """
120
+ if isinstance(image_url_or_urls, (list, tuple)):
121
+ return [self.fetch_images(x) for x in image_url_or_urls]
122
+ elif isinstance(image_url_or_urls, str):
123
+ return load_image_as_tensor(image_url_or_urls)
124
+ elif is_valid_image(image_url_or_urls):
125
+ return image_url_or_urls
126
+ else:
127
+ raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
128
+
129
+ def process_image(
130
+ self,
131
+ image: ImageInput,
132
+ do_convert_rgb: bool | None = None,
133
+ input_data_format: str | ChannelDimension | None = None,
134
+ device: Optional["torch.device"] = None,
135
+ **kwargs: Unpack[ImagesKwargs],
136
+ ) -> "torch.Tensor":
137
+ """Process a single image for torchvision backend."""
138
+ image_type = get_image_type(image)
139
+ if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
140
+ raise ValueError(f"Unsupported input image type {image_type}")
141
+
142
+ if do_convert_rgb:
143
+ image = self.convert_to_rgb(image)
144
+
145
+ if image_type == ImageType.PIL:
146
+ image = tvF.pil_to_tensor(image)
147
+ elif image_type == ImageType.NUMPY:
148
+ image = torch.from_numpy(image).contiguous()
149
+
150
+ if image.ndim == 2:
151
+ image = image.unsqueeze(0)
152
+
153
+ if input_data_format is None:
154
+ input_data_format = infer_channel_dimension_format(image)
155
+
156
+ if input_data_format == ChannelDimension.LAST:
157
+ image = image.permute(2, 0, 1).contiguous()
158
+
159
+ if device is not None:
160
+ image = image.to(device)
161
+
162
+ return image
163
+
164
+ def convert_to_rgb(self, image: ImageInput) -> ImageInput:
165
+ """Convert an image to RGB format."""
166
+ return convert_to_rgb(image)
167
+
168
+ def pad(
169
+ self,
170
+ images: list["torch.Tensor"],
171
+ pad_size: SizeDict = None,
172
+ fill_value: int | None = 0,
173
+ padding_mode: str | None = "constant",
174
+ return_mask: bool = False,
175
+ disable_grouping: bool | None = False,
176
+ is_nested: bool | None = False,
177
+ **kwargs,
178
+ ) -> Union[tuple["torch.Tensor", "torch.Tensor"], "torch.Tensor"]:
179
+ """Pad images using Torchvision with batched operations."""
180
+ if pad_size is not None:
181
+ if not (pad_size.height and pad_size.width):
182
+ raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.")
183
+ pad_size = (pad_size.height, pad_size.width)
184
+ else:
185
+ pad_size = get_max_height_width(images)
186
+
187
+ grouped_images, grouped_images_index = group_images_by_shape(
188
+ images, disable_grouping=disable_grouping, is_nested=is_nested
189
+ )
190
+ processed_images_grouped = {}
191
+ processed_masks_grouped = {}
192
+ for shape, stacked_images in grouped_images.items():
193
+ image_size = stacked_images.shape[-2:]
194
+ padding_height = pad_size[0] - image_size[0]
195
+ padding_width = pad_size[1] - image_size[1]
196
+ if padding_height < 0 or padding_width < 0:
197
+ raise ValueError(
198
+ f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the "
199
+ f"image size. Got pad_size={pad_size}, image_size={image_size}."
200
+ )
201
+ if image_size != pad_size:
202
+ padding = (0, 0, padding_width, padding_height)
203
+ stacked_images = tvF.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode)
204
+ processed_images_grouped[shape] = stacked_images
205
+
206
+ if return_mask:
207
+ stacked_masks = torch.zeros_like(stacked_images, dtype=torch.int64)[..., 0, :, :]
208
+ stacked_masks[..., : image_size[0], : image_size[1]] = 1
209
+ processed_masks_grouped[shape] = stacked_masks
210
+
211
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=is_nested)
212
+ if return_mask:
213
+ processed_masks = reorder_images(processed_masks_grouped, grouped_images_index, is_nested=is_nested)
214
+ return processed_images, processed_masks
215
+
216
+ return processed_images
217
+
218
+ def resize(
219
+ self,
220
+ image: "torch.Tensor",
221
+ size: SizeDict,
222
+ resample: "PILImageResampling | tvF.InterpolationMode | int | None" = None,
223
+ antialias: bool = True,
224
+ **kwargs,
225
+ ) -> "torch.Tensor":
226
+ """Resize an image using Torchvision."""
227
+ # Convert PIL resample to torchvision interpolation if needed
228
+ if resample is not None:
229
+ if isinstance(resample, (PILImageResampling, int)):
230
+ interpolation = pil_torch_interpolation_mapping[resample]
231
+ else:
232
+ interpolation = resample
233
+ else:
234
+ interpolation = tvF.InterpolationMode.BILINEAR
235
+ if interpolation == tvF.InterpolationMode.LANCZOS:
236
+ logger.warning_once(
237
+ "You have used a torchvision backend image processor with LANCZOS resample which not yet supported for torch.Tensor. "
238
+ "BICUBIC resample will be used as an alternative. Please fall back to a pil backend image processor if you "
239
+ "want full consistency with the original model."
240
+ )
241
+ interpolation = tvF.InterpolationMode.BICUBIC
242
+
243
+ if size.shortest_edge and size.longest_edge:
244
+ new_size = get_size_with_aspect_ratio(
245
+ image.size()[-2:],
246
+ size.shortest_edge,
247
+ size.longest_edge,
248
+ )
249
+ elif size.shortest_edge:
250
+ new_size = get_resize_output_image_size(
251
+ image,
252
+ size=size.shortest_edge,
253
+ default_to_square=False,
254
+ input_data_format=ChannelDimension.FIRST,
255
+ )
256
+ elif size.max_height and size.max_width:
257
+ new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width)
258
+ elif size.height and size.width:
259
+ new_size = (size.height, size.width)
260
+ else:
261
+ raise ValueError(
262
+ "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
263
+ f" {size}."
264
+ )
265
+
266
+ # Workaround for torch.compile issue with uint8 on AMD GPUs
267
+ if is_torchdynamo_compiling() and is_rocm_platform():
268
+ return self._compile_friendly_resize(image, new_size, interpolation, antialias)
269
+ return tvF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
270
+
271
+ @staticmethod
272
+ def _compile_friendly_resize(
273
+ image: "torch.Tensor",
274
+ new_size: tuple[int, int],
275
+ interpolation: Optional["tvF.InterpolationMode"] = None,
276
+ antialias: bool = True,
277
+ ) -> "torch.Tensor":
278
+ """A wrapper around tvF.resize for torch.compile compatibility with uint8 tensors."""
279
+ if image.dtype == torch.uint8:
280
+ image = image.float() / 256
281
+ image = tvF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
282
+ image = image * 256
283
+ image = torch.where(image > 255, 255, image)
284
+ image = torch.where(image < 0, 0, image)
285
+ image = image.round().to(torch.uint8)
286
+ else:
287
+ image = tvF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
288
+ return image
289
+
290
+ def rescale(
291
+ self,
292
+ image: "torch.Tensor",
293
+ scale: float,
294
+ **kwargs,
295
+ ) -> "torch.Tensor":
296
+ """Rescale an image by a scale factor using Torchvision."""
297
+ return image * scale
298
+
299
+ def normalize(
300
+ self,
301
+ image: "torch.Tensor",
302
+ mean: float | Iterable[float],
303
+ std: float | Iterable[float],
304
+ **kwargs,
305
+ ) -> "torch.Tensor":
306
+ """Normalize an image using Torchvision."""
307
+ return tvF.normalize(image, mean, std)
308
+
309
+ @lru_cache(maxsize=10)
310
+ def _fuse_mean_std_and_rescale_factor(
311
+ self,
312
+ do_normalize: bool | None = None,
313
+ image_mean: float | list[float] | None = None,
314
+ image_std: float | list[float] | None = None,
315
+ do_rescale: bool | None = None,
316
+ rescale_factor: float | None = None,
317
+ device: Optional["torch.device"] = None,
318
+ ) -> tuple:
319
+ if do_rescale and do_normalize:
320
+ # Fused rescale and normalize
321
+ image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
322
+ image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
323
+ do_rescale = False
324
+ return image_mean, image_std, do_rescale
325
+
326
+ def rescale_and_normalize(
327
+ self,
328
+ images: "torch.Tensor",
329
+ do_rescale: bool,
330
+ rescale_factor: float,
331
+ do_normalize: bool,
332
+ image_mean: float | list[float],
333
+ image_std: float | list[float],
334
+ ) -> "torch.Tensor":
335
+ """Rescale and normalize images using Torchvision (fused for efficiency)."""
336
+ image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
337
+ do_normalize=do_normalize,
338
+ image_mean=image_mean,
339
+ image_std=image_std,
340
+ do_rescale=do_rescale,
341
+ rescale_factor=rescale_factor,
342
+ device=images.device,
343
+ )
344
+ if do_normalize:
345
+ images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
346
+ elif do_rescale:
347
+ images = self.rescale(images, rescale_factor)
348
+
349
+ return images
350
+
351
+ def center_crop(
352
+ self,
353
+ image: "torch.Tensor",
354
+ size: SizeDict,
355
+ **kwargs,
356
+ ) -> "torch.Tensor":
357
+ """Center crop an image using Torchvision."""
358
+ if size.height is None or size.width is None:
359
+ raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
360
+ image_height, image_width = image.shape[-2:]
361
+ crop_height, crop_width = size.height, size.width
362
+
363
+ if crop_width > image_width or crop_height > image_height:
364
+ padding_ltrb = [
365
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
366
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
367
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
368
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
369
+ ]
370
+ image = tvF.pad(image, padding_ltrb, fill=0)
371
+ image_height, image_width = image.shape[-2:]
372
+ if crop_width == image_width and crop_height == image_height:
373
+ return image
374
+
375
+ crop_top = int((image_height - crop_height) / 2.0)
376
+ crop_left = int((image_width - crop_width) / 2.0)
377
+ return tvF.crop(image, crop_top, crop_left, crop_height, crop_width)
378
+
379
+ def _preprocess(
380
+ self,
381
+ images: list["torch.Tensor"],
382
+ do_resize: bool,
383
+ size: SizeDict,
384
+ resample: "PILImageResampling | tvF.InterpolationMode | int | None",
385
+ do_center_crop: bool,
386
+ crop_size: SizeDict,
387
+ do_rescale: bool,
388
+ rescale_factor: float,
389
+ do_normalize: bool,
390
+ image_mean: float | list[float] | None,
391
+ image_std: float | list[float] | None,
392
+ do_pad: bool | None,
393
+ pad_size: SizeDict | None,
394
+ disable_grouping: bool | None,
395
+ return_tensors: str | TensorType | None,
396
+ **kwargs,
397
+ ) -> BatchFeature:
398
+ """Preprocess using Torchvision backend (fast, GPU-accelerated)."""
399
+ # Group images by size for batched resizing
400
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
401
+ resized_images_grouped = {}
402
+ for shape, stacked_images in grouped_images.items():
403
+ if do_resize:
404
+ stacked_images = self.resize(image=stacked_images, size=size, resample=resample)
405
+ resized_images_grouped[shape] = stacked_images
406
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
407
+
408
+ # Group images by size for further processing
409
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
410
+ processed_images_grouped = {}
411
+ for shape, stacked_images in grouped_images.items():
412
+ if do_center_crop:
413
+ stacked_images = self.center_crop(stacked_images, crop_size)
414
+ # Fused rescale and normalize
415
+ stacked_images = self.rescale_and_normalize(
416
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
417
+ )
418
+ processed_images_grouped[shape] = stacked_images
419
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
420
+
421
+ if do_pad:
422
+ processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
423
+
424
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
425
+
426
+
427
+ @requires(backends=("vision",))
428
+ class PilBackend(BaseImageProcessor):
429
+ """PIL/NumPy backend for portable CPU-only image processing."""
430
+
431
+ def __init__(self, **kwargs: Unpack[ImagesKwargs]):
432
+ super().__init__(**kwargs)
433
+ self._set_attributes(**kwargs)
434
+
435
+ @property
436
+ def is_fast(self) -> bool:
437
+ """
438
+ `bool`: Whether or not this image processor is using the fast (Torchvision) backend.
439
+ The `is_fast` property is deprecated and will be removed in v5.3 of Transformers.
440
+ Use the `backend` attribute instead (e.g., `processor.backend == "torchvision"`).
441
+ """
442
+ logger.warning_once(
443
+ "The `is_fast` property is deprecated and will be removed in v5.3 of Transformers. "
444
+ "Use the `backend` attribute instead (e.g., `processor.backend == 'torchvision'`)."
445
+ )
446
+ return False
447
+
448
+ @property
449
+ def backend(self) -> str:
450
+ """
451
+ `str`: The backend used by this image processor.
452
+ """
453
+ return "pil"
454
+
455
+ def process_image(
456
+ self,
457
+ image: ImageInput,
458
+ do_convert_rgb: bool | None = None,
459
+ input_data_format: str | ChannelDimension | None = None,
460
+ **kwargs: Unpack[ImagesKwargs],
461
+ ) -> np.ndarray:
462
+ """Process a single image for PIL backend."""
463
+ image_type = get_image_type(image)
464
+ if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
465
+ raise ValueError(f"Unsupported input image type {image_type}")
466
+
467
+ if do_convert_rgb:
468
+ image = self.convert_to_rgb(image)
469
+
470
+ if image_type == ImageType.PIL:
471
+ image = np.array(image)
472
+ # Set LAST only for multi-channel PIL images (H, W, C); for grayscale (H, W), leave as is to avoid shape errors after expand_dims.
473
+ if image.ndim >= 3:
474
+ input_data_format = ChannelDimension.LAST if input_data_format is None else input_data_format
475
+ elif image_type == ImageType.TORCH:
476
+ image = image.numpy()
477
+
478
+ if image.ndim == 2:
479
+ image = np.expand_dims(image, axis=0)
480
+
481
+ if input_data_format is None:
482
+ input_data_format = infer_channel_dimension_format(image)
483
+
484
+ if input_data_format == ChannelDimension.LAST:
485
+ # Convert from channels-last to channels-first
486
+ if isinstance(image, np.ndarray):
487
+ image = np.transpose(image, (2, 0, 1))
488
+
489
+ return image
490
+
491
+ def convert_to_rgb(self, image: ImageInput) -> ImageInput:
492
+ """Convert an image to RGB format."""
493
+ return convert_to_rgb(image)
494
+
495
+ def pad(
496
+ self,
497
+ images: list[np.ndarray],
498
+ pad_size: SizeDict = None,
499
+ fill_value: int | None = 0,
500
+ padding_mode: str | None = "constant",
501
+ return_mask: bool = False,
502
+ **kwargs,
503
+ ) -> tuple[list[np.ndarray], list[np.ndarray]] | list[np.ndarray]:
504
+ """Pad images to specified size using NumPy."""
505
+ if pad_size is not None:
506
+ if not (pad_size.height and pad_size.width):
507
+ raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.")
508
+ target_height, target_width = pad_size.height, pad_size.width
509
+ else:
510
+ target_height, target_width = get_max_height_width(images)
511
+
512
+ processed_images = []
513
+ processed_masks = []
514
+
515
+ for image in images:
516
+ height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
517
+ padding_height = target_height - height
518
+ padding_width = target_width - width
519
+
520
+ if padding_height < 0 or padding_width < 0:
521
+ raise ValueError(
522
+ f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the "
523
+ f"image size. Got pad_size=({target_height}, {target_width}), image_size=({height}, {width})."
524
+ )
525
+
526
+ if height != target_height or width != target_width:
527
+ # Pad format: ((before_1, after_1), (before_2, after_2), ...)
528
+ # For CHW format: ((0, 0), (0, padding_height), (0, padding_width))
529
+ pad_width = ((0, 0), (0, padding_height), (0, padding_width))
530
+ if padding_mode == "constant":
531
+ image = np.pad(image, pad_width, mode="constant", constant_values=fill_value)
532
+ else:
533
+ image = np.pad(image, pad_width, mode=padding_mode)
534
+
535
+ processed_images.append(image)
536
+
537
+ if return_mask:
538
+ mask = np.zeros((target_height, target_width), dtype=np.int64)
539
+ mask[:height, :width] = 1
540
+ processed_masks.append(mask)
541
+
542
+ if return_mask:
543
+ return processed_images, processed_masks
544
+ return processed_images
545
+
546
+ def resize(
547
+ self,
548
+ image: np.ndarray,
549
+ size: SizeDict,
550
+ resample: "PILImageResampling | None" = None,
551
+ reducing_gap: int | None = None,
552
+ **kwargs,
553
+ ) -> np.ndarray:
554
+ """Resize an image using PIL/NumPy."""
555
+ # PIL backend only supports PILImageResampling
556
+ if resample is not None and not isinstance(resample, (PILImageResampling, int)):
557
+ if torch_pil_interpolation_mapping is not None and resample in torch_pil_interpolation_mapping:
558
+ resample = torch_pil_interpolation_mapping[resample]
559
+ else:
560
+ resample = PILImageResampling.BILINEAR
561
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
562
+
563
+ if size.shortest_edge and size.longest_edge:
564
+ height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
565
+ new_size = get_size_with_aspect_ratio(
566
+ (height, width),
567
+ size.shortest_edge,
568
+ size.longest_edge,
569
+ )
570
+ elif size.shortest_edge:
571
+ new_size = get_resize_output_image_size(
572
+ image,
573
+ size=size.shortest_edge,
574
+ default_to_square=False,
575
+ input_data_format=ChannelDimension.FIRST,
576
+ )
577
+ elif size.max_height and size.max_width:
578
+ height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
579
+ new_size = get_image_size_for_max_height_width((height, width), size.max_height, size.max_width)
580
+ elif size.height and size.width:
581
+ new_size = (size.height, size.width)
582
+ else:
583
+ raise ValueError(
584
+ "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
585
+ f" {size}."
586
+ )
587
+
588
+ return np_resize(
589
+ image,
590
+ size=new_size,
591
+ resample=resample,
592
+ reducing_gap=reducing_gap,
593
+ data_format=ChannelDimension.FIRST,
594
+ input_data_format=ChannelDimension.FIRST,
595
+ )
596
+
597
+ def rescale(
598
+ self,
599
+ image: np.ndarray,
600
+ scale: float,
601
+ **kwargs,
602
+ ) -> np.ndarray:
603
+ """Rescale an image by a scale factor using NumPy."""
604
+ return np_rescale(
605
+ image,
606
+ scale=scale,
607
+ data_format=ChannelDimension.FIRST,
608
+ input_data_format=ChannelDimension.FIRST,
609
+ )
610
+
611
+ def normalize(
612
+ self,
613
+ image: np.ndarray,
614
+ mean: float | Iterable[float],
615
+ std: float | Iterable[float],
616
+ **kwargs,
617
+ ) -> np.ndarray:
618
+ """Normalize an image using NumPy."""
619
+ return np_normalize(
620
+ image,
621
+ mean=mean,
622
+ std=std,
623
+ data_format=ChannelDimension.FIRST,
624
+ input_data_format=ChannelDimension.FIRST,
625
+ )
626
+
627
+ def center_crop(
628
+ self,
629
+ image: np.ndarray,
630
+ size: SizeDict,
631
+ **kwargs,
632
+ ) -> np.ndarray:
633
+ """Center crop an image using NumPy."""
634
+ if size.height is None or size.width is None:
635
+ raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
636
+
637
+ return np_center_crop(
638
+ image,
639
+ size=(size.height, size.width),
640
+ data_format=ChannelDimension.FIRST,
641
+ input_data_format=ChannelDimension.FIRST,
642
+ )
643
+
644
+ def _preprocess(
645
+ self,
646
+ images: list[np.ndarray],
647
+ do_resize: bool,
648
+ size: SizeDict,
649
+ resample: "PILImageResampling | None",
650
+ do_center_crop: bool,
651
+ crop_size: SizeDict,
652
+ do_rescale: bool,
653
+ rescale_factor: float,
654
+ do_normalize: bool,
655
+ image_mean: float | list[float] | None,
656
+ image_std: float | list[float] | None,
657
+ do_pad: bool | None,
658
+ pad_size: SizeDict | None,
659
+ return_tensors: str | TensorType | None,
660
+ **kwargs,
661
+ ) -> BatchFeature:
662
+ """Preprocess using PIL backend (portable, CPU-only)."""
663
+ processed_images = []
664
+ for image in images:
665
+ if do_resize:
666
+ image = self.resize(image=image, size=size, resample=resample)
667
+ if do_center_crop:
668
+ image = self.center_crop(image, crop_size)
669
+ if do_rescale:
670
+ image = self.rescale(image, rescale_factor)
671
+ if do_normalize:
672
+ image = self.normalize(image, image_mean, image_std)
673
+ processed_images.append(image)
674
+
675
+ if do_pad:
676
+ processed_images = self.pad(processed_images, pad_size=pad_size)
677
+
678
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
679
+
680
+ def to_dict(self) -> dict[str, Any]:
681
+ processor_dict = super().to_dict()
682
+ # Remove the "Pil" suffix from the image processor type
683
+ if processor_dict.get("image_processor_type", "").endswith("Pil"):
684
+ processor_dict["image_processor_type"] = processor_dict["image_processor_type"][:-3]
685
+ return processor_dict
686
+
687
+
688
+ # Backward-compatible alias: allow referring to TorchvisionBackend as BaseImageProcessorFast
689
+ BaseImageProcessorFast = TorchvisionBackend
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_processing_utils.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from collections.abc import Iterable
17
+ from copy import deepcopy
18
+ from functools import partial
19
+ from typing import Any
20
+
21
+ import numpy as np
22
+ from huggingface_hub.dataclasses import validate_typed_dict
23
+
24
+ from .image_processing_base import BatchFeature, ImageProcessingMixin
25
+ from .image_transforms import center_crop, normalize, rescale
26
+ from .image_utils import (
27
+ ChannelDimension,
28
+ ImageInput,
29
+ SizeDict,
30
+ get_image_size,
31
+ make_flat_list_of_images,
32
+ validate_preprocess_arguments,
33
+ )
34
+ from .processing_utils import ImagesKwargs, Unpack
35
+ from .utils import (
36
+ auto_docstring,
37
+ is_torchvision_available,
38
+ is_vision_available,
39
+ logging,
40
+ )
41
+
42
+
43
+ if is_vision_available():
44
+ from .image_utils import PILImageResampling
45
+
46
+
47
+ if is_torchvision_available():
48
+ from torchvision.transforms.v2 import functional as tvF
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ INIT_SERVICE_KWARGS = [
55
+ "processor_class",
56
+ "image_processor_type",
57
+ ]
58
+
59
+
60
+ class BaseImageProcessor(ImageProcessingMixin):
61
+ r"""
62
+ Base class for image processors with an inheritance-based backend architecture.
63
+
64
+ This class defines the preprocessing pipeline: kwargs validation, input preparation, and dispatching to the
65
+ backend's `_preprocess` method. Backend subclasses (`TorchvisionBackend`, `PilBackend`) inherit from this class
66
+ and implement the actual image operations (resize, crop, rescale, normalize, etc.). Model-specific image
67
+ processors then inherit from the appropriate backend class.
68
+
69
+ Architecture Overview
70
+ ---------------------
71
+
72
+ The class hierarchy is:
73
+
74
+ BaseImageProcessor (this class)
75
+ ├── TorchvisionBackend (GPU-accelerated, torch.Tensor)
76
+ │ └── ModelImageProcessor (e.g. LlavaNextImageProcessor)
77
+ └── PilBackend (portable CPU, np.ndarray)
78
+ └── ModelImageProcessorPil (e.g. CLIPImageProcessorPil)
79
+
80
+ The preprocessing flow is:
81
+
82
+ __call__() → preprocess() → _preprocess_image_like_inputs() → _prepare_image_like_inputs()
83
+ (calls process_image per image)
84
+ → _preprocess()
85
+ (batch operations: resize, crop, etc.)
86
+
87
+ - `process_image`: Implemented by backends. Converts a single raw input (PIL, NumPy, or Tensor) to the
88
+ backend's working format (torch.Tensor or np.ndarray), handles RGB conversion and channel reordering.
89
+ - `_preprocess`: Implemented by backends. Performs the actual batch processing (resize, center crop, rescale,
90
+ normalize, pad) and returns a `BatchFeature`.
91
+
92
+ Basic Implementation
93
+ --------------------
94
+
95
+ For processors that only need standard operations (resize, center crop, rescale, normalize), inherit from
96
+ a backend and define class attributes:
97
+
98
+ from transformers.image_processing_backends import PilBackend
99
+
100
+ class MyImageProcessorPil(PilBackend):
101
+ resample = PILImageResampling.BILINEAR
102
+ image_mean = IMAGENET_DEFAULT_MEAN
103
+ image_std = IMAGENET_DEFAULT_STD
104
+ size = {"height": 224, "width": 224}
105
+ do_resize = True
106
+ do_rescale = True
107
+ do_normalize = True
108
+
109
+ The backend's `_preprocess` method handles the standard pipeline automatically.
110
+
111
+ Custom Processing
112
+ -----------------
113
+
114
+ For processors that need custom logic (e.g., patch-based processing, multiple input types), override
115
+ `_preprocess` in your model-specific processor. The `_preprocess` method receives already-prepared images
116
+ (converted to the backend format with channels-first ordering) and performs the actual processing:
117
+
118
+ class MyImageProcessor(TorchvisionBackend):
119
+ def _preprocess(self, images, do_resize, size, do_normalize, image_mean, image_std, **kwargs):
120
+ # Group images by shape for efficient batched operations
121
+ grouped_images, grouped_images_index = group_images_by_shape(images)
122
+ processed_groups = {}
123
+ for shape, stacked_images in grouped_images.items():
124
+ if do_resize:
125
+ stacked_images = self.resize(stacked_images, size=size)
126
+ if do_normalize:
127
+ stacked_images = self.normalize(stacked_images, mean=image_mean, std=image_std)
128
+ processed_groups[shape] = stacked_images
129
+ processed_images = reorder_images(processed_groups, grouped_images_index)
130
+ return BatchFeature(data={"pixel_values": processed_images})
131
+
132
+ For processors handling multiple input types (e.g., images + segmentation maps), override
133
+ `_preprocess_image_like_inputs`:
134
+
135
+ def _preprocess_image_like_inputs(
136
+ self,
137
+ images: ImageInput,
138
+ segmentation_maps: ImageInput | None = None,
139
+ **kwargs,
140
+ ) -> BatchFeature:
141
+ images = self._prepare_image_like_inputs(images, **kwargs)
142
+ batch_feature = self._preprocess(images, **kwargs)
143
+
144
+ if segmentation_maps is not None:
145
+ maps = self._prepare_image_like_inputs(segmentation_maps, **kwargs)
146
+ batch_feature["labels"] = self._preprocess(maps, **kwargs).pixel_values
147
+
148
+ return batch_feature
149
+
150
+ Extending Backend Behavior
151
+ --------------------------
152
+
153
+ To customize operations for a specific backend, subclass the backend and override its methods:
154
+
155
+ from transformers.image_processing_backends import TorchvisionBackend, PilBackend
156
+
157
+ class MyTorchvisionProcessor(TorchvisionBackend):
158
+ def resize(self, image, size, **kwargs):
159
+ # Custom resize logic for torchvision
160
+ return super().resize(image, size, **kwargs)
161
+
162
+ class MyPilProcessor(PilBackend):
163
+ def resize(self, image, size, **kwargs):
164
+ # Custom resize logic for PIL
165
+ return super().resize(image, size, **kwargs)
166
+
167
+ Custom Parameters
168
+ -----------------
169
+
170
+ To add parameters beyond `ImagesKwargs`, create a custom kwargs class and set it as `valid_kwargs`:
171
+
172
+ class MyImageProcessorKwargs(ImagesKwargs):
173
+ custom_param: int | None = None
174
+
175
+ class MyImageProcessor(TorchvisionBackend):
176
+ valid_kwargs = MyImageProcessorKwargs
177
+ custom_param = 10 # default value
178
+
179
+ Key Notes
180
+ ---------
181
+
182
+ - Backend selection is done at the class level: inherit from `TorchvisionBackend` or `PilBackend`
183
+ - Backends receive images as `torch.Tensor` (Torchvision) or `np.ndarray` (PIL), always channels-first
184
+ - All images have channel dimension first during processing, regardless of backend
185
+ - Arguments not provided by users default to class attribute values
186
+ - Backend classes encapsulate backend-specific logic (resize, normalize, etc.) and can be overridden
187
+ """
188
+
189
+ valid_kwargs = ImagesKwargs
190
+
191
+ default_to_square = True
192
+ rescale_factor = 1 / 255
193
+ model_input_names = ["pixel_values"]
194
+
195
+ def __init__(self, **kwargs: Unpack[ImagesKwargs]):
196
+ super().__init__(**kwargs)
197
+ # We don't call self._set_attributes in BaseImageProcessor for backward compatibility with remote code
198
+ # We call it instead in the backend subclasses' __init__ methods.
199
+
200
+ def _set_attributes(self, **kwargs):
201
+ """Resolve and set instance attributes from kwargs and class-level defaults for all valid kwargs."""
202
+ attributes = {}
203
+ for key in self.valid_kwargs.__annotations__:
204
+ kwarg = kwargs.pop(key, None)
205
+ if kwarg is not None:
206
+ attributes[key] = kwarg
207
+ else:
208
+ attributes[key] = deepcopy(getattr(self, key, None))
209
+ attributes = self._standardize_kwargs(**attributes)
210
+ for key, value in attributes.items():
211
+ setattr(self, key, value)
212
+
213
+ self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
214
+
215
+ def __call__(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
216
+ """Preprocess an image or a batch of images."""
217
+ return self.preprocess(images, *args, **kwargs)
218
+
219
+ def process_image(self, *args, **kwargs):
220
+ """
221
+ Process a single raw image into the backend's working format.
222
+
223
+ Implemented by backend subclasses (`TorchvisionBackend`, `PilBackend`). Converts a raw input
224
+ (PIL Image, NumPy array, or torch Tensor) to the backend's internal format (`torch.Tensor` for
225
+ Torchvision, `np.ndarray` for PIL), handles RGB conversion and ensures channels-first ordering.
226
+ """
227
+ raise NotImplementedError
228
+
229
+ def _preprocess(self, *args, **kwargs):
230
+ """
231
+ Perform the actual batch image preprocessing (resize, center crop, rescale, normalize, pad).
232
+
233
+ Implemented by backend subclasses (`TorchvisionBackend`, `PilBackend`). Receives a list of
234
+ already-prepared images (in the backend's format, channels-first) and applies the configured
235
+ preprocessing operations. Returns a `BatchFeature` with the processed pixel values.
236
+
237
+ Model-specific processors can override this method to implement custom preprocessing logic
238
+ (e.g., patch-based processing in LLaVA-NeXT).
239
+ """
240
+ raise NotImplementedError
241
+
242
+ def _prepare_images_structure(
243
+ self,
244
+ images: ImageInput,
245
+ expected_ndims: int = 3,
246
+ ) -> ImageInput:
247
+ """
248
+ Prepare the images structure for processing.
249
+
250
+ Args:
251
+ images (`ImageInput`):
252
+ The input images to process.
253
+
254
+ Returns:
255
+ `ImageInput`: The images with a valid nesting.
256
+ """
257
+ images = self.fetch_images(images)
258
+ return make_flat_list_of_images(images, expected_ndims=expected_ndims)
259
+
260
+ def _prepare_image_like_inputs(
261
+ self,
262
+ images: ImageInput,
263
+ *args,
264
+ expected_ndims: int = 3,
265
+ **kwargs: Unpack[ImagesKwargs],
266
+ ) -> list[Any]:
267
+ """
268
+ Prepare image-like inputs for processing by converting each image via `process_image`.
269
+
270
+ Flattens the input structure and applies `process_image` (implemented by the backend) to each
271
+ individual image, converting raw inputs (PIL, NumPy, Tensor) into the backend's working format
272
+ with channels-first ordering.
273
+
274
+ Args:
275
+ images (`ImageInput`):
276
+ The image-like inputs to process.
277
+ expected_ndims (`int`, *optional*, defaults to 3):
278
+ The expected number of dimensions for the images.
279
+
280
+ Returns:
281
+ `list[torch.Tensor]` or `list[np.ndarray]`: The prepared images in the backend's format,
282
+ with channels-first ordering.
283
+ """
284
+ images = self._prepare_images_structure(images, expected_ndims=expected_ndims)
285
+
286
+ process_image_partial = partial(self.process_image, *args, **kwargs)
287
+
288
+ has_nested_structure = len(images) > 0 and isinstance(images[0], list | tuple)
289
+
290
+ if has_nested_structure:
291
+ processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
292
+ else:
293
+ processed_images = [process_image_partial(img) for img in images]
294
+
295
+ return processed_images
296
+
297
+ def _preprocess_image_like_inputs(
298
+ self,
299
+ images: ImageInput,
300
+ *args,
301
+ **kwargs: Unpack[ImagesKwargs],
302
+ ) -> BatchFeature:
303
+ """
304
+ Preprocess image-like inputs by preparing them and dispatching to `_preprocess`.
305
+
306
+ This method first calls `_prepare_image_like_inputs` to convert raw inputs into the backend's
307
+ format, then calls `_preprocess` for the actual batch processing. Override this method in
308
+ model-specific processors that need to handle multiple image-like input types (e.g., images
309
+ and segmentation maps) or need custom orchestration of the preprocessing pipeline.
310
+ """
311
+ images = self._prepare_image_like_inputs(images, **kwargs)
312
+ return self._preprocess(images, *args, **kwargs)
313
+
314
+ def _standardize_kwargs(
315
+ self,
316
+ size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
317
+ crop_size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
318
+ pad_size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
319
+ default_to_square: bool | None = None,
320
+ image_mean: float | list[float] | None = None,
321
+ image_std: float | list[float] | None = None,
322
+ **kwargs,
323
+ ) -> dict:
324
+ """
325
+ Standardize kwargs to canonical format before validation.
326
+ Can be overridden by subclasses to customize the processing of kwargs.
327
+ """
328
+ if kwargs is None:
329
+ kwargs = {}
330
+ if size is not None and not isinstance(size, SizeDict):
331
+ size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
332
+ if crop_size is not None and not isinstance(crop_size, SizeDict):
333
+ crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
334
+ if pad_size is not None and not isinstance(pad_size, SizeDict):
335
+ pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size"))
336
+ if isinstance(image_mean, list):
337
+ image_mean = tuple(image_mean)
338
+ if isinstance(image_std, list):
339
+ image_std = tuple(image_std)
340
+
341
+ kwargs["size"] = size
342
+ kwargs["crop_size"] = crop_size
343
+ kwargs["pad_size"] = pad_size
344
+ kwargs["image_mean"] = image_mean
345
+ kwargs["image_std"] = image_std
346
+
347
+ return kwargs
348
+
349
+ # Backwards compatibility for method that was renamed
350
+ _further_process_kwargs = _standardize_kwargs
351
+
352
+ def _validate_preprocess_kwargs(
353
+ self,
354
+ do_rescale: bool | None = None,
355
+ rescale_factor: float | None = None,
356
+ do_normalize: bool | None = None,
357
+ image_mean: float | tuple[float] | None = None,
358
+ image_std: float | tuple[float] | None = None,
359
+ do_resize: bool | None = None,
360
+ size: SizeDict | None = None,
361
+ do_center_crop: bool | None = None,
362
+ crop_size: SizeDict | None = None,
363
+ resample: "PILImageResampling | tvF.InterpolationMode | int | None" = None,
364
+ **kwargs,
365
+ ):
366
+ """
367
+ Validate the kwargs for the preprocess method.
368
+ """
369
+ validate_preprocess_arguments(
370
+ do_rescale=do_rescale,
371
+ rescale_factor=rescale_factor,
372
+ do_normalize=do_normalize,
373
+ image_mean=image_mean,
374
+ image_std=image_std,
375
+ do_center_crop=do_center_crop,
376
+ crop_size=crop_size,
377
+ do_resize=do_resize,
378
+ size=size,
379
+ resample=resample,
380
+ )
381
+
382
+ @auto_docstring
383
+ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
384
+ """
385
+ Preprocess an image or a batch of images.
386
+ """
387
+ # Perform type validation on received kwargs
388
+ validate_typed_dict(self.valid_kwargs, kwargs)
389
+
390
+ # Set default kwargs from self
391
+ for kwarg_name in self._valid_kwargs_names:
392
+ kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
393
+
394
+ # Update kwargs that need further processing before being validated
395
+ kwargs = self._standardize_kwargs(**kwargs)
396
+
397
+ # Validate kwargs
398
+ self._validate_preprocess_kwargs(**kwargs)
399
+
400
+ return self._preprocess_image_like_inputs(images, *args, **kwargs)
401
+
402
+ def to_dict(self) -> dict[str, Any]:
403
+ processor_dict = super().to_dict()
404
+
405
+ # Filter out None values that are class defaults
406
+ filtered_dict = {}
407
+ for key, value in processor_dict.items():
408
+ if isinstance(value, SizeDict):
409
+ value = dict(value)
410
+ if value is None:
411
+ class_default = getattr(type(self), key, "NOT_FOUND")
412
+ # Keep None if user explicitly set it (class default is non-None)
413
+ if class_default != "NOT_FOUND" and class_default is not None:
414
+ filtered_dict[key] = value
415
+ else:
416
+ filtered_dict[key] = value
417
+
418
+ filtered_dict.pop("_valid_processor_keys", None)
419
+ filtered_dict.pop("_valid_kwargs_names", None)
420
+ return filtered_dict
421
+
422
+ def rescale(
423
+ self,
424
+ image: np.ndarray,
425
+ scale: float,
426
+ data_format: str | ChannelDimension | None = None,
427
+ input_data_format: str | ChannelDimension | None = None,
428
+ **kwargs,
429
+ ) -> np.ndarray:
430
+ """
431
+ Rescale an image by a scale factor. image = image * scale.
432
+
433
+ Args:
434
+ image (`np.ndarray`):
435
+ Image to rescale.
436
+ scale (`float`):
437
+ The scaling factor to rescale pixel values by.
438
+ data_format (`str` or `ChannelDimension`, *optional*):
439
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
440
+ image is used. Can be one of:
441
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
442
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
443
+ input_data_format (`ChannelDimension` or `str`, *optional*):
444
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
445
+ from the input image. Can be one of:
446
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
447
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
448
+
449
+ Returns:
450
+ `np.ndarray`: The rescaled image.
451
+ """
452
+ return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
453
+
454
+ # The next methods are kept for backwards compatibility with remote code, but are overriden by backends.
455
+ def normalize(
456
+ self,
457
+ image: np.ndarray,
458
+ mean: float | Iterable[float],
459
+ std: float | Iterable[float],
460
+ data_format: str | ChannelDimension | None = None,
461
+ input_data_format: str | ChannelDimension | None = None,
462
+ **kwargs,
463
+ ) -> np.ndarray:
464
+ """
465
+ Normalize an image. image = (image - image_mean) / image_std.
466
+
467
+ Args:
468
+ image (`np.ndarray`):
469
+ Image to normalize.
470
+ mean (`float` or `Iterable[float]`):
471
+ Image mean to use for normalization.
472
+ std (`float` or `Iterable[float]`):
473
+ Image standard deviation to use for normalization.
474
+ data_format (`str` or `ChannelDimension`, *optional*):
475
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
476
+ image is used. Can be one of:
477
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
478
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
479
+ input_data_format (`ChannelDimension` or `str`, *optional*):
480
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
481
+ from the input image. Can be one of:
482
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
483
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
484
+
485
+ Returns:
486
+ `np.ndarray`: The normalized image.
487
+ """
488
+ return normalize(
489
+ image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
490
+ )
491
+
492
+ def center_crop(
493
+ self,
494
+ image: np.ndarray,
495
+ size: dict[str, int],
496
+ data_format: str | ChannelDimension | None = None,
497
+ input_data_format: str | ChannelDimension | None = None,
498
+ **kwargs,
499
+ ) -> np.ndarray:
500
+ """
501
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
502
+ any edge, the image is padded with 0's and then center cropped.
503
+
504
+ Args:
505
+ image (`np.ndarray`):
506
+ Image to center crop.
507
+ size (`dict[str, int]`):
508
+ Size of the output image.
509
+ data_format (`str` or `ChannelDimension`, *optional*):
510
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
511
+ image is used. Can be one of:
512
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
513
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
514
+ input_data_format (`ChannelDimension` or `str`, *optional*):
515
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
516
+ from the input image. Can be one of:
517
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
518
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
519
+ """
520
+ size = get_size_dict(size)
521
+ if "height" not in size or "width" not in size:
522
+ raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
523
+ return center_crop(
524
+ image,
525
+ size=(size["height"], size["width"]),
526
+ data_format=data_format,
527
+ input_data_format=input_data_format,
528
+ **kwargs,
529
+ )
530
+
531
+
532
+ VALID_SIZE_DICT_KEYS = (
533
+ {"height", "width"},
534
+ {"shortest_edge"},
535
+ {"shortest_edge", "longest_edge"},
536
+ {"longest_edge"},
537
+ {"max_height", "max_width"},
538
+ )
539
+
540
+
541
+ def is_valid_size_dict(size_dict):
542
+ if not isinstance(size_dict, dict):
543
+ return False
544
+
545
+ size_dict_keys = set(size_dict.keys())
546
+ for allowed_keys in VALID_SIZE_DICT_KEYS:
547
+ if size_dict_keys == allowed_keys:
548
+ return True
549
+ return False
550
+
551
+
552
+ def convert_to_size_dict(
553
+ size: int | Iterable[int] | None = None,
554
+ max_size: int | None = None,
555
+ default_to_square: bool = True,
556
+ height_width_order: bool = True,
557
+ ) -> dict[str, int]:
558
+ # By default, if size is an int we assume it represents a tuple of (size, size).
559
+ if isinstance(size, int) and default_to_square:
560
+ if max_size is not None:
561
+ raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
562
+ return {"height": size, "width": size}
563
+ # In other configs, if size is an int and default_to_square is False, size represents the length of
564
+ # the shortest edge after resizing.
565
+ elif isinstance(size, int) and not default_to_square:
566
+ size_dict = {"shortest_edge": size}
567
+ if max_size is not None:
568
+ size_dict["longest_edge"] = max_size
569
+ return size_dict
570
+ # Otherwise, if size is a tuple it's either (height, width) or (width, height)
571
+ elif isinstance(size, (tuple, list)) and height_width_order:
572
+ return {"height": size[0], "width": size[1]}
573
+ elif isinstance(size, (tuple, list)) and not height_width_order:
574
+ return {"height": size[1], "width": size[0]}
575
+ elif size is None and max_size is not None:
576
+ if default_to_square:
577
+ raise ValueError("Cannot specify both default_to_square=True and max_size")
578
+ return {"longest_edge": max_size}
579
+
580
+ raise ValueError(f"Could not convert size input to size dict: {size}")
581
+
582
+
583
+ def get_size_dict(
584
+ size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
585
+ max_size: int | None = None,
586
+ height_width_order: bool = True,
587
+ default_to_square: bool = True,
588
+ param_name="size",
589
+ ) -> dict:
590
+ """
591
+ Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
592
+ compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
593
+ width) or (width, height) format.
594
+
595
+ - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
596
+ size[0]}` if `height_width_order` is `False`.
597
+ - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
598
+ - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
599
+ is set, it is added to the dict as `{"longest_edge": max_size}`.
600
+ - If `size` is `None` and `default_to_square` is False, the result is `{"longest_edge": max_size}` (requires
601
+ `max_size` to be set). Tuple/list/SizeDict/dict `size` values do not use `max_size`.
602
+
603
+ Args:
604
+ size (`int | Iterable[int] | dict[str, int] | SizeDict`, *optional*):
605
+ The `size` parameter to be cast into a size dictionary.
606
+ max_size (`int | None`, *optional*):
607
+ With `default_to_square=False`, sets `longest_edge` when `size` is an int or `None`; unused for dict,
608
+ `SizeDict`, or tuple/list `size`. Raises if set with `default_to_square=True` when `size` is an int or `None`.
609
+ height_width_order (`bool`, *optional*, defaults to `True`):
610
+ If `size` is a tuple, whether it's in (height, width) or (width, height) order.
611
+ default_to_square (`bool`, *optional*, defaults to `True`):
612
+ If `size` is an int, whether to default to a square image or not.
613
+ """
614
+ if not isinstance(size, dict | SizeDict):
615
+ size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
616
+ logger.info(
617
+ f"{param_name} should be a dictionary with one of the following sets of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
618
+ f" Converted to {size_dict}.",
619
+ )
620
+ # Some remote code bypasses or overrides `_standardize_kwargs`, so handle `SizeDict` `size` here too.
621
+ elif isinstance(size, SizeDict):
622
+ size_dict = dict(size)
623
+ else:
624
+ size_dict = size
625
+
626
+ if not is_valid_size_dict(size_dict):
627
+ raise ValueError(
628
+ f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
629
+ )
630
+ return size_dict
631
+
632
+
633
+ def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
634
+ """
635
+ Selects the best resolution from a list of possible resolutions based on the original size.
636
+
637
+ This is done by calculating the effective and wasted resolution for each possible resolution.
638
+
639
+ The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
640
+
641
+ Args:
642
+ original_size (tuple):
643
+ The original size of the image in the format (height, width).
644
+ possible_resolutions (list):
645
+ A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
646
+
647
+ Returns:
648
+ tuple: The best fit resolution in the format (height, width).
649
+ """
650
+ original_height, original_width = original_size
651
+ best_fit = None
652
+ max_effective_resolution = 0
653
+ min_wasted_resolution = float("inf")
654
+
655
+ for height, width in possible_resolutions:
656
+ scale = min(width / original_width, height / original_height)
657
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
658
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
659
+ wasted_resolution = (width * height) - effective_resolution
660
+
661
+ if effective_resolution > max_effective_resolution or (
662
+ effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
663
+ ):
664
+ max_effective_resolution = effective_resolution
665
+ min_wasted_resolution = wasted_resolution
666
+ best_fit = (height, width)
667
+
668
+ return best_fit
669
+
670
+
671
+ def get_patch_output_size(image, target_resolution, input_data_format):
672
+ """
673
+ Given an image and a target resolution, calculate the output size of the image after cropping to the target
674
+ """
675
+ original_height, original_width = get_image_size(image, channel_dim=input_data_format)
676
+ target_height, target_width = target_resolution
677
+
678
+ scale_w = target_width / original_width
679
+ scale_h = target_height / original_height
680
+
681
+ if scale_w < scale_h:
682
+ new_width = target_width
683
+ new_height = min(math.ceil(original_height * scale_w), target_height)
684
+ else:
685
+ new_height = target_height
686
+ new_width = min(math.ceil(original_width * scale_h), target_width)
687
+
688
+ return new_height, new_width
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_utils.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+ import os
17
+ from collections.abc import Iterable
18
+ from dataclasses import dataclass, fields
19
+ from io import BytesIO
20
+ from typing import Any, Union
21
+
22
+ import httpx
23
+ import numpy as np
24
+
25
+ from .utils import (
26
+ ExplicitEnum,
27
+ is_numpy_array,
28
+ is_torch_available,
29
+ is_torch_tensor,
30
+ is_torchvision_available,
31
+ is_vision_available,
32
+ logging,
33
+ requires_backends,
34
+ to_numpy,
35
+ )
36
+ from .utils.constants import ( # noqa: F401
37
+ IMAGENET_DEFAULT_MEAN,
38
+ IMAGENET_DEFAULT_STD,
39
+ IMAGENET_STANDARD_MEAN,
40
+ IMAGENET_STANDARD_STD,
41
+ OPENAI_CLIP_MEAN,
42
+ OPENAI_CLIP_STD,
43
+ )
44
+ from .utils.import_utils import requires
45
+
46
+
47
+ if is_vision_available():
48
+ import PIL.Image
49
+ import PIL.ImageOps
50
+
51
+ PILImageResampling = PIL.Image.Resampling
52
+
53
+ if is_torchvision_available():
54
+ from torchvision.io import ImageReadMode, decode_image
55
+ from torchvision.transforms import InterpolationMode
56
+ from torchvision.transforms.functional import pil_to_tensor
57
+
58
+ pil_torch_interpolation_mapping = {
59
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST_EXACT,
60
+ PILImageResampling.BOX: InterpolationMode.BOX,
61
+ PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
62
+ PILImageResampling.HAMMING: InterpolationMode.HAMMING,
63
+ PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
64
+ PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
65
+ }
66
+ # Create inverse mapping: InterpolationMode -> PILImageResampling
67
+ torch_pil_interpolation_mapping = {v: k for k, v in pil_torch_interpolation_mapping.items()}
68
+ else:
69
+ pil_torch_interpolation_mapping = {}
70
+ torch_pil_interpolation_mapping = {}
71
+
72
+
73
+ if is_torch_available():
74
+ import torch
75
+
76
+
77
+ logger = logging.get_logger(__name__)
78
+
79
+
80
+ ImageInput = Union[
81
+ "PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"]
82
+ ]
83
+
84
+
85
+ class ChannelDimension(ExplicitEnum):
86
+ FIRST = "channels_first"
87
+ LAST = "channels_last"
88
+
89
+
90
+ class AnnotationFormat(ExplicitEnum):
91
+ COCO_DETECTION = "coco_detection"
92
+ COCO_PANOPTIC = "coco_panoptic"
93
+
94
+
95
+ AnnotationType = dict[str, int | str | list[dict]]
96
+
97
+
98
+ def is_pil_image(img):
99
+ return is_vision_available() and isinstance(img, PIL.Image.Image)
100
+
101
+
102
+ class ImageType(ExplicitEnum):
103
+ PIL = "pillow"
104
+ TORCH = "torch"
105
+ NUMPY = "numpy"
106
+
107
+
108
+ def get_image_type(image):
109
+ if is_pil_image(image):
110
+ return ImageType.PIL
111
+ if is_torch_tensor(image):
112
+ return ImageType.TORCH
113
+ if is_numpy_array(image):
114
+ return ImageType.NUMPY
115
+ raise ValueError(f"Unrecognized image type {type(image)}")
116
+
117
+
118
+ def is_valid_image(img):
119
+ return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img)
120
+
121
+
122
+ def is_valid_list_of_images(images: list):
123
+ return images and all(is_valid_image(image) for image in images)
124
+
125
+
126
+ def concatenate_list(input_list):
127
+ if isinstance(input_list[0], list):
128
+ return [item for sublist in input_list for item in sublist]
129
+ elif isinstance(input_list[0], np.ndarray):
130
+ return np.concatenate(input_list, axis=0)
131
+ elif isinstance(input_list[0], torch.Tensor):
132
+ return torch.cat(input_list, dim=0)
133
+
134
+
135
+ def valid_images(imgs):
136
+ # If we have an list of images, make sure every image is valid
137
+ if isinstance(imgs, (list, tuple)):
138
+ for img in imgs:
139
+ if not valid_images(img):
140
+ return False
141
+ # If not a list of tuple, we have been given a single image or batched tensor of images
142
+ elif not is_valid_image(imgs):
143
+ return False
144
+ return True
145
+
146
+
147
+ def is_batched(img):
148
+ if isinstance(img, (list, tuple)):
149
+ return is_valid_image(img[0])
150
+ return False
151
+
152
+
153
+ def is_scaled_image(image: np.ndarray) -> bool:
154
+ """
155
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
156
+ """
157
+ if image.dtype == np.uint8:
158
+ return False
159
+
160
+ # It's possible the image has pixel values in [0, 255] but is of floating type
161
+ return np.min(image) >= 0 and np.max(image) <= 1
162
+
163
+
164
+ def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]:
165
+ """
166
+ Ensure that the output is a list of images. If the input is a single image, it is converted to a list of length 1.
167
+ If the input is a batch of images, it is converted to a list of images.
168
+
169
+ Args:
170
+ images (`ImageInput`):
171
+ Image of images to turn into a list of images.
172
+ expected_ndims (`int`, *optional*, defaults to 3):
173
+ Expected number of dimensions for a single input image. If the input image has a different number of
174
+ dimensions, an error is raised.
175
+ """
176
+ if is_batched(images):
177
+ return images
178
+
179
+ # Either the input is a single image, in which case we create a list of length 1
180
+ if is_pil_image(images):
181
+ # PIL images are never batched
182
+ return [images]
183
+
184
+ if is_valid_image(images):
185
+ if images.ndim == expected_ndims + 1:
186
+ # Batch of images
187
+ images = list(images)
188
+ elif images.ndim == expected_ndims:
189
+ # Single image
190
+ images = [images]
191
+ else:
192
+ raise ValueError(
193
+ f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
194
+ f" {images.ndim} dimensions."
195
+ )
196
+ return images
197
+ raise ValueError(
198
+ f"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, or torch.Tensor, but got {type(images)}."
199
+ )
200
+
201
+
202
+ def make_flat_list_of_images(
203
+ images: list[ImageInput] | ImageInput,
204
+ expected_ndims: int = 3,
205
+ ) -> ImageInput:
206
+ """
207
+ Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
208
+ If the input is a nested list of images, it is converted to a flat list of images.
209
+ Args:
210
+ images (`Union[list[ImageInput], ImageInput]`):
211
+ The input image.
212
+ expected_ndims (`int`, *optional*, defaults to 3):
213
+ The expected number of dimensions for a single input image.
214
+ Returns:
215
+ list: A list of images or a 4d array of images.
216
+ """
217
+ # If the input is a nested list of images, we flatten it
218
+ if (
219
+ isinstance(images, (list, tuple))
220
+ and all(isinstance(images_i, (list, tuple)) for images_i in images)
221
+ and all(is_valid_list_of_images(images_i) or not images_i for images_i in images)
222
+ ):
223
+ return [img for img_list in images for img in img_list]
224
+
225
+ if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
226
+ if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
227
+ return images
228
+ if images[0].ndim == expected_ndims + 1:
229
+ return [img for img_list in images for img in img_list]
230
+
231
+ if is_valid_image(images):
232
+ if is_pil_image(images) or images.ndim == expected_ndims:
233
+ return [images]
234
+ if images.ndim == expected_ndims + 1:
235
+ return list(images)
236
+
237
+ raise ValueError(f"Could not make a flat list of images from {images}")
238
+
239
+
240
+ def make_nested_list_of_images(
241
+ images: list[ImageInput] | ImageInput,
242
+ expected_ndims: int = 3,
243
+ ) -> list[ImageInput]:
244
+ """
245
+ Ensure that the output is a nested list of images.
246
+ Args:
247
+ images (`Union[list[ImageInput], ImageInput]`):
248
+ The input image.
249
+ expected_ndims (`int`, *optional*, defaults to 3):
250
+ The expected number of dimensions for a single input image.
251
+ Returns:
252
+ list: A list of list of images or a list of 4d array of images.
253
+ """
254
+ # If it's a list of batches, it's already in the right format
255
+ if (
256
+ isinstance(images, (list, tuple))
257
+ and all(isinstance(images_i, (list, tuple)) for images_i in images)
258
+ and all(is_valid_list_of_images(images_i) or not images_i for images_i in images)
259
+ ):
260
+ return images
261
+
262
+ # If it's a list of images, it's a single batch, so convert it to a list of lists
263
+ if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
264
+ if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
265
+ return [images]
266
+ if images[0].ndim == expected_ndims + 1:
267
+ return [list(image) for image in images]
268
+
269
+ # If it's a single image, convert it to a list of lists
270
+ if is_valid_image(images):
271
+ if is_pil_image(images) or images.ndim == expected_ndims:
272
+ return [[images]]
273
+ if images.ndim == expected_ndims + 1:
274
+ return [list(images)]
275
+
276
+ raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")
277
+
278
+
279
+ def to_numpy_array(img) -> np.ndarray:
280
+ if not is_valid_image(img):
281
+ raise ValueError(f"Invalid image type: {type(img)}")
282
+
283
+ if is_vision_available() and isinstance(img, PIL.Image.Image):
284
+ return np.array(img)
285
+ return to_numpy(img)
286
+
287
+
288
+ def infer_channel_dimension_format(
289
+ image: np.ndarray, num_channels: int | tuple[int, ...] | None = None
290
+ ) -> ChannelDimension:
291
+ """
292
+ Infers the channel dimension format of `image`.
293
+
294
+ Args:
295
+ image (`np.ndarray`):
296
+ The image to infer the channel dimension of.
297
+ num_channels (`int` or `tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
298
+ The number of channels of the image.
299
+
300
+ Returns:
301
+ The channel dimension of the image.
302
+ """
303
+ num_channels = num_channels if num_channels is not None else (1, 3)
304
+ num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
305
+
306
+ if image.ndim == 3:
307
+ first_dim, last_dim = 0, 2
308
+ elif image.ndim == 4:
309
+ first_dim, last_dim = 1, 3
310
+ elif image.ndim == 5:
311
+ first_dim, last_dim = 2, 4
312
+ else:
313
+ raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
314
+
315
+ if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
316
+ logger.warning(
317
+ f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) parameter to assign the channel dimension."
318
+ )
319
+ return ChannelDimension.FIRST
320
+ elif image.shape[first_dim] in num_channels:
321
+ return ChannelDimension.FIRST
322
+ elif image.shape[last_dim] in num_channels:
323
+ return ChannelDimension.LAST
324
+ raise ValueError("Unable to infer channel dimension format")
325
+
326
+
327
+ def get_channel_dimension_axis(image: np.ndarray, input_data_format: ChannelDimension | str | None = None) -> int:
328
+ """
329
+ Returns the channel dimension axis of the image.
330
+
331
+ Args:
332
+ image (`np.ndarray`):
333
+ The image to get the channel dimension axis of.
334
+ input_data_format (`ChannelDimension` or `str`, *optional*):
335
+ The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
336
+
337
+ Returns:
338
+ The channel dimension axis of the image.
339
+ """
340
+ if input_data_format is None:
341
+ input_data_format = infer_channel_dimension_format(image)
342
+ if input_data_format == ChannelDimension.FIRST:
343
+ return image.ndim - 3
344
+ elif input_data_format == ChannelDimension.LAST:
345
+ return image.ndim - 1
346
+ raise ValueError(f"Unsupported data format: {input_data_format}")
347
+
348
+
349
+ def get_image_size(image: np.ndarray, channel_dim: ChannelDimension | None = None) -> tuple[int, int]:
350
+ """
351
+ Returns the (height, width) dimensions of the image.
352
+
353
+ Args:
354
+ image (`np.ndarray`):
355
+ The image to get the dimensions of.
356
+ channel_dim (`ChannelDimension`, *optional*):
357
+ Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
358
+
359
+ Returns:
360
+ A tuple of the image's height and width.
361
+ """
362
+ if channel_dim is None:
363
+ channel_dim = infer_channel_dimension_format(image)
364
+
365
+ if channel_dim == ChannelDimension.FIRST:
366
+ return image.shape[-2], image.shape[-1]
367
+ elif channel_dim == ChannelDimension.LAST:
368
+ return image.shape[-3], image.shape[-2]
369
+ else:
370
+ raise ValueError(f"Unsupported data format: {channel_dim}")
371
+
372
+
373
+ def get_image_size_for_max_height_width(
374
+ image_size: tuple[int, int],
375
+ max_height: int,
376
+ max_width: int,
377
+ ) -> tuple[int, int]:
378
+ """
379
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
380
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
381
+ to at least one of the edges be equal to max_height or max_width.
382
+
383
+ For example:
384
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
385
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
386
+
387
+ Args:
388
+ image_size (`tuple[int, int]`):
389
+ The image to resize.
390
+ max_height (`int`):
391
+ The maximum allowed height.
392
+ max_width (`int`):
393
+ The maximum allowed width.
394
+ """
395
+ height, width = image_size
396
+ height_scale = max_height / height
397
+ width_scale = max_width / width
398
+ min_scale = min(height_scale, width_scale)
399
+ new_height = int(height * min_scale)
400
+ new_width = int(width * min_scale)
401
+ return new_height, new_width
402
+
403
+
404
+ def max_across_indices(values: Iterable[Any]) -> list[Any]:
405
+ """
406
+ Return the maximum value across all indices of an iterable of values.
407
+ """
408
+ return [max(values_i) for values_i in zip(*values)]
409
+
410
+
411
+ def get_max_height_width(
412
+ images: list[Union["torch.Tensor", np.ndarray]], input_data_format: str | ChannelDimension = ChannelDimension.FIRST
413
+ ) -> list[int]:
414
+ """
415
+ Get the maximum height and width across all images in a batch.
416
+ """
417
+ if input_data_format == ChannelDimension.FIRST:
418
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
419
+ elif input_data_format == ChannelDimension.LAST:
420
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
421
+ else:
422
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
423
+ return (max_height, max_width)
424
+
425
+
426
+ def is_valid_annotation_coco_detection(annotation: dict[str, list | tuple]) -> bool:
427
+ if (
428
+ isinstance(annotation, dict)
429
+ and "image_id" in annotation
430
+ and "annotations" in annotation
431
+ and isinstance(annotation["annotations"], (list, tuple))
432
+ and (
433
+ # an image can have no annotations
434
+ len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
435
+ )
436
+ ):
437
+ return True
438
+ return False
439
+
440
+
441
+ def is_valid_annotation_coco_panoptic(annotation: dict[str, list | tuple]) -> bool:
442
+ if (
443
+ isinstance(annotation, dict)
444
+ and "image_id" in annotation
445
+ and "segments_info" in annotation
446
+ and "file_name" in annotation
447
+ and isinstance(annotation["segments_info"], (list, tuple))
448
+ and (
449
+ # an image can have no segments
450
+ len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
451
+ )
452
+ ):
453
+ return True
454
+ return False
455
+
456
+
457
+ def valid_coco_detection_annotations(annotations: Iterable[dict[str, list | tuple]]) -> bool:
458
+ return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
459
+
460
+
461
+ def valid_coco_panoptic_annotations(annotations: Iterable[dict[str, list | tuple]]) -> bool:
462
+ return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
463
+
464
+
465
+ def load_image(
466
+ image: Union[str, "PIL.Image.Image"],
467
+ timeout: float | None = None,
468
+ ) -> "PIL.Image.Image":
469
+ """
470
+ Loads `image` to a PIL Image.
471
+
472
+ Args:
473
+ image (`str` or `PIL.Image.Image`):
474
+ The image to convert to the PIL Image format.
475
+ timeout (`float`, *optional*):
476
+ The timeout value in seconds for the URL request.
477
+
478
+ Returns:
479
+ `PIL.Image.Image`: A PIL Image.
480
+ """
481
+ requires_backends(load_image, ["vision"])
482
+ if isinstance(image, str):
483
+ if image.startswith("http://") or image.startswith("https://"):
484
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
485
+ # like http_huggingface_co.png
486
+ image = PIL.Image.open(BytesIO(httpx.get(image, timeout=timeout, follow_redirects=True).content))
487
+ elif os.path.isfile(image):
488
+ image = PIL.Image.open(image)
489
+ else:
490
+ if image.startswith("data:image/"):
491
+ image = image.split(",")[1]
492
+
493
+ # Try to load as base64
494
+ try:
495
+ b64 = base64.decodebytes(image.encode())
496
+ image = PIL.Image.open(BytesIO(b64))
497
+ except Exception as e:
498
+ raise ValueError(
499
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
500
+ )
501
+ elif not isinstance(image, PIL.Image.Image):
502
+ raise TypeError(
503
+ "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
504
+ )
505
+ image = PIL.ImageOps.exif_transpose(image)
506
+ image = image.convert("RGB")
507
+ return image
508
+
509
+
510
+ @requires(backends=("torchvision",))
511
+ def load_image_as_tensor(
512
+ image: Union[str, "PIL.Image.Image"],
513
+ timeout: float | None = None,
514
+ ) -> "torch.Tensor":
515
+ """
516
+ Loads `image` directly to a `torch.Tensor` using torchvision.
517
+
518
+ Args:
519
+ image (`str` or `PIL.Image.Image`):
520
+ The image to convert to the PIL Image format.
521
+ timeout (`float`, *optional*):
522
+ The timeout value in seconds for the URL request.
523
+
524
+ Returns:
525
+ `torch.Tensor`: A `[C, H, W]` uint8 tensor in RGB channel order.
526
+ """
527
+ import torch
528
+
529
+ if isinstance(image, str):
530
+ if image.startswith("http://") or image.startswith("https://"):
531
+ raw = httpx.get(image, timeout=timeout, follow_redirects=True).content
532
+ buf = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
533
+ return decode_image(buf, mode=ImageReadMode.RGB)
534
+ elif os.path.isfile(image):
535
+ return decode_image(image, mode=ImageReadMode.RGB)
536
+ else:
537
+ if image.startswith("data:image/"):
538
+ image = image.split(",")[1]
539
+ try:
540
+ raw = base64.decodebytes(image.encode())
541
+ except Exception as e:
542
+ raise ValueError(
543
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
544
+ )
545
+ buf = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
546
+ return decode_image(buf, mode=ImageReadMode.RGB)
547
+ elif isinstance(image, PIL.Image.Image):
548
+ image = PIL.ImageOps.exif_transpose(image)
549
+ return pil_to_tensor(image.convert("RGB"))
550
+ else:
551
+ raise TypeError(
552
+ "Incorrect format used for image. Should be a URL, a local path, a base64 string, or a PIL image."
553
+ )
554
+
555
+
556
+ def load_images(
557
+ images: Union[list, tuple, str, "PIL.Image.Image"], timeout: float | None = None
558
+ ) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]:
559
+ """Loads images, handling different levels of nesting.
560
+
561
+ Args:
562
+ images: A single image, a list of images, or a list of lists of images to load.
563
+ timeout: Timeout for loading images.
564
+
565
+ Returns:
566
+ A single image, a list of images, a list of lists of images.
567
+ """
568
+ if isinstance(images, (list, tuple)):
569
+ if len(images) and isinstance(images[0], (list, tuple)):
570
+ return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
571
+ else:
572
+ return [load_image(image, timeout=timeout) for image in images]
573
+ else:
574
+ return load_image(images, timeout=timeout)
575
+
576
+
577
+ def validate_preprocess_arguments(
578
+ do_rescale: bool | None = None,
579
+ rescale_factor: float | None = None,
580
+ do_normalize: bool | None = None,
581
+ image_mean: float | list[float] | None = None,
582
+ image_std: float | list[float] | None = None,
583
+ do_pad: bool | None = None,
584
+ pad_size: dict[str, int] | int | None = None,
585
+ do_center_crop: bool | None = None,
586
+ crop_size: dict[str, int] | None = None,
587
+ do_resize: bool | None = None,
588
+ size: dict[str, int] | None = None,
589
+ resample: Union["PILImageResampling", "InterpolationMode", int] | None = None,
590
+ ):
591
+ """
592
+ Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
593
+ Raises `ValueError` if arguments incompatibility is caught.
594
+ Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
595
+ sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
596
+ existing arguments when possible.
597
+
598
+ """
599
+ if do_rescale and rescale_factor is None:
600
+ raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
601
+
602
+ if do_pad and pad_size is None:
603
+ # Processors pad images using different args depending on the model, so the below check is pointless
604
+ # but we keep it for BC for now. TODO: remove in v5
605
+ # Usually padding can be called with:
606
+ # - "pad_size/size" if we're padding to specific values
607
+ # - "size_divisor" if we're padding to any value divisible by X
608
+ # - "None" if we're padding to the maximum size image in batch
609
+ raise ValueError(
610
+ "Depending on the model, `size_divisor` or `pad_size` or `size` must be specified if `do_pad` is `True`."
611
+ )
612
+
613
+ if do_normalize and (image_mean is None or image_std is None):
614
+ raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
615
+
616
+ if do_center_crop and crop_size is None:
617
+ raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
618
+
619
+ if do_resize and not (size is not None and resample is not None):
620
+ raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
621
+
622
+
623
+ class ImageFeatureExtractionMixin:
624
+ """
625
+ Mixin that contain utilities for preparing image features.
626
+ """
627
+
628
+ def _ensure_format_supported(self, image):
629
+ if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
630
+ raise ValueError(
631
+ f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.ndarray` and "
632
+ "`torch.Tensor` are."
633
+ )
634
+
635
+ def to_pil_image(self, image, rescale=None):
636
+ """
637
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
638
+ needed.
639
+
640
+ Args:
641
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
642
+ The image to convert to the PIL Image format.
643
+ rescale (`bool`, *optional*):
644
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
645
+ default to `True` if the image type is a floating type, `False` otherwise.
646
+ """
647
+ self._ensure_format_supported(image)
648
+
649
+ if is_torch_tensor(image):
650
+ image = image.numpy()
651
+
652
+ if isinstance(image, np.ndarray):
653
+ if rescale is None:
654
+ # rescale default to the array being of floating type.
655
+ rescale = isinstance(image.flat[0], np.floating)
656
+ # If the channel as been moved to first dim, we put it back at the end.
657
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
658
+ image = image.transpose(1, 2, 0)
659
+ if rescale:
660
+ image = image * 255
661
+ image = image.astype(np.uint8)
662
+ return PIL.Image.fromarray(image)
663
+ return image
664
+
665
+ def convert_rgb(self, image):
666
+ """
667
+ Converts `PIL.Image.Image` to RGB format.
668
+
669
+ Args:
670
+ image (`PIL.Image.Image`):
671
+ The image to convert.
672
+ """
673
+ self._ensure_format_supported(image)
674
+ if not isinstance(image, PIL.Image.Image):
675
+ return image
676
+
677
+ return image.convert("RGB")
678
+
679
+ def rescale(self, image: np.ndarray, scale: float | int) -> np.ndarray:
680
+ """
681
+ Rescale a numpy image by scale amount
682
+ """
683
+ self._ensure_format_supported(image)
684
+ return image * scale
685
+
686
+ def to_numpy_array(self, image, rescale=None, channel_first=True):
687
+ """
688
+ Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
689
+ dimension.
690
+
691
+ Args:
692
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
693
+ The image to convert to a NumPy array.
694
+ rescale (`bool`, *optional*):
695
+ Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
696
+ default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
697
+ channel_first (`bool`, *optional*, defaults to `True`):
698
+ Whether or not to permute the dimensions of the image to put the channel dimension first.
699
+ """
700
+ self._ensure_format_supported(image)
701
+
702
+ if isinstance(image, PIL.Image.Image):
703
+ image = np.array(image)
704
+
705
+ if is_torch_tensor(image):
706
+ image = image.numpy()
707
+
708
+ rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
709
+
710
+ if rescale:
711
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
712
+
713
+ if channel_first and image.ndim == 3:
714
+ image = image.transpose(2, 0, 1)
715
+
716
+ return image
717
+
718
+ def expand_dims(self, image):
719
+ """
720
+ Expands 2-dimensional `image` to 3 dimensions.
721
+
722
+ Args:
723
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
724
+ The image to expand.
725
+ """
726
+ self._ensure_format_supported(image)
727
+
728
+ # Do nothing if PIL image
729
+ if isinstance(image, PIL.Image.Image):
730
+ return image
731
+
732
+ if is_torch_tensor(image):
733
+ image = image.unsqueeze(0)
734
+ else:
735
+ image = np.expand_dims(image, axis=0)
736
+ return image
737
+
738
+ def normalize(self, image, mean, std, rescale=False):
739
+ """
740
+ Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
741
+ if it's a PIL Image.
742
+
743
+ Args:
744
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
745
+ The image to normalize.
746
+ mean (`list[float]` or `np.ndarray` or `torch.Tensor`):
747
+ The mean (per channel) to use for normalization.
748
+ std (`list[float]` or `np.ndarray` or `torch.Tensor`):
749
+ The standard deviation (per channel) to use for normalization.
750
+ rescale (`bool`, *optional*, defaults to `False`):
751
+ Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
752
+ happen automatically.
753
+ """
754
+ self._ensure_format_supported(image)
755
+
756
+ if isinstance(image, PIL.Image.Image):
757
+ image = self.to_numpy_array(image, rescale=True)
758
+ # If the input image is a PIL image, it automatically gets rescaled. If it's another
759
+ # type it may need rescaling.
760
+ elif rescale:
761
+ if isinstance(image, np.ndarray):
762
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
763
+ elif is_torch_tensor(image):
764
+ image = self.rescale(image.float(), 1 / 255.0)
765
+
766
+ if isinstance(image, np.ndarray):
767
+ if not isinstance(mean, np.ndarray):
768
+ mean = np.array(mean).astype(image.dtype)
769
+ if not isinstance(std, np.ndarray):
770
+ std = np.array(std).astype(image.dtype)
771
+ elif is_torch_tensor(image):
772
+ import torch
773
+
774
+ if not isinstance(mean, torch.Tensor):
775
+ if isinstance(mean, np.ndarray):
776
+ mean = torch.from_numpy(mean)
777
+ else:
778
+ mean = torch.tensor(mean)
779
+ if not isinstance(std, torch.Tensor):
780
+ if isinstance(std, np.ndarray):
781
+ std = torch.from_numpy(std)
782
+ else:
783
+ std = torch.tensor(std)
784
+
785
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
786
+ return (image - mean[:, None, None]) / std[:, None, None]
787
+ else:
788
+ return (image - mean) / std
789
+
790
+ def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
791
+ """
792
+ Resizes `image`. Enforces conversion of input to PIL.Image.
793
+
794
+ Args:
795
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
796
+ The image to resize.
797
+ size (`int` or `tuple[int, int]`):
798
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
799
+ matched to this.
800
+
801
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
802
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
803
+ this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
804
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
805
+ The filter to user for resampling.
806
+ default_to_square (`bool`, *optional*, defaults to `True`):
807
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
808
+ square (`size`,`size`). If set to `False`, will replicate
809
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
810
+ with support for resizing only the smallest edge and providing an optional `max_size`.
811
+ max_size (`int`, *optional*, defaults to `None`):
812
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
813
+ greater than `max_size` after being resized according to `size`, then the image is resized again so
814
+ that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
815
+ edge may be shorter than `size`. Only used if `default_to_square` is `False`.
816
+
817
+ Returns:
818
+ image: A resized `PIL.Image.Image`.
819
+ """
820
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
821
+
822
+ self._ensure_format_supported(image)
823
+
824
+ if not isinstance(image, PIL.Image.Image):
825
+ image = self.to_pil_image(image)
826
+
827
+ if isinstance(size, list):
828
+ size = tuple(size)
829
+
830
+ if isinstance(size, int) or len(size) == 1:
831
+ if default_to_square:
832
+ size = (size, size) if isinstance(size, int) else (size[0], size[0])
833
+ else:
834
+ width, height = image.size
835
+ # specified size only for the smallest edge
836
+ short, long = (width, height) if width <= height else (height, width)
837
+ requested_new_short = size if isinstance(size, int) else size[0]
838
+
839
+ if short == requested_new_short:
840
+ return image
841
+
842
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
843
+
844
+ if max_size is not None:
845
+ if max_size <= requested_new_short:
846
+ raise ValueError(
847
+ f"max_size = {max_size} must be strictly greater than the requested "
848
+ f"size for the smaller edge size = {size}"
849
+ )
850
+ if new_long > max_size:
851
+ new_short, new_long = int(max_size * new_short / new_long), max_size
852
+
853
+ size = (new_short, new_long) if width <= height else (new_long, new_short)
854
+
855
+ return image.resize(size, resample=resample)
856
+
857
+ def center_crop(self, image, size):
858
+ """
859
+ Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
860
+ size given, it will be padded (so the returned result has the size asked).
861
+
862
+ Args:
863
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
864
+ The image to resize.
865
+ size (`int` or `tuple[int, int]`):
866
+ The size to which crop the image.
867
+
868
+ Returns:
869
+ new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
870
+ height, width).
871
+ """
872
+ self._ensure_format_supported(image)
873
+
874
+ if not isinstance(size, tuple):
875
+ size = (size, size)
876
+
877
+ # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
878
+ if is_torch_tensor(image) or isinstance(image, np.ndarray):
879
+ if image.ndim == 2:
880
+ image = self.expand_dims(image)
881
+ image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
882
+ else:
883
+ image_shape = (image.size[1], image.size[0])
884
+
885
+ top = (image_shape[0] - size[0]) // 2
886
+ bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
887
+ left = (image_shape[1] - size[1]) // 2
888
+ right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
889
+
890
+ # For PIL Images we have a method to crop directly.
891
+ if isinstance(image, PIL.Image.Image):
892
+ return image.crop((left, top, right, bottom))
893
+
894
+ # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
895
+ channel_first = image.shape[0] in [1, 3]
896
+
897
+ # Transpose (height, width, n_channels) format images
898
+ if not channel_first:
899
+ if isinstance(image, np.ndarray):
900
+ image = image.transpose(2, 0, 1)
901
+ if is_torch_tensor(image):
902
+ image = image.permute(2, 0, 1)
903
+
904
+ # Check if cropped area is within image boundaries
905
+ if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
906
+ return image[..., top:bottom, left:right]
907
+
908
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
909
+ new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
910
+ if isinstance(image, np.ndarray):
911
+ new_image = np.zeros_like(image, shape=new_shape)
912
+ elif is_torch_tensor(image):
913
+ new_image = image.new_zeros(new_shape)
914
+
915
+ top_pad = (new_shape[-2] - image_shape[0]) // 2
916
+ bottom_pad = top_pad + image_shape[0]
917
+ left_pad = (new_shape[-1] - image_shape[1]) // 2
918
+ right_pad = left_pad + image_shape[1]
919
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
920
+
921
+ top += top_pad
922
+ bottom += top_pad
923
+ left += left_pad
924
+ right += left_pad
925
+
926
+ new_image = new_image[
927
+ ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
928
+ ]
929
+
930
+ return new_image
931
+
932
+ def flip_channel_order(self, image):
933
+ """
934
+ Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
935
+ `image` to a NumPy array if it's a PIL Image.
936
+
937
+ Args:
938
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
939
+ The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
940
+ be first.
941
+ """
942
+ self._ensure_format_supported(image)
943
+
944
+ if isinstance(image, PIL.Image.Image):
945
+ image = self.to_numpy_array(image)
946
+
947
+ return image[::-1, :, :]
948
+
949
+ def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
950
+ """
951
+ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
952
+ counter clockwise around its centre.
953
+
954
+ Args:
955
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
956
+ The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
957
+ rotating.
958
+
959
+ Returns:
960
+ image: A rotated `PIL.Image.Image`.
961
+ """
962
+ resample = resample if resample is not None else PIL.Image.NEAREST
963
+
964
+ self._ensure_format_supported(image)
965
+
966
+ if not isinstance(image, PIL.Image.Image):
967
+ image = self.to_pil_image(image)
968
+
969
+ return image.rotate(
970
+ angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
971
+ )
972
+
973
+
974
+ def validate_annotations(
975
+ annotation_format: AnnotationFormat,
976
+ supported_annotation_formats: tuple[AnnotationFormat, ...],
977
+ annotations: list[dict],
978
+ ) -> None:
979
+ if annotation_format not in supported_annotation_formats:
980
+ raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
981
+
982
+ if annotation_format is AnnotationFormat.COCO_DETECTION:
983
+ if not valid_coco_detection_annotations(annotations):
984
+ raise ValueError(
985
+ "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
986
+ "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
987
+ "being a list of annotations in the COCO format."
988
+ )
989
+
990
+ if annotation_format is AnnotationFormat.COCO_PANOPTIC:
991
+ if not valid_coco_panoptic_annotations(annotations):
992
+ raise ValueError(
993
+ "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
994
+ "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
995
+ "the latter being a list of annotations in the COCO format."
996
+ )
997
+
998
+
999
+ def validate_kwargs(valid_processor_keys: list[str], captured_kwargs: list[str]):
1000
+ unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
1001
+ if unused_keys:
1002
+ unused_key_str = ", ".join(unused_keys)
1003
+ # TODO raise a warning here instead of simply logging?
1004
+ logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
1005
+
1006
+
1007
+ @dataclass()
1008
+ class SizeDict:
1009
+ """
1010
+ Hashable dictionary to store image size information.
1011
+ """
1012
+
1013
+ height: int | None = None
1014
+ width: int | None = None
1015
+ longest_edge: int | None = None
1016
+ shortest_edge: int | None = None
1017
+ max_height: int | None = None
1018
+ max_width: int | None = None
1019
+
1020
+ def __getitem__(self, key):
1021
+ if hasattr(self, key):
1022
+ return getattr(self, key)
1023
+ raise KeyError(f"Key {key} not found in SizeDict.")
1024
+
1025
+ def get(self, key, default=None):
1026
+ if hasattr(self, key) and getattr(self, key) is not None:
1027
+ return getattr(self, key)
1028
+ return default
1029
+
1030
+ def __iter__(self):
1031
+ # Yield only non-None (key, value) pairs so dict(self) excludes missing values.
1032
+ for f in fields(self):
1033
+ val = getattr(self, f.name)
1034
+ if val is not None:
1035
+ yield f.name, val
1036
+
1037
+ def __hash__(self):
1038
+ return hash((self.height, self.width, self.longest_edge, self.shortest_edge, self.max_height, self.max_width))
1039
+
1040
+ def __contains__(self, key):
1041
+ return hasattr(self, key) and getattr(self, key) is not None
1042
+
1043
+ def __setitem__(self, key, value):
1044
+ if not hasattr(self, key):
1045
+ raise KeyError(f"Key {key} is not a valid field of SizeDict.")
1046
+ object.__setattr__(self, key, value)
1047
+
1048
+ def __eq__(self, other):
1049
+ if isinstance(other, dict):
1050
+ return dict(self) == other
1051
+ if isinstance(other, SizeDict):
1052
+ return tuple(getattr(self, f.name) for f in fields(self)) == tuple(
1053
+ getattr(other, f.name) for f in fields(self)
1054
+ )
1055
+ return NotImplemented
1056
+
1057
+ def __or__(self, other) -> "SizeDict":
1058
+ if isinstance(other, dict | SizeDict):
1059
+ merged = dict(self)
1060
+ merged.update(dict(other))
1061
+ return SizeDict(**merged)
1062
+ return NotImplemented
1063
+
1064
+ def __ror__(self, other) -> dict:
1065
+ if isinstance(other, dict):
1066
+ merged = dict(other)
1067
+ merged.update(dict(self))
1068
+ return merged
1069
+ return NotImplemented
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/masking_utils.py ADDED
@@ -0,0 +1,1514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from collections.abc import Callable
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+ from .cache_utils import Cache
20
+ from .configuration_utils import PreTrainedConfig
21
+ from .utils import is_torch_xpu_available, logging
22
+ from .utils.generic import GeneralInterface, is_flash_attention_requested
23
+ from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_tracing
24
+
25
+
26
+ if is_torch_flex_attn_available():
27
+ from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size
28
+ from torch.nn.attention.flex_attention import BlockMask, create_block_mask
29
+ else:
30
+ # Register a fake type to avoid crashing for annotations and `isinstance` checks
31
+ BlockMask = torch.Tensor
32
+
33
+ _is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
34
+ _is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
35
+ _is_torch_xpu_available = is_torch_xpu_available()
36
+
37
+ if _is_torch_greater_or_equal_than_2_6:
38
+ from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ def and_masks(*mask_functions: Callable) -> Callable:
45
+ """Returns a mask function that is the intersection of provided mask functions"""
46
+ if not all(callable(arg) for arg in mask_functions):
47
+ raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")
48
+
49
+ def and_mask(batch_idx, head_idx, q_idx, kv_idx):
50
+ result = q_idx.new_ones((), dtype=torch.bool)
51
+ for mask in mask_functions:
52
+ result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
53
+ return result
54
+
55
+ return and_mask
56
+
57
+
58
+ def or_masks(*mask_functions: Callable) -> Callable:
59
+ """Returns a mask function that is the union of provided mask functions"""
60
+ if not all(callable(arg) for arg in mask_functions):
61
+ raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")
62
+
63
+ def or_mask(batch_idx, head_idx, q_idx, kv_idx):
64
+ result = q_idx.new_zeros((), dtype=torch.bool)
65
+ for mask in mask_functions:
66
+ result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
67
+ return result
68
+
69
+ return or_mask
70
+
71
+
72
+ def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
73
+ """
74
+ This creates a basic lower-diagonal causal mask.
75
+ """
76
+ return kv_idx <= q_idx
77
+
78
+
79
+ def bidirectional_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
80
+ """
81
+ This creates a full bidirectional mask.
82
+
83
+ NOTE: It is important to keep an index-based version for non-vmap expansion.
84
+ """
85
+ return q_idx >= 0
86
+
87
+
88
+ def sliding_window_overlay(sliding_window: int) -> Callable:
89
+ """
90
+ This is an overlay depicting a sliding window pattern. Add it on top of a causal mask for a proper sliding
91
+ window mask.
92
+ """
93
+
94
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
95
+ return kv_idx > q_idx - sliding_window
96
+
97
+ return inner_mask
98
+
99
+
100
+ def chunked_overlay(chunk_size: int, left_padding: torch.Tensor) -> Callable:
101
+ """
102
+ This is an overlay depicting a chunked attention pattern. Add it on top of a causal mask for a proper chunked
103
+ attention mask.
104
+ """
105
+
106
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
107
+ return (kv_idx - left_padding[batch_idx]) // chunk_size == (q_idx - left_padding[batch_idx]) // chunk_size
108
+
109
+ return inner_mask
110
+
111
+
112
+ def blockwise_overlay(block_sequence_ids: torch.Tensor) -> Callable:
113
+ """
114
+ This is an overlay depicting a blockwise masking pattern. Instead of a single
115
+ token, each block consists of arbitrary length tokens. In causal setup, each block
116
+ can attend to prev block causally and can't attend to future blocks. Within one block
117
+ the attention is always bidirectional.
118
+ Mostly used in MLLMs when non-text data attends bidirectionally to itself.
119
+ """
120
+
121
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
122
+ # Unmask if the q and kv come from same group which is not -1 (i.e. non-text)
123
+ q_group = block_sequence_ids[batch_idx, q_idx]
124
+ kv_group = block_sequence_ids[batch_idx, kv_idx]
125
+ return (q_group == kv_group) & (q_group >= 0)
126
+
127
+ return inner_mask
128
+
129
+
130
+ def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
131
+ """
132
+ This return the mask_function function to create a sliding window mask.
133
+ """
134
+ return and_masks(sliding_window_overlay(sliding_window), causal_mask_function)
135
+
136
+
137
+ def sliding_window_bidirectional_overlay(sliding_window: int) -> Callable:
138
+ """
139
+ This is an overlay depicting a bidirectional sliding window pattern.
140
+ """
141
+
142
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
143
+ """A token can attend to any other token if their absolute distance is within
144
+ the (inclusive) sliding window size (distance <= sliding_window)."""
145
+ return abs(q_idx - kv_idx) <= sliding_window
146
+
147
+ return inner_mask
148
+
149
+
150
+ def sliding_window_bidirectional_mask_function(sliding_window: int) -> Callable:
151
+ """
152
+ This return the mask_function function to create a bidirectional sliding window mask.
153
+ """
154
+ return and_masks(sliding_window_bidirectional_overlay(sliding_window), bidirectional_mask_function)
155
+
156
+
157
+ def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) -> Callable:
158
+ """
159
+ This return the mask_function function to create a chunked attention mask.
160
+ """
161
+ return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function)
162
+
163
+
164
+ def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
165
+ """
166
+ This return the mask_function function corresponding to a 2D padding mask.
167
+ """
168
+
169
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
170
+ # Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
171
+ # we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
172
+ # vectorizable on accelerator devices
173
+ return padding_mask[batch_idx, kv_idx]
174
+
175
+ return inner_mask
176
+
177
+
178
+ def packed_sequence_mask_function(packed_sequence_mask: torch.Tensor) -> Callable:
179
+ """
180
+ This return the mask_function function corresponding to a 2D packed sequence mask.
181
+ """
182
+
183
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
184
+ return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx]
185
+
186
+ return inner_mask
187
+
188
+
189
+ def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable:
190
+ """
191
+ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
192
+ not start and end indices.
193
+ """
194
+
195
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
196
+ return mask_function(batch_idx, head_idx, q_idx + q_offset, kv_idx + kv_offset)
197
+
198
+ return inner_mask
199
+
200
+
201
+ def prepare_padding_mask(attention_mask: torch.Tensor | None, kv_length: int, kv_offset: int) -> torch.Tensor | None:
202
+ """
203
+ From the 2D attention mask, prepare the correct padding mask to use by potentially padding it.
204
+ """
205
+ local_padding_mask = attention_mask
206
+ if attention_mask is not None:
207
+ # Pad it if necessary
208
+ if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
209
+ local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length))
210
+ return local_padding_mask
211
+
212
+
213
+ def maybe_pad_block_sequence_ids(
214
+ block_sequence_ids: torch.Tensor, attention_mask: torch.Tensor | None, kv_length: int, kv_offset: int
215
+ ) -> torch.Tensor:
216
+ """
217
+ Pads the `block_sequence_ids` in case the total length is less than `kv_length`.
218
+ Usually that happens with `StaticCache` generation or generating without cache.
219
+ Pads to the right with `-1`.
220
+ """
221
+ if (padding_length := kv_length + kv_offset - block_sequence_ids.shape[-1]) > 0:
222
+ block_sequence_ids = F.pad(block_sequence_ids, pad=(0, padding_length), value=-1)
223
+ return block_sequence_ids
224
+
225
+
226
+ def _can_skip_causal_mask_xpu(
227
+ padding_mask: torch.Tensor | None,
228
+ query_length: int,
229
+ kv_length: int,
230
+ local_attention_size: int | None,
231
+ ) -> bool:
232
+ """
233
+ XPU-specific logic for determining if we can skip causal mask creation.
234
+
235
+ For XPU devices, we have special handling:
236
+ - Single query tokens (query_length == 1) use the same logic as CUDA
237
+ - Multi-query tokens can skip if padding_mask is provided and correctly structured
238
+ The mask must have all True values in the query window and all False after
239
+ """
240
+
241
+ if is_tracing(padding_mask):
242
+ return False
243
+
244
+ # Check local attention constraint (same as CUDA)
245
+ if local_attention_size is not None and kv_length >= local_attention_size:
246
+ return False
247
+
248
+ if padding_mask is None:
249
+ # Without padding mask, can skip if single query token or full causal attention
250
+ return query_length == 1 or kv_length == query_length
251
+
252
+ # XPU allows skipping under additional conditions when padding_mask is provided
253
+ if query_length == 1:
254
+ # Single query token: skip only if no padding tokens present
255
+ return padding_mask.all()
256
+
257
+ # XPU-specific: check if query window is all True and rest is all False
258
+ # This allows XPU to optimize the 1st token in static cache
259
+ return padding_mask[:, :query_length].all() and not padding_mask[:, query_length:].any()
260
+
261
+
262
+ def _ignore_causal_mask_sdpa(
263
+ padding_mask: torch.Tensor | None,
264
+ query_length: int,
265
+ kv_length: int,
266
+ kv_offset: int,
267
+ local_attention_size: int | None = None,
268
+ ) -> bool:
269
+ """
270
+ Detects whether the causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
271
+
272
+ In case no token is masked in the 2D `padding_mask` argument, if `query_length == 1` or
273
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
274
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
275
+ passed).
276
+ """
277
+ if padding_mask is not None and padding_mask.shape[-1] > kv_length:
278
+ mask_indices = torch.arange(kv_length, device=padding_mask.device)
279
+ mask_indices += kv_offset
280
+ padding_mask = padding_mask[:, mask_indices]
281
+
282
+ if _is_torch_xpu_available:
283
+ # XPU devices have special handling for mask skipping:
284
+ # - Single query tokens use the same logic as CUDA
285
+ # - Multi-query tokens can skip if padding_mask is provided and correctly structured
286
+ # (all True in query window, all False after)
287
+ return _can_skip_causal_mask_xpu(padding_mask, query_length, kv_length, local_attention_size)
288
+ # When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
289
+ # hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
290
+ # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
291
+ # `ignore_causal_mask = True` if we are not tracing
292
+ if (
293
+ not is_tracing(padding_mask)
294
+ # only cases when lower and upper diags are the same, see https://github.com/pytorch/pytorch/issues/108108
295
+ and (query_length == 1 or kv_length == query_length)
296
+ # in this case we need to add special patterns to the mask so cannot be skipped otherwise
297
+ and (local_attention_size is None or kv_length < local_attention_size)
298
+ # In this case, we need to add padding to the mask, so cannot be skipped otherwise
299
+ and (padding_mask is None or padding_mask.all())
300
+ ):
301
+ return True
302
+
303
+ return False
304
+
305
+
306
+ def _can_skip_bidirectional_mask_xpu(
307
+ padding_mask: torch.Tensor | None,
308
+ kv_length: int,
309
+ local_attention_size: int | None,
310
+ ) -> bool:
311
+ """
312
+ XPU-specific logic for determining if we can skip bidirectional mask creation.
313
+
314
+ For XPU devices, we have special handling:
315
+ - Skip if no padding and no local attention constraint
316
+ """
317
+
318
+ if is_tracing(padding_mask):
319
+ return False
320
+
321
+ # Check local attention constraint (same as CUDA)
322
+ if local_attention_size is not None and kv_length >= local_attention_size:
323
+ return False
324
+
325
+ if padding_mask is None:
326
+ # Without padding mask, can always skip for full bidirectional attention
327
+ return True
328
+
329
+ # Skip only if no padding tokens present
330
+ return padding_mask.all()
331
+
332
+
333
+ def _ignore_bidirectional_mask_sdpa(
334
+ padding_mask: torch.Tensor | None,
335
+ kv_length: int,
336
+ local_attention_size: int | None = None,
337
+ ) -> bool:
338
+ """
339
+ Detects whether the bidirectional mask can be ignored in case PyTorch's SDPA is used.
340
+
341
+ In case no token is masked in the 2D `padding_mask` argument and no local attention constraint applies
342
+ (i.e. `local_attention_size` is None or `kv_length < local_attention_size`), we skip mask creation,
343
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
344
+ passed).
345
+ """
346
+ if _is_torch_xpu_available:
347
+ # XPU devices have special handling for mask skipping:
348
+ # - Skip if no padding and no local attention constraint
349
+ return _can_skip_bidirectional_mask_xpu(padding_mask, kv_length, local_attention_size)
350
+
351
+ # When using `torch.export` or `torch.onnx.dynamo_export`, we need to avoid to check the contents of the mask;
352
+ # otherwise, we will encounter dynamic control flows
353
+ if (
354
+ not is_tracing(padding_mask)
355
+ and (padding_mask is None or padding_mask.all())
356
+ # in this case we need to add special patterns to the mask so cannot be skipped otherwise
357
+ and (local_attention_size is None or kv_length < local_attention_size)
358
+ ):
359
+ return True
360
+
361
+ return False
362
+
363
+
364
+ def _vmap_expansion_sdpa(mask_function: Callable) -> Callable:
365
+ """
366
+ Used to vmap our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
367
+ Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
368
+ functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
369
+ """
370
+ # We vmap the function over all 4 dimensions, broadcasting [b_idx, h_idx, q_idx, kv_idx]
371
+ dimensions = [(None, None, None, 0), (None, None, 0, None), (None, 0, None, None), (0, None, None, None)]
372
+ for dims in dimensions:
373
+ mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
374
+ return mask_function
375
+
376
+
377
+ def _non_vmap_expansion_sdpa(
378
+ batch_indices: torch.Tensor, head_indices: torch.Tensor, q_indices: torch.Tensor, kv_indices: torch.Tensor
379
+ ):
380
+ """
381
+ Used to broadcast our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
382
+ Allows the usage of any index-based mask function without relying on vmap.
383
+
384
+ NOTE: This is limited to index based functions only and is not guaranteed to work otherwise.
385
+
386
+ Reference:
387
+ - https://github.com/huggingface/optimum-onnx/blob/c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365
388
+ """
389
+ batch_indices = batch_indices[:, None, None, None]
390
+ head_indices = head_indices[None, :, None, None]
391
+ q_indices = q_indices[None, None, :, None]
392
+ kv_indices = kv_indices[None, None, None, :]
393
+ return batch_indices, head_indices, q_indices, kv_indices
394
+
395
+
396
+ def sdpa_mask(
397
+ batch_size: int,
398
+ q_length: int,
399
+ kv_length: int,
400
+ q_offset: int = 0,
401
+ kv_offset: int = 0,
402
+ mask_function: Callable = causal_mask_function,
403
+ attention_mask: torch.Tensor | None = None,
404
+ local_size: int | None = None,
405
+ allow_is_causal_skip: bool = True,
406
+ allow_is_bidirectional_skip: bool = False,
407
+ allow_torch_fix: bool = True,
408
+ use_vmap: bool = False,
409
+ device: torch.device | str = "cpu",
410
+ **kwargs,
411
+ ) -> torch.Tensor | None:
412
+ """
413
+ Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
414
+ the element should take part in the attention computation, and False that it should not.
415
+ This function can only be used with torch>=2.5, as the context manager is otherwise not available.
416
+
417
+ Args:
418
+ batch_size (`int`):
419
+ The batch size of the input sequence.
420
+ q_length (`int`):
421
+ The size that the query states will have during the attention computation.
422
+ kv_length (`int`):
423
+ The size that the key and value states will have during the attention computation.
424
+ kv_offset (`int`, optional):
425
+ An optional offset to indicate at which first position the key and values states will refer to.
426
+ q_offset (`int`, optional):
427
+ An optional offset to indicate at which first position the query states will refer to.
428
+ mask_function (`Callable`):
429
+ The mask factory function describing the mask pattern.
430
+ attention_mask (`torch.Tensor`, optional):
431
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
432
+ local_size (`int`, optional):
433
+ The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
434
+ to try to skip mask creation if possible.
435
+ allow_is_causal_skip (`bool`, optional):
436
+ Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
437
+ `torch.sdpa` instead. Default to `True`.
438
+ allow_is_bidirectional_skip (`bool`, optional):
439
+ Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
440
+ i.e. full attention without any padding. Default to `False`.
441
+ allow_torch_fix (`bool`, optional):
442
+ Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
443
+ versions. We need an arg to skip it when using eager. By default `True`.
444
+ use_vmap (`bool`, optional):
445
+ Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
446
+ index-based (for the cost of speed performance). By default `False`.
447
+ device (`torch.device` or `str`, optional):
448
+ An optional device to create the mask on.
449
+
450
+
451
+ ## Creating a simple causal mask:
452
+
453
+ To create the following causal mask:
454
+
455
+ 0 ■ ⬚ ⬚ ⬚ ⬚
456
+ 1 ■ ■ ⬚ ⬚ ⬚
457
+ 2 ■ ■ ■ ⬚ ⬚
458
+ 3 ■ ■ ■ ■ ⬚
459
+ 4 ■ ■ ■ ■ ■
460
+
461
+ You can do
462
+
463
+ ```python
464
+ >>> sdpa_mask(batch_size=1, q_length=5, kv_length=5)
465
+ >>> tensor([[[[ True, False, False, False, False],
466
+ [ True, True, False, False, False],
467
+ [ True, True, True, False, False],
468
+ [ True, True, True, True, False],
469
+ [ True, True, True, True, True]]]])
470
+ ```
471
+
472
+ ## Creating a sliding window mask:
473
+
474
+ To create the following sliding window mask (`sliding_window=3`):
475
+
476
+ 0 ■ ⬚ ⬚ ⬚ ⬚
477
+ 1 ■ ■ ⬚ ⬚ ⬚
478
+ 2 ■ ■ ■ ⬚ ⬚
479
+ 3 ⬚ ■ ■ ■ ⬚
480
+ 4 ⬚ ⬚ ■ ■ ■
481
+
482
+ You can do
483
+
484
+ ```python
485
+ >>> sdpa_mask(batch_size=1, q_length=5, kv_length=5, mask_function=sliding_window_causal_mask_function(3))
486
+ >>> tensor([[[[ True, False, False, False, False],
487
+ [ True, True, False, False, False],
488
+ [ True, True, True, False, False],
489
+ [False, True, True, True, False],
490
+ [False, False, True, True, True]]]])
491
+ ```
492
+
493
+ ## Creating a chunked attention mask
494
+
495
+ To create the following chunked attention mask (`chunk_size=3`):
496
+
497
+ 0 ■ ⬚ ⬚ ⬚ ⬚
498
+ 1 ■ ■ ⬚ ⬚ ⬚
499
+ 2 ■ ■ ■ ⬚ ⬚
500
+ 3 ⬚ ⬚ ⬚ ■ ⬚
501
+ 4 ⬚ ⬚ ⬚ ■ ■
502
+
503
+ You can do
504
+
505
+ ```python
506
+ >>> sdpa_mask(batch_size=1, q_length=5, kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
507
+ >>> tensor([[[[ True, False, False, False, False],
508
+ [ True, True, False, False, False],
509
+ [ True, True, True, False, False],
510
+ [False, False, False, True, False],
511
+ [False, False, False, True, True]]]])
512
+ ```
513
+
514
+ """
515
+ # Potentially pad the 2D mask
516
+ padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
517
+
518
+ # Under specific conditions, we can avoid materializing the mask
519
+ # 1. Causal masks can rely on the `is_causal` argument
520
+ # 2. Bidirectional do not need any further processing (no bias)
521
+ if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
522
+ return None
523
+ if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask, kv_length, local_size):
524
+ return None
525
+
526
+ # Potentially add the padding 2D mask
527
+ if padding_mask is not None:
528
+ mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
529
+
530
+ batch_arange = torch.arange(batch_size, device=device)
531
+ head_arange = torch.arange(1, device=device)
532
+ q_arange = torch.arange(q_length, device=device) + q_offset
533
+ kv_arange = torch.arange(kv_length, device=device) + kv_offset
534
+
535
+ # Actual mask creation
536
+ # Option 1: Fast non-vmap mask creation (default)
537
+ if not use_vmap:
538
+ # Apply mask function element-wise through broadcasting
539
+ attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_arange, head_arange, q_arange, kv_arange))
540
+ # Expand the mask to match batch size and query length if they weren't used in the mask function
541
+ attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
542
+
543
+ # Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
544
+ elif _is_torch_greater_or_equal_than_2_6:
545
+ # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
546
+ # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
547
+ # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
548
+ with TransformGetItemToIndex():
549
+ attention_mask = _vmap_expansion_sdpa(mask_function)(batch_arange, head_arange, q_arange, kv_arange)
550
+
551
+ # Option 3: Error out since it indicates that the user did something custom, which they shouldn't have (torch<2.6)
552
+ else:
553
+ raise ValueError(
554
+ "The vmap functionality for mask creation is only supported from torch>=2.6. "
555
+ "Please update your torch version or use `use_vmap=False` with index-based masks."
556
+ )
557
+
558
+ # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any
559
+ # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
560
+ if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
561
+ attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True)
562
+
563
+ return attention_mask
564
+
565
+
566
+ def eager_mask(
567
+ batch_size: int,
568
+ q_length: int,
569
+ kv_length: int,
570
+ q_offset: int = 0,
571
+ kv_offset: int = 0,
572
+ mask_function: Callable = causal_mask_function,
573
+ attention_mask: torch.Tensor | None = None,
574
+ dtype: torch.dtype = torch.float32,
575
+ allow_is_bidirectional_skip: bool = False,
576
+ use_vmap: bool = False,
577
+ device: torch.device | str = "cpu",
578
+ **kwargs,
579
+ ) -> torch.Tensor:
580
+ """
581
+ Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
582
+ the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
583
+ it should not.
584
+
585
+ Args:
586
+ batch_size (`int`):
587
+ The batch size of the input sequence.
588
+ q_length (`int`):
589
+ The size that the query states will have during the attention computation.
590
+ kv_length (`int`):
591
+ The size that the key and value states will have during the attention computation.
592
+ q_offset (`int`, optional):
593
+ An optional offset to indicate at which first position the query states will refer to.
594
+ kv_offset (`int`, optional):
595
+ An optional offset to indicate at which first position the key and values states will refer to.
596
+ mask_function (`Callable`):
597
+ The mask factory function describing the mask pattern.
598
+ attention_mask (`torch.Tensor`, optional):
599
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
600
+ dtype (`torch.dtype`, optional):
601
+ The dtype to use for the mask. By default, `torch.float32`.
602
+ allow_is_bidirectional_skip (`bool`, optional):
603
+ Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
604
+ i.e. full attention without any padding. Default to `False`.
605
+ use_vmap (`bool`, optional):
606
+ Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
607
+ index-based (for the cost of speed performance). By default `False`.
608
+ device (`torch.device` or `str`, optional):
609
+ An optional device to create the mask on.
610
+ """
611
+ # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
612
+ _ = kwargs.pop("allow_is_causal_skip", None)
613
+ _ = kwargs.pop("allow_torch_fix", None)
614
+ mask = sdpa_mask(
615
+ batch_size=batch_size,
616
+ q_length=q_length,
617
+ kv_length=kv_length,
618
+ q_offset=q_offset,
619
+ kv_offset=kv_offset,
620
+ mask_function=mask_function,
621
+ attention_mask=attention_mask,
622
+ allow_is_causal_skip=False,
623
+ allow_is_bidirectional_skip=allow_is_bidirectional_skip,
624
+ allow_torch_fix=False,
625
+ use_vmap=use_vmap,
626
+ device=device,
627
+ **kwargs,
628
+ )
629
+ # only bidirectional masks can be skipped, otherwise we convert bool -> float
630
+ if mask is not None:
631
+ min_dtype = torch.finfo(dtype).min
632
+ # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
633
+ mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
634
+ return mask
635
+
636
+
637
+ def flash_attention_mask(
638
+ batch_size: int,
639
+ q_length: int,
640
+ kv_length: int,
641
+ q_offset: int = 0,
642
+ kv_offset: int = 0,
643
+ mask_function: Callable = causal_mask_function,
644
+ attention_mask: torch.Tensor | None = None,
645
+ **kwargs,
646
+ ):
647
+ """
648
+ Create the attention mask necessary to use FA2. Since FA2 is un-padded by definition, here we simply return
649
+ `None` if the mask is fully causal, or we return the 2D mask which will then be used to extract the seq_lens.
650
+ We just slice it in case of sliding window.
651
+
652
+ Args:
653
+ batch_size (`int`):
654
+ The batch size of the input sequence.
655
+ q_length (`int`):
656
+ The size that the query states will have during the attention computation.
657
+ kv_length (`int`):
658
+ The size that the key and value states will have during the attention computation.
659
+ q_offset (`int`, optional):
660
+ An optional offset to indicate at which first position the query states will refer to.
661
+ kv_offset (`int`, optional):
662
+ An optional offset to indicate at which first position the key and values states will refer to.
663
+ mask_function (`Callable`):
664
+ The mask factory function describing the mask pattern.
665
+ attention_mask (`torch.Tensor`, optional):
666
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
667
+ """
668
+ if attention_mask is not None:
669
+ # Here we need to slice from the right if using sliding or chunked (for full attention, this is equivalent to doing nothing)
670
+ attention_mask = attention_mask[:, -kv_length:]
671
+ # We only return an actual mask if there is at least 1 padding token, otherwise we return `None` and use `is_causal` in FA2
672
+ # (note that the attention_mask is a boolean dtype here)
673
+ if attention_mask.all():
674
+ attention_mask = None
675
+
676
+ return attention_mask
677
+
678
+
679
+ def flex_attention_mask(
680
+ batch_size: int,
681
+ q_length: int,
682
+ kv_length: int,
683
+ q_offset: int = 0,
684
+ kv_offset: int = 0,
685
+ mask_function: Callable = causal_mask_function,
686
+ attention_mask: torch.Tensor | None = None,
687
+ device: torch.device | str = "cpu",
688
+ **kwargs,
689
+ ) -> BlockMask:
690
+ """
691
+ Create a 4D block mask which is a compressed representation of the full 4D block causal mask. BlockMask is essential
692
+ for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/
693
+
694
+ Args:
695
+ batch_size (`int`):
696
+ The batch size of the input sequence.
697
+ q_length (`int`):
698
+ The size that the query states will have during the attention computation.
699
+ kv_length (`int`):
700
+ The size that the key and value states will have during the attention computation.
701
+ q_offset (`int`, optional):
702
+ An optional offset to indicate at which first position the query states will refer to.
703
+ kv_offset (`int`, optional):
704
+ An optional offset to indicate at which first position the key and values states will refer to.
705
+ mask_function (`Callable`):
706
+ The mask factory function describing the mask pattern.
707
+ attention_mask (`torch.Tensor`, optional):
708
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
709
+ device (`torch.device` or `str`, optional):
710
+ An optional device to create the mask on.
711
+ """
712
+ # Potentially add the padding 2D mask
713
+ if attention_mask is not None:
714
+ # Older torch (2.5.x) cannot handle sequences not in multiples of 128 (default block size)
715
+ # Hence we pad to multiples of this as a minimum to ensure this
716
+ pad_len = ((attention_mask.shape[1] // flex_default_block_size) + 1) * flex_default_block_size
717
+ pad_len = pad_len - attention_mask.shape[1]
718
+ if not _is_torch_greater_or_equal_than_2_6 and pad_len > 0:
719
+ attention_mask = torch.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len))
720
+
721
+ padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
722
+ mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
723
+
724
+ # Add the offsets on top (because flex interface only allows length, not start and end indices)
725
+ mask_function = add_offsets_to_mask_function(mask_function, q_offset, kv_offset)
726
+
727
+ # Finally create the block mask
728
+ block_mask = create_block_mask(
729
+ mask_mod=mask_function,
730
+ B=batch_size,
731
+ H=None,
732
+ Q_LEN=q_length,
733
+ KV_LEN=kv_length,
734
+ device=device,
735
+ _compile=_is_torch_greater_or_equal_than_2_6,
736
+ )
737
+ return block_mask
738
+
739
+
740
+ class AttentionMaskInterface(GeneralInterface):
741
+ # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
742
+ # a new instance is created (in order to locally override a given function)
743
+ _global_mapping = {
744
+ "sdpa": sdpa_mask,
745
+ "eager": eager_mask,
746
+ "flash_attention_2": flash_attention_mask,
747
+ "flash_attention_3": flash_attention_mask,
748
+ "flash_attention_4": flash_attention_mask,
749
+ "flex_attention": flex_attention_mask,
750
+ }
751
+
752
+
753
+ # Global AttentionMaskInterface shared by all models which do not need to overwrite any of the existing ones
754
+ ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
755
+
756
+
757
+ def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor | None:
758
+ """
759
+ Find the indices of the sequence to which each new query token in the sequence belongs when using packed
760
+ tensor format (i.e. several sequences packed in the same batch dimension).
761
+
762
+ Args:
763
+ position_ids (`torch.Tensor`)
764
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
765
+
766
+ Returns:
767
+ A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
768
+ pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
769
+
770
+ If the there is only one sequence in each batch item (and we don't compile), then we return `None` indicating
771
+ no packed sequences. This is the same as [[0, 0, 0, 0, 0, 0]] for the example above.
772
+ """
773
+ # What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
774
+ # taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
775
+ # gives exactly the sequence indices
776
+ # Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
777
+ # cannot be part of the end of the first batch dim and the start of the 2nd one for example
778
+ first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
779
+ position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
780
+ packed_sequence_mask = (position_diff != 1).cumsum(-1)
781
+
782
+ # Sadly this is a dynamic control flow, so we cannot enable this check on anything compile related
783
+ if not is_tracing(packed_sequence_mask) and (packed_sequence_mask[:, -1] == 0).all():
784
+ return None
785
+
786
+ return packed_sequence_mask
787
+
788
+
789
+ def _preprocess_mask_arguments(
790
+ config: PreTrainedConfig,
791
+ inputs_embeds: torch.Tensor,
792
+ attention_mask: torch.Tensor | BlockMask | None,
793
+ past_key_values: Cache | None,
794
+ position_ids: torch.Tensor | None,
795
+ layer_idx: int | None,
796
+ encoder_hidden_states: torch.Tensor | None = None,
797
+ ) -> tuple[bool, torch.Tensor | BlockMask | None, int, int]:
798
+ """
799
+ Perform some common pre-processing of the mask arguments we get from the modeling code. Mostly determine the
800
+ key-value length and offsets, and if we should early exit or not.
801
+
802
+ Args:
803
+ config (`PreTrainedConfig`):
804
+ The model config.
805
+ inputs_embeds (`torch.Tensor`):
806
+ The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
807
+ batch size, query length and dtype.
808
+ attention_mask (`torch.Tensor`, optional):
809
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
810
+ It can also be an already prepared 4D mask, in which case it is returned as-is.
811
+ past_key_values (`Cache`, optional):
812
+ The past key values, if we use a cache.
813
+ position_ids (`torch.Tensor`, optional)
814
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
815
+ layer_idx (`int`, optional):
816
+ If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
817
+ length and offset. Indeed, for hybrid caches, different layers may return different lengths.
818
+ encoder_hidden_states (`torch.Tensor`, optional):
819
+ The input embeddings of shape (batch_size, kv_length, hidden_dim). If provided, it is used instead of
820
+ `inputs_embeds` to infer the kv length.
821
+
822
+ Returns:
823
+ early_exit (`bool`):
824
+ Whether we should early exit mask creation, and return the mask as-is.
825
+ attention_mask (`torch.Tensor` or `BlockMask` or `None`):
826
+ The attention mask to either return immediately, or to use in downstream mask creation.
827
+ packed_sequence_mask (`torch.Tensor`, optional):
828
+ In case we detected packed sequence format, this is a tensor where each similar integer indicates that
829
+ the tokens belong to the same sequence.
830
+ q_length (`int`):
831
+ The size that the query states will have during the attention computation.
832
+ kv_length (`int`):
833
+ The size that the key and value states will have during the attention computation.
834
+ q_offset (`int`, optional):
835
+ An optional offset to indicate at which first position the query states will refer to.
836
+ kv_offset (`int`):
837
+ An offset to indicate at which first position the key and values states will refer to.
838
+ """
839
+ # If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
840
+ if isinstance(attention_mask, (torch.Tensor, BlockMask)) and len(attention_mask.shape) == 4:
841
+ return True, attention_mask, None, None, None, None, None
842
+
843
+ # For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
844
+ # Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
845
+ # full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
846
+ # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
847
+ # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
848
+ if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping:
849
+ return True, None, None, None, None, None, None
850
+
851
+ # Move the mask to correct device, and potentially switch dtype for efficiency
852
+ if attention_mask is not None and attention_mask.ndim == 2:
853
+ attention_mask = attention_mask.to(device=inputs_embeds.device, dtype=torch.bool)
854
+
855
+ q_length = inputs_embeds.shape[1]
856
+ # If using a cache, it can give all information about mask sizes based on seen tokens
857
+ if past_key_values is not None:
858
+ q_offset = past_key_values.get_seq_length()
859
+ # To avoid graph breaks, StaticLayer return a tensor instead of int -> this has no impact on the ops, but we
860
+ # need the correct device
861
+ q_offset = q_offset.to(inputs_embeds.device) if isinstance(q_offset, torch.Tensor) else q_offset
862
+ kv_length, kv_offset = past_key_values.get_mask_sizes(q_length, layer_idx)
863
+ # Otherwise, we infer based on our input
864
+ else:
865
+ q_offset = 0
866
+ # 1. Rely on input directly
867
+ if attention_mask is None:
868
+ # For encoder-decoders, use encoder_hidden_states to infer kv_length if provided
869
+ kv_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else q_length
870
+ kv_offset = 0
871
+ # 2. Rely on the mask instead - needed for special cases like prefix tuning in PEFT
872
+ #
873
+ # This is a very unique and special case where an encoder utilizes a cache and expects its length
874
+ # to be accounted for (usually, they should never use a cache). In general, the mask should always
875
+ # match with the input sizes nonetheless (i.e. it does not affect others).
876
+ # Conclusion: "prefix tuning is evil"
877
+ else:
878
+ kv_length, kv_offset = attention_mask.shape[-1], 0
879
+
880
+ # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
881
+ # and we don't have past_key_values, i.e. generally a training setup)
882
+ packed_sequence_mask = None
883
+ if position_ids is not None and attention_mask is None and past_key_values is None:
884
+ batch_size = inputs_embeds.shape[0]
885
+ # The position ids are sometimes just unsqueezed, without being expanded
886
+ if batch_size != position_ids.shape[0]:
887
+ position_ids = position_ids.expand(batch_size, -1)
888
+ packed_sequence_mask = find_packed_sequence_indices(position_ids)
889
+
890
+ return False, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset
891
+
892
+
893
+ def create_causal_mask(
894
+ config: PreTrainedConfig,
895
+ inputs_embeds: torch.Tensor,
896
+ attention_mask: torch.Tensor | None,
897
+ past_key_values: Cache | None,
898
+ position_ids: torch.Tensor | None = None,
899
+ or_mask_function: Callable | None = None,
900
+ and_mask_function: Callable | None = None,
901
+ block_sequence_ids: torch.Tensor | None = None,
902
+ ) -> torch.Tensor | BlockMask | None:
903
+ """
904
+ Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
905
+ has an hybrid cache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
906
+ to what is needed in the `modeling_xxx.py` files).
907
+
908
+ Args:
909
+ config (`PreTrainedConfig`):
910
+ The model config.
911
+ inputs_embeds (`torch.Tensor`):
912
+ The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
913
+ batch size, query length and dtype.
914
+ attention_mask (`torch.Tensor`, optional):
915
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
916
+ It can also be an already prepared 4D mask, in which case it is returned as-is.
917
+ cache_position (`torch.Tensor`):
918
+ Deprecated and unused.
919
+ past_key_values (`Cache`, optional):
920
+ The past key values, if we use a cache.
921
+ position_ids (`torch.Tensor`, optional)
922
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
923
+ or_mask_function (`Callable`, optional):
924
+ An optional mask function to combine with the causal mask function (by doing the union of both). This is
925
+ useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
926
+ and_mask_function (`Callable`, optional):
927
+ An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
928
+ useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
929
+ block_sequence_ids (`torch.Tensor`, *optional*):
930
+ A tensor of same shape as input IDs indicating to which block or group each token belongs to. Tokens from
931
+ the same block will keep a bidirectional mask within the block, attending causally to the past. Index `-1`
932
+ can be used for blocks that have to keep complete causality within itself.
933
+ """
934
+ # Power feature: if `is_causal` is False, then fallback to bi-directional mask for bi-directional attention.
935
+ # It allows to use decoder-only models with bi-directional attention as well
936
+ if not getattr(config, "is_causal", True):
937
+ return create_bidirectional_mask(
938
+ config,
939
+ inputs_embeds,
940
+ attention_mask,
941
+ past_key_values=past_key_values,
942
+ or_mask_function=or_mask_function,
943
+ and_mask_function=and_mask_function,
944
+ )
945
+
946
+ # If we have an hybrid cache structure, here we want to create the mask for the full layers
947
+ if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding:
948
+ layer_idx = past_key_values.is_sliding.index(False)
949
+ else:
950
+ layer_idx = 0
951
+
952
+ early_exit, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset = (
953
+ _preprocess_mask_arguments(config, inputs_embeds, attention_mask, past_key_values, position_ids, layer_idx)
954
+ )
955
+ if early_exit:
956
+ return attention_mask
957
+
958
+ batch_size, dtype, device = inputs_embeds.shape[0], inputs_embeds.dtype, inputs_embeds.device
959
+ mask_factory_function = causal_mask_function
960
+ mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
961
+
962
+ # Defaulting to using non-vmap based mask creations except when detecting
963
+ # users passing custom mask functions (as we cannot guarantee that they
964
+ # are properly index-based as required by our implementation).
965
+ use_vmap = False
966
+
967
+ # Do not allow skip if we are compiling (this is to match BC)
968
+ # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
969
+ if _is_torch_xpu_available:
970
+ # Do not allow skip if we are compiling for decoding, but for prefill, we still allow skip to optimization the perf of 1st token generation
971
+ allow_is_causal_skip = not (getattr(past_key_values, "is_compileable", False) and q_length == 1)
972
+ else:
973
+ allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
974
+
975
+ # Allow slight deviations from causal mask
976
+ # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
977
+ # padding mask, etc) as the resulting mask may otherwise not be correct!
978
+ if or_mask_function is not None:
979
+ if not _is_torch_greater_or_equal_than_2_6:
980
+ raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
981
+ mask_factory_function = or_masks(mask_factory_function, or_mask_function)
982
+ allow_is_causal_skip = False
983
+ use_vmap = True
984
+ if and_mask_function is not None:
985
+ if not _is_torch_greater_or_equal_than_2_6:
986
+ raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
987
+ mask_factory_function = and_masks(mask_factory_function, and_mask_function)
988
+ allow_is_causal_skip = False
989
+ use_vmap = True
990
+
991
+ # If we detected packing format or blockwise overlay
992
+ if packed_sequence_mask is not None:
993
+ mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
994
+ allow_is_causal_skip = False
995
+ if block_sequence_ids is not None:
996
+ block_sequence_ids = maybe_pad_block_sequence_ids(block_sequence_ids, attention_mask, kv_length, kv_offset)
997
+ mask_factory_function = or_masks(mask_factory_function, blockwise_overlay(block_sequence_ids))
998
+ allow_is_causal_skip = False
999
+
1000
+ # We now create the mask
1001
+ causal_mask = mask_interface(
1002
+ batch_size=batch_size,
1003
+ q_length=q_length,
1004
+ kv_length=kv_length,
1005
+ q_offset=q_offset,
1006
+ kv_offset=kv_offset,
1007
+ mask_function=mask_factory_function,
1008
+ attention_mask=attention_mask,
1009
+ allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
1010
+ dtype=dtype, # Additional kwarg for eager
1011
+ config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
1012
+ use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
1013
+ device=device,
1014
+ )
1015
+ return causal_mask
1016
+
1017
+
1018
+ def create_bidirectional_mask(
1019
+ config: PreTrainedConfig,
1020
+ inputs_embeds: torch.Tensor,
1021
+ attention_mask: torch.Tensor | None,
1022
+ encoder_hidden_states: torch.Tensor | None = None,
1023
+ past_key_values: Cache | None = None,
1024
+ or_mask_function: Callable | None = None,
1025
+ and_mask_function: Callable | None = None,
1026
+ **kwargs,
1027
+ ) -> torch.Tensor | BlockMask | None:
1028
+ """
1029
+ Create a standard bidirectional mask based on the attention implementation used (stored in the config).
1030
+
1031
+ Args:
1032
+ config (`PreTrainedConfig`):
1033
+ The model config.
1034
+ inputs_embeds (`torch.Tensor`):
1035
+ The input embeddings of shape (batch_size, query_length, hidden_dim). This is only used to infer metadata
1036
+ such as the batch size, query length, dtype, and device.
1037
+ attention_mask (`torch.Tensor`, optional):
1038
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, kv_length).
1039
+ It can also be an already prepared 4D mask of shape (batch_size, 1, query_length, kv_length),
1040
+ in which case it is returned as-is.
1041
+ encoder_hidden_states (`torch.Tensor`, optional):
1042
+ The input embeddings of shape (batch_size, kv_length, hidden_dim). If provided, it is used instead of
1043
+ `inputs_embeds` to infer the batch size, kv length and dtype.
1044
+ past_key_values (`Cache`, optional):
1045
+ The past key values, if we use a cache.
1046
+ or_mask_function (`Callable`, optional):
1047
+ An optional mask function to combine with the base mask function (by doing the union of both). This is
1048
+ useful to easily overlay another mask on top, for example for image tokens handling.
1049
+ and_mask_function (`Callable`, optional):
1050
+ An optional mask function to combine with the base mask function (by doing the intersection of both). This is
1051
+ useful to easily overlay another mask on top, for example for image tokens handling.
1052
+ """
1053
+ # We ignore a few irrelevant arguments at the end as we do not have a (growing) cache here
1054
+ early_exit, attention_mask, _, q_length, kv_length, q_offset, kv_offset = _preprocess_mask_arguments(
1055
+ config, inputs_embeds, attention_mask, past_key_values, None, 0, encoder_hidden_states
1056
+ )
1057
+ if early_exit:
1058
+ return attention_mask
1059
+
1060
+ embeds = encoder_hidden_states if encoder_hidden_states is not None else inputs_embeds
1061
+ batch_size, dtype, device = embeds.shape[0], embeds.dtype, embeds.device
1062
+ mask_factory_function = bidirectional_mask_function
1063
+ mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
1064
+
1065
+ # Allow skipping the mask creation except we have additional masking operators (and/or masks)
1066
+ allow_is_bidirectional_skip = True
1067
+ # Defaulting to using non-vmap based mask creations except when detecting
1068
+ # users passing custom mask functions (as we cannot guarantee that they
1069
+ # are properly index-based as required by our implementation).
1070
+ use_vmap = False
1071
+
1072
+ # Allow slight deviations from the base mask
1073
+ # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
1074
+ # padding mask, etc) as the resulting mask may otherwise not be correct!
1075
+ if or_mask_function is not None:
1076
+ if not _is_torch_greater_or_equal_than_2_6:
1077
+ raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
1078
+ mask_factory_function = or_masks(mask_factory_function, or_mask_function)
1079
+ allow_is_bidirectional_skip = False
1080
+ use_vmap = True
1081
+ if and_mask_function is not None:
1082
+ if not _is_torch_greater_or_equal_than_2_6:
1083
+ raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
1084
+ mask_factory_function = and_masks(mask_factory_function, and_mask_function)
1085
+ allow_is_bidirectional_skip = False
1086
+ use_vmap = True
1087
+
1088
+ # We now create the mask
1089
+ attention_mask = mask_interface(
1090
+ batch_size=batch_size,
1091
+ q_length=q_length,
1092
+ kv_length=kv_length,
1093
+ q_offset=q_offset,
1094
+ kv_offset=kv_offset,
1095
+ mask_function=mask_factory_function,
1096
+ attention_mask=attention_mask,
1097
+ # Additional kwargs for sdpa
1098
+ allow_is_causal_skip=False,
1099
+ allow_is_bidirectional_skip=allow_is_bidirectional_skip,
1100
+ dtype=dtype, # Additional kwarg for eager
1101
+ config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
1102
+ use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
1103
+ device=device,
1104
+ )
1105
+ return attention_mask
1106
+
1107
+
1108
+ def create_sliding_window_causal_mask(
1109
+ config: PreTrainedConfig,
1110
+ inputs_embeds: torch.Tensor,
1111
+ attention_mask: torch.Tensor | None,
1112
+ past_key_values: Cache | None,
1113
+ position_ids: torch.Tensor | None = None,
1114
+ or_mask_function: Callable | None = None,
1115
+ and_mask_function: Callable | None = None,
1116
+ block_sequence_ids: torch.Tensor | None = None,
1117
+ ) -> torch.Tensor | BlockMask | None:
1118
+ """
1119
+ Create a sliding window causal mask based on the attention implementation used (stored in the config). This type
1120
+ of attention pattern was mostly democratized by Mistral. If `past_key_values` has an hybrid cache structure, this
1121
+ function will return the mask corresponding to one of the "sliding_attention" layers (to align to what is needed in the
1122
+ `modeling_xxx.py` files).
1123
+
1124
+ Args:
1125
+ config (`PreTrainedConfig`):
1126
+ The model config.
1127
+ inputs_embeds (`torch.Tensor`):
1128
+ The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
1129
+ batch size, query length and dtype.
1130
+ attention_mask (`torch.Tensor`, optional):
1131
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
1132
+ It can also be an already prepared 4D mask, in which case it is returned as-is.
1133
+ cache_position (`torch.Tensor`):
1134
+ Deprecated and unused.
1135
+ past_key_values (`Cache`, optional):
1136
+ The past key values, if we use a cache.
1137
+ position_ids (`torch.Tensor`, optional)
1138
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
1139
+ or_mask_function (`Callable`, optional):
1140
+ An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
1141
+ useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
1142
+ and_mask_function (`Callable`, optional):
1143
+ An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is
1144
+ useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
1145
+ block_sequence_ids (`torch.Tensor`, *optional*):
1146
+ A tensor of same shape as input IDs indicating to which block or group each token belongs to. Tokens from
1147
+ the same block will keep a bidirectional mask within the block, attending causally to the past. Index `-1`
1148
+ can be used for blocks that have to keep complete causality within itself.
1149
+ """
1150
+ # Power feature: if `is_causal` is False, then fallback to bi-directional mask for bi-directional attention
1151
+ # It allows to use decoder-only models with bi-directional attention as well
1152
+ if not getattr(config, "is_causal", True):
1153
+ return create_bidirectional_sliding_window_mask(
1154
+ config,
1155
+ inputs_embeds,
1156
+ attention_mask,
1157
+ past_key_values=past_key_values,
1158
+ or_mask_function=or_mask_function,
1159
+ and_mask_function=and_mask_function,
1160
+ )
1161
+
1162
+ # If we have an hybrid cache structure, here we want to create the mask for the sliding layers
1163
+ if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
1164
+ layer_idx = past_key_values.is_sliding.index(True)
1165
+ else:
1166
+ layer_idx = 0
1167
+
1168
+ early_exit, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset = (
1169
+ _preprocess_mask_arguments(config, inputs_embeds, attention_mask, past_key_values, position_ids, layer_idx)
1170
+ )
1171
+ if early_exit:
1172
+ return attention_mask
1173
+
1174
+ sliding_window = getattr(config, "sliding_window", None)
1175
+ if sliding_window is None:
1176
+ raise ValueError("Could not find a `sliding_window` argument in the config, or it is not set")
1177
+
1178
+ batch_size, dtype, device = inputs_embeds.shape[0], inputs_embeds.dtype, inputs_embeds.device
1179
+ mask_factory_function = sliding_window_causal_mask_function(sliding_window)
1180
+ mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
1181
+
1182
+ # Defaulting to using non-vmap based mask creations except when detecting
1183
+ # users passing custom mask functions (as we cannot guarantee that they
1184
+ # are properly index-based as required by our implementation).
1185
+ use_vmap = False
1186
+ # Do not allow skip if we are compiling (this is to match BC)
1187
+ # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
1188
+ allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
1189
+
1190
+ # Allow slight deviations from causal mask
1191
+ # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
1192
+ # padding mask, etc) as the resulting mask may otherwise not be correct!
1193
+ if or_mask_function is not None:
1194
+ if not _is_torch_greater_or_equal_than_2_6:
1195
+ raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
1196
+ mask_factory_function = or_masks(mask_factory_function, or_mask_function)
1197
+ allow_is_causal_skip = False
1198
+ use_vmap = True
1199
+ if and_mask_function is not None:
1200
+ if not _is_torch_greater_or_equal_than_2_6:
1201
+ raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
1202
+ mask_factory_function = and_masks(mask_factory_function, and_mask_function)
1203
+ allow_is_causal_skip = False
1204
+ use_vmap = True
1205
+
1206
+ # If we detected packing format or blockwise overlay
1207
+ if packed_sequence_mask is not None:
1208
+ mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
1209
+ allow_is_causal_skip = False
1210
+ if block_sequence_ids is not None:
1211
+ block_sequence_ids = maybe_pad_block_sequence_ids(block_sequence_ids, attention_mask, kv_length, kv_offset)
1212
+ mask_factory_function = or_masks(mask_factory_function, blockwise_overlay(block_sequence_ids))
1213
+ allow_is_causal_skip = False
1214
+
1215
+ # We now create the mask
1216
+ causal_mask = mask_interface(
1217
+ batch_size=batch_size,
1218
+ q_length=q_length,
1219
+ kv_length=kv_length,
1220
+ q_offset=q_offset,
1221
+ kv_offset=kv_offset,
1222
+ mask_function=mask_factory_function,
1223
+ attention_mask=attention_mask,
1224
+ allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
1225
+ local_size=sliding_window, # Additional kwarg for sdpa
1226
+ dtype=dtype, # Additional kwarg for eager
1227
+ config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
1228
+ use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
1229
+ device=device,
1230
+ )
1231
+ return causal_mask
1232
+
1233
+
1234
+ def create_bidirectional_sliding_window_mask(
1235
+ config: PreTrainedConfig,
1236
+ inputs_embeds: torch.Tensor,
1237
+ attention_mask: torch.Tensor | None,
1238
+ encoder_hidden_states: torch.Tensor | None = None,
1239
+ past_key_values: Cache | None = None,
1240
+ or_mask_function: Callable | None = None,
1241
+ and_mask_function: Callable | None = None,
1242
+ **kwargs,
1243
+ ) -> torch.Tensor | BlockMask | None:
1244
+ """
1245
+ Create a standard bidirectional sliding window mask based on the attention implementation used (stored in the config).
1246
+
1247
+ Args:
1248
+ config (`PreTrainedConfig`):
1249
+ The model config.
1250
+ inputs_embeds (`torch.Tensor`):
1251
+ The input embeddings of shape (batch_size, query_length, hidden_dim). This is only used to infer metadata
1252
+ such as the batch size, query length, dtype, and device.
1253
+ attention_mask (`torch.Tensor`, optional):
1254
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, kv_length).
1255
+ It can also be an already prepared 4D mask of shape (batch_size, 1, query_length, kv_length),
1256
+ in which case it is returned as-is.
1257
+ encoder_hidden_states (`torch.Tensor`, optional):
1258
+ The input embeddings of shape (batch_size, kv_length, hidden_dim). If provided, it is used instead of
1259
+ `inputs_embeds` to infer the batch size, kv length and dtype.
1260
+ past_key_values (`Cache`, optional):
1261
+ The past key values, if we use a cache.
1262
+ or_mask_function (`Callable`, optional):
1263
+ An optional mask function to combine with the base mask function (by doing the union of both). This is
1264
+ useful to easily overlay another mask on top, for example for image tokens handling.
1265
+ and_mask_function (`Callable`, optional):
1266
+ An optional mask function to combine with the base mask function (by doing the intersection of both). This is
1267
+ useful to easily overlay another mask on top, for example for image tokens handling.
1268
+ """
1269
+ # We ignore a few irrelevant arguments at the end as we do not have a (growing) cache here
1270
+ early_exit, attention_mask, _, q_length, kv_length, q_offset, kv_offset = _preprocess_mask_arguments(
1271
+ config, inputs_embeds, attention_mask, past_key_values, None, 0, encoder_hidden_states
1272
+ )
1273
+ if early_exit:
1274
+ return attention_mask
1275
+
1276
+ sliding_window = getattr(config, "sliding_window", None)
1277
+ if sliding_window is None:
1278
+ raise ValueError("Could not find a `sliding_window` argument in the config, or it is not set")
1279
+
1280
+ embeds = encoder_hidden_states if encoder_hidden_states is not None else inputs_embeds
1281
+ batch_size, dtype, device = embeds.shape[0], embeds.dtype, embeds.device
1282
+ mask_factory_function = sliding_window_bidirectional_mask_function(sliding_window)
1283
+ mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
1284
+
1285
+ use_vmap = False
1286
+ allow_is_bidirectional_skip = True
1287
+
1288
+ if or_mask_function is not None:
1289
+ if not _is_torch_greater_or_equal_than_2_6:
1290
+ raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
1291
+ mask_factory_function = or_masks(mask_factory_function, or_mask_function)
1292
+ allow_is_bidirectional_skip = False
1293
+ use_vmap = True
1294
+ if and_mask_function is not None:
1295
+ if not _is_torch_greater_or_equal_than_2_6:
1296
+ raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
1297
+ mask_factory_function = and_masks(mask_factory_function, and_mask_function)
1298
+ allow_is_bidirectional_skip = False
1299
+ use_vmap = True
1300
+
1301
+ attention_mask = mask_interface(
1302
+ batch_size=batch_size,
1303
+ q_length=q_length,
1304
+ kv_length=kv_length,
1305
+ q_offset=q_offset,
1306
+ kv_offset=kv_offset,
1307
+ mask_function=mask_factory_function,
1308
+ attention_mask=attention_mask,
1309
+ allow_is_causal_skip=False,
1310
+ allow_is_bidirectional_skip=allow_is_bidirectional_skip,
1311
+ local_size=sliding_window, # Additional kwarg for sdpa
1312
+ dtype=dtype, # Additional kwarg for eager
1313
+ config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
1314
+ use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
1315
+ device=device,
1316
+ )
1317
+ return attention_mask
1318
+
1319
+
1320
+ def create_chunked_causal_mask(
1321
+ config: PreTrainedConfig,
1322
+ inputs_embeds: torch.Tensor,
1323
+ attention_mask: torch.Tensor | None,
1324
+ past_key_values: Cache | None,
1325
+ position_ids: torch.Tensor | None = None,
1326
+ or_mask_function: Callable | None = None,
1327
+ and_mask_function: Callable | None = None,
1328
+ ) -> torch.Tensor | BlockMask | None:
1329
+ """
1330
+ Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type
1331
+ of attention pattern was mostly democratized by Llama4. If `past_key_values` has an hybrid cache structure, this
1332
+ function will return the mask corresponding to one of the "chunked_attention" layers (to align to what is needed in the
1333
+ `modeling_xxx.py` files).
1334
+
1335
+ Args:
1336
+ config (`PreTrainedConfig`):
1337
+ The model config.
1338
+ inputs_embeds (`torch.Tensor`):
1339
+ The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
1340
+ batch size, query length and dtype.
1341
+ attention_mask (`torch.Tensor`, optional):
1342
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
1343
+ It can also be an already prepared 4D mask, in which case it is returned as-is.
1344
+ cache_position (`torch.Tensor`):
1345
+ Deprecated and unused.
1346
+ past_key_values (`Cache`, optional):
1347
+ The past key values, if we use a cache.
1348
+ position_ids (`torch.Tensor`, optional)
1349
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
1350
+ or_mask_function (`Callable`, optional):
1351
+ An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
1352
+ useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
1353
+ and_mask_function (`Callable`, optional):
1354
+ An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is
1355
+ useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
1356
+ """
1357
+ # If we have an hybrid cache structure, here we want to create the mask for the sliding layers
1358
+ if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
1359
+ layer_idx = past_key_values.is_sliding.index(True)
1360
+ else:
1361
+ layer_idx = 0
1362
+
1363
+ early_exit, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset = (
1364
+ _preprocess_mask_arguments(config, inputs_embeds, attention_mask, past_key_values, position_ids, layer_idx)
1365
+ )
1366
+ if early_exit:
1367
+ return attention_mask
1368
+
1369
+ chunk_size = getattr(config, "attention_chunk_size", None)
1370
+ if chunk_size is None:
1371
+ raise ValueError("Could not find an `attention_chunk_size` argument in the config, or it is not set")
1372
+
1373
+ # Raise if using chunked attention on context too large with FA
1374
+ if is_flash_attention_requested(config) and kv_length + kv_offset > chunk_size:
1375
+ raise ValueError(
1376
+ "Flash attention cannot handle chunked attention, and the key-value length is larger than the chunk size so the "
1377
+ "chunked pattern cannot be respected. You should use another `attn_implementation` when instantiating the model"
1378
+ )
1379
+
1380
+ batch_size, dtype, device = inputs_embeds.shape[0], inputs_embeds.dtype, inputs_embeds.device
1381
+ # For chunked attention and batched inputs, we need to take the number of left padding tokens into account
1382
+ # to start the chunk from the actual start of the sequence for the padded sequence
1383
+ if attention_mask is not None:
1384
+ # Only count the left padding tokens, not all of them
1385
+ left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1)
1386
+ else:
1387
+ left_padding_tokens = torch.zeros(batch_size, device=device, dtype=int)
1388
+ mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens)
1389
+ mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
1390
+
1391
+ # Defaulting to using non-vmap based mask creations except when detecting
1392
+ # users passing custom mask functions (as we cannot guarantee that they
1393
+ # are properly index-based as required by our implementation).
1394
+ use_vmap = False
1395
+ # Do not allow skip if we are compiling (this is to match BC)
1396
+ # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
1397
+ allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
1398
+
1399
+ # Allow slight deviations from causal mask
1400
+ # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
1401
+ # padding mask, etc) as the resulting mask may otherwise not be correct!
1402
+ if or_mask_function is not None:
1403
+ if not _is_torch_greater_or_equal_than_2_6:
1404
+ raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
1405
+ mask_factory_function = or_masks(mask_factory_function, or_mask_function)
1406
+ allow_is_causal_skip = False
1407
+ use_vmap = True
1408
+ if and_mask_function is not None:
1409
+ if not _is_torch_greater_or_equal_than_2_6:
1410
+ raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
1411
+ mask_factory_function = and_masks(mask_factory_function, and_mask_function)
1412
+ allow_is_causal_skip = False
1413
+ use_vmap = True
1414
+
1415
+ # If we detected packing format
1416
+ if packed_sequence_mask is not None:
1417
+ mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
1418
+ allow_is_causal_skip = False
1419
+
1420
+ # We now create the mask
1421
+ causal_mask = mask_interface(
1422
+ batch_size=batch_size,
1423
+ q_length=q_length,
1424
+ kv_length=kv_length,
1425
+ q_offset=q_offset,
1426
+ kv_offset=kv_offset,
1427
+ mask_function=mask_factory_function,
1428
+ attention_mask=attention_mask,
1429
+ allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
1430
+ local_size=chunk_size, # Additional kwarg for sdpa
1431
+ dtype=dtype, # Additional kwarg for eager
1432
+ config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
1433
+ use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
1434
+ device=device,
1435
+ )
1436
+ return causal_mask
1437
+
1438
+
1439
+ LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = {
1440
+ "full_attention": create_causal_mask,
1441
+ "sliding_attention": create_sliding_window_causal_mask,
1442
+ "chunked_attention": create_chunked_causal_mask,
1443
+ "compressed_sparse_attention": create_sliding_window_causal_mask,
1444
+ "heavily_compressed_attention": create_sliding_window_causal_mask,
1445
+ }
1446
+
1447
+
1448
+ def create_masks_for_generate(
1449
+ config: PreTrainedConfig,
1450
+ inputs_embeds: torch.Tensor,
1451
+ attention_mask: torch.Tensor | None,
1452
+ past_key_values: Cache | None,
1453
+ position_ids: torch.Tensor | None = None,
1454
+ or_mask_function: Callable | None = None,
1455
+ and_mask_function: Callable | None = None,
1456
+ block_sequence_ids: torch.Tensor | None = None,
1457
+ **kwargs,
1458
+ ):
1459
+ """
1460
+ This function mimics how we create the masks in the `modeling_xxx.py` files, and is used in places like `generate`
1461
+ in order to easily create the masks in advance, when we compile the forwards with Static caches.
1462
+
1463
+ Args:
1464
+ config (`PreTrainedConfig`):
1465
+ The model config.
1466
+ inputs_embeds (`torch.Tensor`):
1467
+ The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
1468
+ batch size, query length and dtype.
1469
+ attention_mask (`torch.Tensor`, optional):
1470
+ The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
1471
+ It can also be an already prepared 4D mask, in which case it is returned as-is.
1472
+ past_key_values (`Cache`, optional):
1473
+ The past key values, if we use a cache.
1474
+ position_ids (`torch.Tensor`, optional)
1475
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
1476
+ or_mask_function (`Callable`, optional):
1477
+ An optional mask function to combine with the other mask function (by doing the union of both). This is
1478
+ useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
1479
+ and_mask_function (`Callable`, optional):
1480
+ An optional mask function to combine with the other mask function (by doing the intersection of both). This is
1481
+ useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
1482
+ block_sequence_ids (`torch.Tensor`, *optional*):
1483
+ A tensor of same shape as input IDs indicating to which block or group each token belongs to. Tokens from
1484
+ the same block will keep a bidirectional mask within the block, attending causally to the past. Index `-1`
1485
+ can be used for blocks that have to keep complete causality within itself.
1486
+ """
1487
+ # The attribute reside in the text config for composite models
1488
+ effective_config = config.get_text_config()
1489
+ # Prepare the mask args
1490
+ mask_kwargs = {
1491
+ "config": effective_config,
1492
+ "inputs_embeds": inputs_embeds,
1493
+ "attention_mask": attention_mask,
1494
+ "past_key_values": past_key_values,
1495
+ "position_ids": position_ids,
1496
+ "or_mask_function": or_mask_function,
1497
+ "and_mask_function": and_mask_function,
1498
+ "block_sequence_ids": block_sequence_ids,
1499
+ }
1500
+
1501
+ # If the attribute exist, we need several masks
1502
+ if hasattr(effective_config, "layer_types"):
1503
+ causal_masks = {}
1504
+ for layer_pattern in set(effective_config.layer_types):
1505
+ causal_masks[layer_pattern] = LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING[layer_pattern](**mask_kwargs)
1506
+ return causal_masks
1507
+ # In this case, all layers are sliding
1508
+ elif getattr(effective_config, "sliding_window", None) is not None:
1509
+ return create_sliding_window_causal_mask(**mask_kwargs)
1510
+ # In this case, all layers are chunked
1511
+ elif getattr(effective_config, "attention_chunk_size", None) is not None:
1512
+ return create_chunked_causal_mask(**mask_kwargs)
1513
+ # All layers use standard causal attention
1514
+ return create_causal_mask(**mask_kwargs)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general
16
+ `masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now,
17
+ and will be removed in the future.
18
+ """
19
+
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from typing import Union
23
+
24
+ import torch
25
+
26
+ from .utils.import_utils import is_torchdynamo_compiling, is_tracing
27
+
28
+
29
+ DEPRECATION_MESSAGE = (
30
+ "The attention mask API under `transformers.modeling_attn_mask_utils` (`AttentionMaskConverter`) "
31
+ "is deprecated and will be removed in Transformers v5.10. Please use the new API in `transformers.masking_utils`."
32
+ )
33
+
34
+
35
+ @dataclass
36
+ class AttentionMaskConverter:
37
+ """
38
+ A utility attention mask class that allows one to:
39
+ - Create a causal 4d mask
40
+ - Create a causal 4d mask with slided window
41
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
42
+ key_value_length) that can be multiplied with attention scores
43
+
44
+ Examples:
45
+
46
+ ```python
47
+ >>> import torch
48
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
49
+
50
+ >>> converter = AttentionMaskConverter(True)
51
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
52
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
53
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
54
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
55
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
56
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
57
+ ```
58
+
59
+ Parameters:
60
+ is_causal (`bool`):
61
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
62
+
63
+ sliding_window (`int`, *optional*):
64
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
65
+ """
66
+
67
+ is_causal: bool
68
+ sliding_window: int
69
+
70
+ def __init__(self, is_causal: bool, sliding_window: int | None = None):
71
+ warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
72
+
73
+ self.is_causal = is_causal
74
+ self.sliding_window = sliding_window
75
+
76
+ if self.sliding_window is not None and self.sliding_window <= 0:
77
+ raise ValueError(
78
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
79
+ )
80
+
81
+ def to_causal_4d(
82
+ self,
83
+ batch_size: int,
84
+ query_length: int,
85
+ key_value_length: int,
86
+ dtype: torch.dtype,
87
+ device: Union[torch.device, "str"] = "cpu",
88
+ ) -> torch.Tensor | None:
89
+ """
90
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
91
+ bias to upper right hand triangular matrix (causal mask).
92
+ """
93
+ if not self.is_causal:
94
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
95
+
96
+ # If shape is not cached, create a new causal mask and cache it
97
+ input_shape = (batch_size, query_length)
98
+ past_key_values_length = key_value_length - query_length
99
+
100
+ # create causal mask
101
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
102
+ causal_4d_mask = None
103
+ if input_shape[-1] > 1 or self.sliding_window is not None:
104
+ causal_4d_mask = self._make_causal_mask(
105
+ input_shape,
106
+ dtype,
107
+ device=device,
108
+ past_key_values_length=past_key_values_length,
109
+ sliding_window=self.sliding_window,
110
+ )
111
+
112
+ return causal_4d_mask
113
+
114
+ def to_4d(
115
+ self,
116
+ attention_mask_2d: torch.Tensor,
117
+ query_length: int,
118
+ dtype: torch.dtype,
119
+ key_value_length: int | None = None,
120
+ ) -> torch.Tensor:
121
+ """
122
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
123
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
124
+ causal, a causal mask will be added.
125
+ """
126
+ input_shape = (attention_mask_2d.shape[0], query_length)
127
+
128
+ # create causal mask
129
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
130
+ causal_4d_mask = None
131
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
132
+ if key_value_length is None:
133
+ raise ValueError(
134
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
135
+ )
136
+
137
+ past_key_values_length = key_value_length - query_length
138
+ causal_4d_mask = self._make_causal_mask(
139
+ input_shape,
140
+ dtype,
141
+ device=attention_mask_2d.device,
142
+ past_key_values_length=past_key_values_length,
143
+ sliding_window=self.sliding_window,
144
+ )
145
+ elif self.sliding_window is not None:
146
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
147
+
148
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
149
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
150
+ attention_mask_2d.device
151
+ )
152
+
153
+ if causal_4d_mask is not None:
154
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
155
+
156
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
157
+ expanded_4d_mask = expanded_attn_mask
158
+
159
+ return expanded_4d_mask
160
+
161
+ @staticmethod
162
+ def _make_causal_mask(
163
+ input_ids_shape: torch.Size,
164
+ dtype: torch.dtype,
165
+ device: torch.device,
166
+ past_key_values_length: int = 0,
167
+ sliding_window: int | None = None,
168
+ ):
169
+ """
170
+ Make causal mask used for bi-directional self-attention.
171
+ """
172
+ warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
173
+
174
+ bsz, tgt_len = input_ids_shape
175
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
176
+ mask_cond = torch.arange(mask.size(-1), device=device)
177
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
178
+
179
+ mask = mask.to(dtype)
180
+
181
+ if past_key_values_length > 0:
182
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
183
+
184
+ # add lower triangular sliding window mask if necessary
185
+ if sliding_window is not None:
186
+ diagonal = past_key_values_length - sliding_window - 1
187
+
188
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
189
+ # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
190
+ # See https://github.com/pytorch/pytorch/issues/127571
191
+ if is_torchdynamo_compiling():
192
+ mask = mask.clone()
193
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
194
+
195
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
196
+
197
+ @staticmethod
198
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
199
+ """
200
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
201
+ """
202
+ warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
203
+
204
+ bsz, src_len = mask.size()
205
+ tgt_len = tgt_len if tgt_len is not None else src_len
206
+
207
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
208
+
209
+ inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask
210
+
211
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
212
+
213
+ @staticmethod
214
+ def _unmask_unattended(
215
+ expanded_mask: torch.FloatTensor,
216
+ min_dtype: float,
217
+ ):
218
+ # fmt: off
219
+ """
220
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
221
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
222
+ Details: https://github.com/pytorch/pytorch/issues/110213
223
+
224
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
225
+ `attention_mask` is [bsz, src_seq_len].
226
+
227
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
228
+
229
+ For example, if `expanded_mask` is (e.g. here left-padding case)
230
+ ```
231
+ [[[[0, 0, 0],
232
+ [0, 0, 0],
233
+ [0, 0, 1]]],
234
+ [[[1, 0, 0],
235
+ [1, 1, 0],
236
+ [1, 1, 1]]],
237
+ [[[0, 0, 0],
238
+ [0, 1, 0],
239
+ [0, 1, 1]]]]
240
+ ```
241
+ then the modified `expanded_mask` will be
242
+ ```
243
+ [[[[1, 1, 1], <-- modified
244
+ [1, 1, 1], <-- modified
245
+ [0, 0, 1]]],
246
+ [[[1, 0, 0],
247
+ [1, 1, 0],
248
+ [1, 1, 1]]],
249
+ [[[1, 1, 1], <-- modified
250
+ [0, 1, 0],
251
+ [0, 1, 1]]]]
252
+ ```
253
+ """
254
+ warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
255
+
256
+ # fmt: on
257
+ if expanded_mask.dtype == torch.bool:
258
+ raise ValueError(
259
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
260
+ )
261
+
262
+ return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
263
+
264
+ @staticmethod
265
+ def _ignore_causal_mask_sdpa(
266
+ attention_mask: torch.Tensor | None,
267
+ inputs_embeds: torch.Tensor,
268
+ past_key_values_length: int,
269
+ sliding_window: int | None = None,
270
+ is_training: bool = False,
271
+ ) -> bool:
272
+ """
273
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
274
+ ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
275
+
276
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
277
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
278
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
279
+ passed).
280
+ """
281
+ warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
282
+
283
+ _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
284
+ key_value_length = query_length + past_key_values_length
285
+
286
+ is_tracing_ = is_tracing(inputs_embeds)
287
+
288
+ ignore_causal_mask = False
289
+
290
+ if attention_mask is None:
291
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
292
+ # shape, thus SDPA's `is_causal` argument is rightfully updated
293
+ # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
294
+ # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
295
+ # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
296
+ # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
297
+ # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
298
+ #
299
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
300
+ # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
301
+ if (
302
+ (is_training or not is_tracing_)
303
+ and (query_length == 1 or key_value_length == query_length)
304
+ and (sliding_window is None or key_value_length < sliding_window)
305
+ ):
306
+ ignore_causal_mask = True
307
+ elif sliding_window is None or key_value_length < sliding_window:
308
+ if len(attention_mask.shape) == 4:
309
+ return False
310
+ elif not is_tracing_ and torch.all(attention_mask == 1):
311
+ if query_length == 1 or key_value_length == query_length:
312
+ # For query_length == 1, causal attention and bi-directional attention are the same.
313
+ ignore_causal_mask = True
314
+
315
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
316
+ # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
317
+ # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
318
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
319
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
320
+
321
+ return ignore_causal_mask
322
+
323
+
324
+ def _prepare_4d_causal_attention_mask(
325
+ attention_mask: torch.Tensor | None,
326
+ input_shape: torch.Size | tuple | list,
327
+ inputs_embeds: torch.Tensor,
328
+ past_key_values_length: int,
329
+ sliding_window: int | None = None,
330
+ ):
331
+ """
332
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
333
+ `(batch_size, key_value_length)`
334
+
335
+ Args:
336
+ attention_mask (`torch.Tensor` or `None`):
337
+ A 2D attention mask of shape `(batch_size, key_value_length)`
338
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
339
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
340
+ inputs_embeds (`torch.Tensor`):
341
+ The embedded inputs as a torch Tensor.
342
+ past_key_values_length (`int`):
343
+ The length of the key value cache.
344
+ sliding_window (`int`, *optional*):
345
+ If the model uses windowed attention, a sliding window should be passed.
346
+ """
347
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
348
+
349
+ key_value_length = input_shape[-1] + past_key_values_length
350
+
351
+ # 4d mask is passed through the layers
352
+ if attention_mask is not None and len(attention_mask.shape) == 2:
353
+ attention_mask = attn_mask_converter.to_4d(
354
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
355
+ )
356
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
357
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
358
+ if tuple(attention_mask.shape) != expected_shape:
359
+ raise ValueError(
360
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
361
+ )
362
+ else:
363
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
364
+ inverted_mask = 1.0 - attention_mask
365
+ attention_mask = inverted_mask.masked_fill(
366
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
367
+ )
368
+ else:
369
+ attention_mask = attn_mask_converter.to_causal_4d(
370
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
371
+ )
372
+
373
+ return attention_mask
374
+
375
+
376
+ # Adapted from _prepare_4d_causal_attention_mask
377
+ def _prepare_4d_causal_attention_mask_for_sdpa(
378
+ attention_mask: torch.Tensor | None,
379
+ input_shape: torch.Size | tuple | list,
380
+ inputs_embeds: torch.Tensor,
381
+ past_key_values_length: int,
382
+ sliding_window: int | None = None,
383
+ ):
384
+ """
385
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
386
+
387
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
388
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
389
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
390
+ """
391
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
392
+
393
+ key_value_length = input_shape[-1] + past_key_values_length
394
+
395
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
396
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
397
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
398
+ is_tracing_ = is_tracing(inputs_embeds)
399
+
400
+ ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
401
+ attention_mask=attention_mask,
402
+ inputs_embeds=inputs_embeds,
403
+ past_key_values_length=past_key_values_length,
404
+ sliding_window=sliding_window,
405
+ )
406
+
407
+ if ignore_causal_mask:
408
+ expanded_4d_mask = None
409
+ elif attention_mask is None:
410
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
411
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
412
+ )
413
+ else:
414
+ if attention_mask.dim() == 4:
415
+ expanded_4d_mask = attention_mask
416
+ else:
417
+ expanded_4d_mask = attn_mask_converter.to_4d(
418
+ attention_mask,
419
+ input_shape[-1],
420
+ dtype=inputs_embeds.dtype,
421
+ key_value_length=key_value_length,
422
+ )
423
+
424
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
425
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
426
+ # Details: https://github.com/pytorch/pytorch/issues/110213
427
+ if not is_tracing_ and expanded_4d_mask.device.type in ["cuda", "xpu"]:
428
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
429
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
430
+ )
431
+
432
+ return expanded_4d_mask
433
+
434
+
435
+ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
436
+ """
437
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
438
+ `(batch_size, key_value_length)`
439
+
440
+ Args:
441
+ mask (`torch.Tensor`):
442
+ A 2D attention mask of shape `(batch_size, key_value_length)`
443
+ dtype (`torch.dtype`):
444
+ The torch dtype the created mask shall have.
445
+ tgt_len (`int`):
446
+ The target length or query length the created mask shall have.
447
+ """
448
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
449
+
450
+
451
+ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
452
+ """
453
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
454
+ `(batch_size, key_value_length)`
455
+
456
+ Args:
457
+ mask (`torch.Tensor`):
458
+ A 2D attention mask of shape `(batch_size, key_value_length)`
459
+ dtype (`torch.dtype`):
460
+ The torch dtype the created mask shall have.
461
+ tgt_len (`int`):
462
+ The target length or query length the created mask shall have.
463
+ """
464
+ warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
465
+
466
+ _, key_value_length = mask.shape
467
+ tgt_len = tgt_len if tgt_len is not None else key_value_length
468
+
469
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
470
+ if not is_tracing(mask) and torch.all(mask == 1):
471
+ return None
472
+ else:
473
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
474
+
475
+
476
+ def _create_4d_causal_attention_mask(
477
+ input_shape: torch.Size | tuple | list,
478
+ dtype: torch.dtype,
479
+ device: torch.device,
480
+ past_key_values_length: int = 0,
481
+ sliding_window: int | None = None,
482
+ ) -> torch.Tensor | None:
483
+ """
484
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
485
+
486
+ Args:
487
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
488
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
489
+ dtype (`torch.dtype`):
490
+ The torch dtype the created mask shall have.
491
+ device (`int`):
492
+ The torch device the created mask shall have.
493
+ sliding_window (`int`, *optional*):
494
+ If the model uses windowed attention, a sliding window should be passed.
495
+ """
496
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
497
+
498
+ key_value_length = past_key_values_length + input_shape[-1]
499
+ attention_mask = attn_mask_converter.to_causal_4d(
500
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
501
+ )
502
+
503
+ return attention_mask
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/blt/configuration_blt.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Blt model configuration"""
15
+
16
+ from huggingface_hub.dataclasses import strict
17
+
18
+ from ...configuration_utils import PreTrainedConfig
19
+ from ...modeling_rope_utils import RopeParameters
20
+ from ...utils import auto_docstring, logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ @auto_docstring(checkpoint="itazap/blt-1b-hf")
27
+ @strict
28
+ class BltLocalEncoderConfig(PreTrainedConfig):
29
+ r"""
30
+ cross_attn_all_layers (`bool`, *optional*, defaults to `True`):
31
+ Whether all attention layers have cross attention.
32
+ cross_attn_k (`int`, *optional*, defaults to 2):
33
+ Number of cross-attention heads used in the model.
34
+ hidden_size_global (`int`, *int*, defaults to 2048):
35
+ Hidden size of the global transformer layer.
36
+ """
37
+
38
+ model_type = "blt_local_encoder"
39
+ default_theta = 500000.0
40
+
41
+ vocab_size: int = 260
42
+ cross_attn_all_layers: bool | None = False
43
+ cross_attn_k: int | None = 2
44
+ hidden_size_global: int | None = 2048
45
+ hidden_size: int = 1024
46
+ num_attention_heads: int = 16
47
+ num_key_value_heads: int | None = None
48
+ num_hidden_layers: int = 1
49
+ rms_norm_eps: float = 1e-5
50
+ dropout: float | int | None = 0.0
51
+ max_position_embeddings: int = 24576
52
+ rope_parameters: RopeParameters | dict | None = None
53
+ hidden_act: str = "silu"
54
+ intermediate_size: int | None = None
55
+ initializer_range: float = 0.02
56
+
57
+ def __post_init__(self, **kwargs):
58
+ self.num_key_value_heads = self.num_key_value_heads or self.num_attention_heads
59
+ self.intermediate_size = self.intermediate_size or int(8 * self.hidden_size / 3)
60
+ self.tie_word_embeddings = False
61
+ super().__post_init__(**kwargs)
62
+
63
+
64
+ @auto_docstring(checkpoint="itazap/blt-1b-hf")
65
+ @strict
66
+ class BltLocalDecoderConfig(PreTrainedConfig):
67
+ r"""
68
+ cross_attn_all_layers (`bool`, *optional*, defaults to `True`):
69
+ Whether all attention layers have cross attention.
70
+ cross_attn_k (`int`, *optional*, defaults to 2):
71
+ Number of cross-attention heads used in the model.
72
+ hidden_size_global (`int`, *int*, defaults to 2048):
73
+ Hidden size of the global transformer layer.
74
+ """
75
+
76
+ model_type = "blt_local_decoder"
77
+ default_theta = 500000.0
78
+
79
+ vocab_size: int = 260
80
+ cross_attn_all_layers: bool | None = True
81
+ cross_attn_k: int | None = 2
82
+ hidden_size_global: int | None = 2048
83
+ hidden_size: int = 1024
84
+ num_attention_heads: int = 16
85
+ num_key_value_heads: int | None = None
86
+ num_hidden_layers: int = 9
87
+ rms_norm_eps: float = 1e-5
88
+ dropout: float | int | None = 0.0
89
+ max_position_embeddings: int = 24576
90
+ rope_parameters: RopeParameters | dict | None = None
91
+ hidden_act: str = "silu"
92
+ intermediate_size: int = 2816
93
+ initializer_range: float = 0.02
94
+ pad_token_id: int | None = None
95
+ bos_token_id: int | None = None
96
+ eos_token_id: int | list[int] | None = None
97
+ tie_word_embeddings: bool = False
98
+
99
+ def __post_init__(self, **kwargs):
100
+ self.num_key_value_heads = self.num_key_value_heads or self.num_attention_heads
101
+ self.head_dim = self.hidden_size // self.num_attention_heads
102
+ self.intermediate_size = self.intermediate_size or int(8 * self.hidden_size / 3)
103
+ self.tie_word_embeddings = False # Force-set to False for BC
104
+ super().__post_init__(**kwargs)
105
+
106
+
107
+ @auto_docstring(checkpoint="itazap/blt-1b-hf")
108
+ @strict
109
+ class BltGlobalTransformerConfig(PreTrainedConfig):
110
+ model_type = "blt_global_transformer"
111
+ default_theta = 500000.0
112
+
113
+ hidden_size: int = 2048
114
+ num_attention_heads: int = 16
115
+ num_key_value_heads: int | None = None
116
+ num_hidden_layers: int = 25
117
+ rms_norm_eps: float = 1e-5
118
+ dropout: float | int | None = 0.0
119
+ max_position_embeddings: int = 4096
120
+ rope_parameters: RopeParameters | dict | None = None
121
+ hidden_act: str = "silu"
122
+ intermediate_size: int = 5632
123
+ initializer_range: float = 0.02
124
+ tie_word_embeddings: bool = False
125
+
126
+ def __post_init__(self, **kwargs):
127
+ self.num_key_value_heads = self.num_key_value_heads or self.num_attention_heads
128
+ self.head_dim = self.hidden_size // self.num_attention_heads
129
+ self.intermediate_size = self.intermediate_size or int(8 * self.hidden_size / 3)
130
+ self.tie_word_embeddings = False
131
+
132
+ super().__post_init__(**kwargs)
133
+
134
+
135
+ @auto_docstring(checkpoint="itazap/blt-1b-hf")
136
+ @strict
137
+ class BltPatcherConfig(PreTrainedConfig):
138
+ model_type = "blt_patcher"
139
+
140
+ vocab_size: int = 260
141
+ hidden_size: int = 768
142
+ num_hidden_layers: int = 14
143
+ num_attention_heads: int = 12
144
+ num_key_value_heads: int | None = None
145
+ max_position_embeddings: int = 8192
146
+ rms_norm_eps: float = 1e-5
147
+ dropout: float | int | None = 0.0
148
+ intermediate_size: int = 2048
149
+ rope_parameters: RopeParameters | dict | None = None
150
+ initializer_range: float = 0.02
151
+ tie_word_embeddings: bool = False
152
+
153
+ def __post_init__(self, **kwargs):
154
+ self.num_key_value_heads = self.num_key_value_heads or self.num_attention_heads
155
+ self.head_dim = self.hidden_size // self.num_attention_heads
156
+ self.intermediate_size = self.intermediate_size or int(8 * self.hidden_size / 3)
157
+ self.tie_word_embeddings = False
158
+ self.hidden_act = "silu" # Blt uses silu activation
159
+
160
+ super().__post_init__(**kwargs)
161
+
162
+
163
+ @auto_docstring(checkpoint="itazap/blt-1b-hf")
164
+ @strict
165
+ class BltConfig(PreTrainedConfig):
166
+ r"""
167
+ patch_in_forward (`bool`, *optional*, defaults to `True`):
168
+ Whether to perform patching during the forward pass.
169
+ patch_size (`int`, *optional*, defaults to 4):
170
+ Size of the patches used in the patching mechanism.
171
+ patching_mode (`str`, *optional*, defaults to `"entropy"`):
172
+ The mode used for patching, such as entropy-based patching.
173
+ patching_threshold (`float`, *optional*, defaults to 1.34):
174
+ Threshold value used for determining when to apply patches.
175
+ patching_batch_size (`int`, *optional*, defaults to 1):
176
+ Batch size used during the patching process.
177
+ max_patch_length (`int`, *optional*):
178
+ Maximum length of patches that can be generated.
179
+ cross_attn_k (`int`, *optional*, defaults to 2):
180
+ Number of cross-attention heads used in the model.
181
+ encoder_hash_byte_group_size (`list`, *optional*):
182
+ List of byte group sizes used in the encoder hash function.
183
+ encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 500002):
184
+ Vocabulary size for the encoder hash byte groups.
185
+ encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 1):
186
+ Number of hash functions used in the encoder byte grouping.
187
+ patcher_config (`BltPatcherConfig`, *optional*):
188
+ Configuration for the patcher component of the model.
189
+ global_config (`BltGlobalTransformerConfig`, *optional*):
190
+ Configuration for the global transformer component of the model.
191
+
192
+ Example:
193
+ ```python
194
+ >>> from transformers import BltModel, BltConfig
195
+
196
+ >>> # Initializing a Blt configuration
197
+ >>> configuration = BltConfig()
198
+
199
+ >>> # Initializing a model from the configuration
200
+ >>> model = BltModel(configuration)
201
+
202
+ >>> # Accessing the model configuration
203
+ >>> configuration = model.config
204
+ ```"""
205
+
206
+ model_type = "blt"
207
+ keys_to_ignore_at_inference = ["past_key_values"]
208
+ default_theta = 500000.0
209
+ sub_configs = {
210
+ "patcher_config": BltPatcherConfig,
211
+ "encoder_config": BltLocalEncoderConfig,
212
+ "decoder_config": BltLocalDecoderConfig,
213
+ "global_config": BltGlobalTransformerConfig,
214
+ }
215
+
216
+ vocab_size: int = 260
217
+ max_position_embeddings: int = 4096
218
+ patch_in_forward: bool | None = True
219
+ patch_size: int | None = 4
220
+ patching_mode: str | None = "entropy"
221
+ patching_threshold: float | None = 1.335442066192627
222
+ patching_batch_size: int | None = 1
223
+ max_patch_length: int | None = None
224
+ cross_attn_k: int | None = 2
225
+ encoder_hash_byte_group_size: list[int] | None = None
226
+ encoder_hash_byte_group_vocab: int | None = 500002
227
+ encoder_hash_byte_group_nb_functions: int | None = 1
228
+ patcher_config: dict | PreTrainedConfig | None = None
229
+ encoder_config: dict | PreTrainedConfig | None = None
230
+ decoder_config: dict | PreTrainedConfig | None = None
231
+ global_config: dict | PreTrainedConfig | None = None
232
+ tie_word_embeddings: bool = False
233
+ pad_token_id: int | None = None
234
+ bos_token_id: int | None = None
235
+ eos_token_id: int | list[int] | None = None
236
+ initializer_range: float = 0.02
237
+ rope_parameters: RopeParameters | dict | None = None
238
+
239
+ def __post_init__(self, **kwargs):
240
+ self.encoder_hash_byte_group_size = self.encoder_hash_byte_group_size or [3, 4, 5, 6, 7, 8]
241
+
242
+ # Initialize component configurations
243
+ if self.patcher_config is None:
244
+ self.patcher_config = BltPatcherConfig(initializer_range=self.initializer_range)
245
+ logger.info("patcher_config is None, using default Blt patcher config")
246
+ elif isinstance(self.patcher_config, dict):
247
+ self.patcher_config.setdefault("initializer_range", self.initializer_range)
248
+ self.patcher_config = BltPatcherConfig(**self.patcher_config)
249
+
250
+ if self.encoder_config is None:
251
+ self.encoder_config = BltLocalEncoderConfig(initializer_range=self.initializer_range)
252
+ logger.info("encoder_config is None, using default Blt encoder config")
253
+ elif isinstance(self.encoder_config, dict):
254
+ self.encoder_config.setdefault("initializer_range", self.initializer_range)
255
+ self.encoder_config = BltLocalEncoderConfig(**self.encoder_config)
256
+
257
+ if self.decoder_config is None:
258
+ self.decoder_config = BltLocalDecoderConfig(initializer_range=self.initializer_range)
259
+ logger.info("decoder_config is None, using default Blt decoder config")
260
+ elif isinstance(self.decoder_config, dict):
261
+ self.decoder_config.setdefault("initializer_range", self.initializer_range)
262
+ self.decoder_config = BltLocalDecoderConfig(**self.decoder_config)
263
+
264
+ if self.global_config is None:
265
+ self.global_config = BltGlobalTransformerConfig(initializer_range=self.initializer_range)
266
+ logger.info("global_config is None, using default Blt global config")
267
+ elif isinstance(self.global_config, dict):
268
+ self.global_config.setdefault("initializer_range", self.initializer_range)
269
+ self.global_config = BltGlobalTransformerConfig(**self.global_config)
270
+
271
+ # Determine if token embedding projection is needed based on dimension mismatch (7b)
272
+ encoder_cross_output_size = self.encoder_config.hidden_size * self.cross_attn_k
273
+ self.global_config.encoder_cross_output_size = (
274
+ encoder_cross_output_size if encoder_cross_output_size != self.global_config.hidden_size else None
275
+ )
276
+
277
+ super().__post_init__(**kwargs)
278
+
279
+
280
+ __all__ = [
281
+ "BltConfig",
282
+ "BltPatcherConfig",
283
+ "BltLocalEncoderConfig",
284
+ "BltLocalDecoderConfig",
285
+ "BltGlobalTransformerConfig",
286
+ ]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/jetmoe/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_jetmoe import *
22
+ from .modeling_jetmoe import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/jetmoe/modeling_jetmoe.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/jetmoe/modular_jetmoe.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_jetmoe.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from collections.abc import Callable
22
+ from typing import Optional
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import functional as F
27
+
28
+ from ... import initialization as init
29
+ from ...activations import ACT2FN
30
+ from ...cache_utils import Cache, DynamicCache
31
+ from ...generation import GenerationMixin
32
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
33
+ from ...masking_utils import create_causal_mask
34
+ from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
35
+ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
36
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from ...processing_utils import Unpack
39
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
40
+ from ...utils.generic import maybe_autocast, merge_with_config_defaults
41
+ from ...utils.output_capturing import OutputRecorder, capture_outputs
42
+ from .configuration_jetmoe import JetMoeConfig
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+
48
+ @use_kernel_forward_from_hub("RMSNorm")
49
+ class JetMoeRMSNorm(nn.Module):
50
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
51
+ """
52
+ JetMoeRMSNorm is equivalent to T5LayerNorm
53
+ """
54
+ super().__init__()
55
+ self.weight = nn.Parameter(torch.ones(hidden_size))
56
+ self.variance_epsilon = eps
57
+
58
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
59
+ input_dtype = hidden_states.dtype
60
+ hidden_states = hidden_states.to(torch.float32)
61
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
62
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
63
+ return self.weight * hidden_states.to(input_dtype)
64
+
65
+ def extra_repr(self):
66
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
67
+
68
+
69
+ class JetMoeRotaryEmbedding(nn.Module):
70
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
71
+
72
+ def __init__(self, config: JetMoeConfig, device=None):
73
+ super().__init__()
74
+ self.max_seq_len_cached = config.max_position_embeddings
75
+ self.original_max_seq_len = config.max_position_embeddings
76
+
77
+ self.config = config
78
+
79
+ self.rope_type = self.config.rope_parameters["rope_type"]
80
+ rope_init_fn: Callable = self.compute_default_rope_parameters
81
+ if self.rope_type != "default":
82
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
83
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
84
+
85
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
86
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
87
+
88
+ @staticmethod
89
+ def compute_default_rope_parameters(
90
+ config: JetMoeConfig | None = None,
91
+ device: Optional["torch.device"] = None,
92
+ seq_len: int | None = None,
93
+ ) -> tuple["torch.Tensor", float]:
94
+ """
95
+ Computes the inverse frequencies according to the original RoPE implementation
96
+ Args:
97
+ config ([`~transformers.PreTrainedConfig`]):
98
+ The model configuration.
99
+ device (`torch.device`):
100
+ The device to use for initialization of the inverse frequencies.
101
+ seq_len (`int`, *optional*):
102
+ The current sequence length. Unused for this type of RoPE.
103
+ Returns:
104
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
105
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
106
+ """
107
+ base = config.rope_parameters["rope_theta"]
108
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
109
+
110
+ attention_factor = 1.0 # Unused in this type of RoPE
111
+
112
+ # Compute the inverse frequencies
113
+ inv_freq = 1.0 / (
114
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
115
+ )
116
+ return inv_freq, attention_factor
117
+
118
+ @torch.no_grad()
119
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
120
+ def forward(self, x, position_ids):
121
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
122
+ position_ids_expanded = position_ids[:, None, :].float()
123
+
124
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
125
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
126
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
127
+ emb = torch.cat((freqs, freqs), dim=-1)
128
+ cos = emb.cos() * self.attention_scaling
129
+ sin = emb.sin() * self.attention_scaling
130
+
131
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
132
+
133
+
134
+ class JetMoeParallelExperts(nn.Module):
135
+ def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
136
+ """
137
+ Initialize the JetMoeParallelExperts module.
138
+ The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
139
+ many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
140
+ [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
141
+ [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
142
+ used in vllm.
143
+
144
+ Args:
145
+ num_experts (int):
146
+ Number of experts.
147
+ input_size (int):
148
+ Size of the input.
149
+ output_size (int):
150
+ Size of the output.
151
+ """
152
+ super().__init__()
153
+ self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
154
+ self.num_experts = num_experts
155
+ self.input_size = input_size
156
+ self.output_size = output_size
157
+
158
+ def forward(self, inputs, expert_size):
159
+ """
160
+ Forward pass of the JetMoeParallelExperts module.
161
+
162
+ Args:
163
+ inputs (Tensor):
164
+ Input tensor.
165
+ expert_size:
166
+ Expert size information.
167
+
168
+ Returns:
169
+ Tensor: Output tensor.
170
+ """
171
+ input_list = inputs.split(expert_size, dim=0)
172
+ output_list = []
173
+ for i in range(self.num_experts):
174
+ output_list.append(F.linear(input_list[i], self.weight[i]))
175
+ results = torch.cat(output_list, dim=0)
176
+ return results
177
+
178
+
179
+ class JetMoeTopKGating(nn.Module):
180
+ def __init__(self, input_size: int, num_experts: int, top_k: int):
181
+ """
182
+ Initialize the top-k gating mechanism.
183
+
184
+ Args:
185
+ input_size (`int`):
186
+ Size of the input.
187
+ num_experts (`int`):
188
+ Number of experts.
189
+ top_k (`int`):
190
+ Number of top experts to select.
191
+ """
192
+ super().__init__()
193
+
194
+ self.num_experts = num_experts
195
+ self.input_size = input_size
196
+ self.top_k = top_k
197
+
198
+ self.layer = nn.Linear(input_size, num_experts, bias=False)
199
+
200
+ def forward(self, hidden_states):
201
+ # compute the top_k routing decision
202
+ logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
203
+ top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
204
+ top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
205
+
206
+ # compute number of input given to each expert
207
+ zeros = torch.zeros(
208
+ [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
209
+ ) # [num_tokens, num_experts]
210
+ gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
211
+ expert_size = gates.long().sum(0) # [num_experts,]
212
+ # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
213
+ # (and `DataDependentOutputException`)
214
+ expert_size = expert_size.tolist()
215
+
216
+ # sort and group input tokens according to expert assignment
217
+ top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
218
+ _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
219
+ batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
220
+
221
+ # gather the gate values for grouped input tokens
222
+ top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
223
+ batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
224
+
225
+ return index_sorted_experts, batch_index, batch_gates, expert_size, logits
226
+
227
+
228
+ class JetMoeMoE(nn.Module):
229
+ """
230
+ A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
231
+
232
+ Args:
233
+ config:
234
+ Configuration object with model hyperparameters.
235
+ """
236
+
237
+ def __init__(self, config: JetMoeConfig):
238
+ super().__init__()
239
+
240
+ self.input_size = config.hidden_size
241
+ self.hidden_size = config.intermediate_size
242
+ self.activation = ACT2FN[config.activation_function]
243
+ self.bias = torch.nn.Parameter(torch.empty(self.input_size))
244
+ self.input_linear = JetMoeParallelExperts(config.num_local_experts, self.input_size, self.hidden_size * 2)
245
+ self.output_linear = JetMoeParallelExperts(config.num_local_experts, self.hidden_size, self.input_size)
246
+
247
+ self.router = JetMoeTopKGating(
248
+ input_size=self.input_size,
249
+ num_experts=config.num_local_experts,
250
+ top_k=config.num_experts_per_tok,
251
+ )
252
+
253
+ def forward(self, layer_input):
254
+ """
255
+ Forward pass of the mixture of experts layer.
256
+
257
+ Args:
258
+ layer_input (Tensor):
259
+ Input tensor.
260
+
261
+ Returns:
262
+ Tensor:
263
+ Output tensor.
264
+ Tensor:
265
+ Router logits.
266
+ """
267
+ bsz, length, emb_size = layer_input.size()
268
+ layer_input = layer_input.reshape(-1, emb_size)
269
+ _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
270
+
271
+ expert_inputs = layer_input[batch_index]
272
+ hidden_states = self.input_linear(expert_inputs, expert_size)
273
+ chunked_hidden_states = hidden_states.chunk(2, dim=-1)
274
+ hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
275
+ expert_outputs = self.output_linear(hidden_states, expert_size)
276
+
277
+ expert_outputs = expert_outputs * batch_gates[:, None]
278
+
279
+ zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
280
+ layer_output = zeros.index_add(0, batch_index, expert_outputs)
281
+ layer_output = layer_output.view(bsz, length, self.input_size)
282
+ layer_output = layer_output + self.bias
283
+ return layer_output
284
+
285
+
286
+ class JetMoeMoA(nn.Module):
287
+ """
288
+ A Sparsely gated mixture of attention layer with pairs of query- and output-projections as experts.
289
+
290
+ Args:
291
+ config:
292
+ Configuration object with model hyperparameters.
293
+ """
294
+
295
+ def __init__(self, config: JetMoeConfig):
296
+ super().__init__()
297
+
298
+ self.num_experts = config.num_local_experts
299
+ self.input_size = config.hidden_size
300
+ self.hidden_size = config.kv_channels * config.num_key_value_heads
301
+ self.top_k = config.num_experts_per_tok
302
+ self.bias = torch.nn.Parameter(torch.empty(self.input_size))
303
+
304
+ self.input_linear = JetMoeParallelExperts(self.num_experts, self.input_size, self.hidden_size)
305
+ self.output_linear = JetMoeParallelExperts(self.num_experts, self.hidden_size, self.input_size)
306
+
307
+ self.router = JetMoeTopKGating(
308
+ input_size=self.input_size,
309
+ num_experts=self.num_experts,
310
+ top_k=self.top_k,
311
+ )
312
+
313
+ def map(self, layer_input):
314
+ """
315
+ Map inputs to attention experts according to routing decision and compute query projection inside each experts.
316
+ """
317
+
318
+ # Compute gating topology
319
+ bsz, length, emb_size = layer_input.size()
320
+ layer_input = layer_input.reshape(-1, emb_size) # [bsz * length, emb_size]
321
+ index_sorted_experts, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
322
+ topo_info = (index_sorted_experts, batch_index, batch_gates, expert_size)
323
+
324
+ # Group inputs according to topology and compute query projection
325
+ expert_inputs = layer_input[batch_index] # [bsz * length * top_k, emb_size]
326
+ expert_outputs = self.input_linear(expert_inputs, expert_size) # [bsz * length * top_k, hidden_size]
327
+
328
+ # Ungroup queries back to original order
329
+ zeros = torch.zeros(
330
+ (bsz * length * self.top_k, self.hidden_size), dtype=expert_outputs.dtype, device=expert_outputs.device
331
+ )
332
+ layer_output = zeros.index_add(0, index_sorted_experts, expert_outputs)
333
+ layer_output = layer_output.view(bsz, length, self.top_k, -1) # [bsz, length, top_k, hidden_size]
334
+ return layer_output, router_logits, topo_info
335
+
336
+ def reduce(self, layer_input, topo_info):
337
+ """
338
+ Compute output projection inside each attention experts and merge the outputs of different experts.
339
+ """
340
+ bsz, length, k, hidden_size = layer_input.size()
341
+ layer_input = layer_input.reshape(-1, hidden_size) # [bsz * length * k, hidden_size]
342
+ index_sorted_experts, batch_index, batch_gates, expert_size = topo_info
343
+
344
+ # Group inputs according to topology and compute output projection
345
+ expert_inputs = layer_input[index_sorted_experts] # [bsz * length * top_k, hidden_size]
346
+ expert_outputs = self.output_linear(expert_inputs, expert_size) # [bsz * length * top_k, emb_size]
347
+
348
+ # Apply gates to attention expert outputs
349
+ expert_outputs = expert_outputs * batch_gates[:, None]
350
+
351
+ # Ungroup and merge outputs to original order
352
+ zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
353
+ layer_output = zeros.index_add(0, batch_index, expert_outputs)
354
+ layer_output = layer_output.view(bsz, length, self.input_size)
355
+ layer_output = layer_output + self.bias
356
+ return layer_output
357
+
358
+ def forward(self, layer_input):
359
+ raise NotImplementedError("This module doesn't support call and forward.")
360
+
361
+
362
+ def rotate_half(x):
363
+ """Rotates half the hidden dims of the input."""
364
+ x1 = x[..., : x.shape[-1] // 2]
365
+ x2 = x[..., x.shape[-1] // 2 :]
366
+ return torch.cat((-x2, x1), dim=-1)
367
+
368
+
369
+ @use_kernel_func_from_hub("rotary_pos_emb")
370
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
371
+ """Applies Rotary Position Embedding to the query and key tensors.
372
+
373
+ Args:
374
+ q (`torch.Tensor`): The query tensor.
375
+ k (`torch.Tensor`): The key tensor.
376
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
377
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
378
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
379
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
380
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
381
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
382
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
383
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
384
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
385
+ Returns:
386
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
387
+ """
388
+ cos = cos.unsqueeze(unsqueeze_dim)
389
+ sin = sin.unsqueeze(unsqueeze_dim)
390
+ q_embed = (q * cos) + (rotate_half(q) * sin)
391
+ k_embed = (k * cos) + (rotate_half(k) * sin)
392
+ return q_embed, k_embed
393
+
394
+
395
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
396
+ """
397
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
398
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
399
+ """
400
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
401
+ if n_rep == 1:
402
+ return hidden_states
403
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
404
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
405
+
406
+
407
+ def eager_attention_forward(
408
+ module: nn.Module,
409
+ query: torch.Tensor,
410
+ key: torch.Tensor,
411
+ value: torch.Tensor,
412
+ attention_mask: torch.Tensor | None,
413
+ scaling: float,
414
+ dropout: float = 0.0,
415
+ **kwargs: Unpack[TransformersKwargs],
416
+ ):
417
+ key_states = repeat_kv(key, module.num_key_value_groups)
418
+ value_states = repeat_kv(value, module.num_key_value_groups)
419
+
420
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
421
+ if attention_mask is not None:
422
+ attn_weights = attn_weights + attention_mask
423
+
424
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
425
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
426
+ attn_output = torch.matmul(attn_weights, value_states)
427
+ attn_output = attn_output.transpose(1, 2).contiguous()
428
+
429
+ return attn_output, attn_weights
430
+
431
+
432
+ class JetMoeAttention(nn.Module):
433
+ """
434
+ Multi-headed attention from 'Attention Is All You Need' paper.
435
+ """
436
+
437
+ def __init__(self, config: JetMoeConfig, layer_idx: int | None = None):
438
+ """
439
+ Initialize the JetMoeAttention module.
440
+
441
+ Args:
442
+ config:
443
+ Configuration object with model hyperparameters.
444
+ layer_idx:
445
+ Index of the layer in the model.
446
+ """
447
+ super().__init__()
448
+ self.config = config
449
+ self.layer_idx = layer_idx
450
+ self.is_causal = True
451
+ if layer_idx is None:
452
+ logger.warning_once(
453
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
454
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
455
+ "when creating this class."
456
+ )
457
+
458
+ self.num_key_value_groups = 1 # We ignore this by setting it to 1 as we have different repeat patterns
459
+ self.top_k = config.num_experts_per_tok
460
+ self.attention_dropout = config.attention_dropout
461
+ self.kv_projection_size = config.kv_channels * config.num_key_value_heads
462
+ self.num_key_value_heads = config.num_key_value_heads
463
+ self.num_heads = config.num_attention_heads
464
+ self.head_dim = config.kv_channels
465
+ self.scaling = self.head_dim**-0.5
466
+ self.experts = JetMoeMoA(config)
467
+
468
+ self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False)
469
+
470
+ def forward(
471
+ self,
472
+ hidden_states: torch.Tensor,
473
+ attention_mask: torch.Tensor | None = None,
474
+ position_embeddings: torch.LongTensor | None = None,
475
+ past_key_values: Cache | None = None,
476
+ **kwargs,
477
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
478
+ input_shape = hidden_states.shape[:-1]
479
+ hidden_shape = (*input_shape, -1, self.head_dim)
480
+
481
+ query_states, router_logits, topo_info = self.experts.map(hidden_states)
482
+ key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
483
+
484
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
485
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
486
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
487
+
488
+ cos, sin = position_embeddings
489
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
490
+
491
+ if past_key_values is not None:
492
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
493
+
494
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
495
+ self.config._attn_implementation, eager_attention_forward
496
+ )
497
+
498
+ # This is different from other models where we repeat k/v heads
499
+ # instead of repeat interleaving them
500
+ key_states = key_states.repeat(1, self.top_k, 1, 1)
501
+ value_states = value_states.repeat(1, self.top_k, 1, 1)
502
+
503
+ attn_output, attn_weights = attention_interface(
504
+ self,
505
+ query_states,
506
+ key_states,
507
+ value_states,
508
+ attention_mask,
509
+ dropout=0.0 if not self.training else self.attention_dropout,
510
+ scaling=self.scaling,
511
+ **kwargs,
512
+ )
513
+
514
+ attn_output = attn_output.view(*input_shape, self.top_k, -1)
515
+ attn_output = self.experts.reduce(attn_output, topo_info)
516
+ attn_output = attn_output.view(*input_shape, -1)
517
+ return attn_output, attn_weights, router_logits
518
+
519
+
520
+ class JetMoeDecoderLayer(GradientCheckpointingLayer):
521
+ def __init__(self, config: JetMoeConfig, layer_idx: int | None = None):
522
+ super().__init__()
523
+ self.hidden_size = config.hidden_size
524
+ self.mlp = JetMoeMoE(config)
525
+ self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
526
+ self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
527
+ self.self_attention = JetMoeAttention(config, layer_idx)
528
+
529
+ def forward(
530
+ self,
531
+ hidden_states: torch.Tensor,
532
+ attention_mask: torch.Tensor | None = None,
533
+ position_ids: torch.LongTensor | None = None,
534
+ past_key_values: Cache | None = None,
535
+ use_cache: bool | None = False,
536
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
537
+ **kwargs: Unpack[TransformersKwargs],
538
+ ) -> torch.Tensor:
539
+ residual = hidden_states
540
+ hidden_states = self.input_layernorm(hidden_states)
541
+ # Self Attention
542
+ hidden_states, _, _ = self.self_attention(
543
+ hidden_states=hidden_states,
544
+ attention_mask=attention_mask,
545
+ position_ids=position_ids,
546
+ past_key_values=past_key_values,
547
+ use_cache=use_cache,
548
+ position_embeddings=position_embeddings,
549
+ **kwargs,
550
+ )
551
+ hidden_states = residual + hidden_states
552
+
553
+ # Fully Connected
554
+ residual = hidden_states
555
+ hidden_states = self.post_attention_layernorm(hidden_states)
556
+ hidden_states = self.mlp(hidden_states)
557
+ hidden_states = residual + hidden_states
558
+ return hidden_states
559
+
560
+
561
+ @auto_docstring
562
+ class JetMoePreTrainedModel(PreTrainedModel):
563
+ config: JetMoeConfig
564
+ base_model_prefix = "model"
565
+ supports_gradient_checkpointing = False
566
+ _no_split_modules = ["JetMoeDecoderLayer"]
567
+ _skip_keys_device_placement = ["past_key_values"]
568
+ _supports_flash_attn = True
569
+ _supports_sdpa = True
570
+ _supports_flex_attn = True
571
+ _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
572
+ _supports_attention_backend = True
573
+ _can_record_outputs = {
574
+ "router_logits": [OutputRecorder(JetMoeAttention, index=2), OutputRecorder(JetMoeTopKGating, index=4)],
575
+ "hidden_states": JetMoeDecoderLayer,
576
+ "attentions": OutputRecorder(JetMoeAttention, index=1),
577
+ }
578
+
579
+ @torch.no_grad()
580
+ def _init_weights(self, module):
581
+ """Initialize the weights."""
582
+ super()._init_weights(module)
583
+ if isinstance(module, JetMoeParallelExperts):
584
+ init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
585
+ elif isinstance(module, JetMoeMoA | JetMoeMoE):
586
+ init.zeros_(module.bias)
587
+
588
+
589
+ @auto_docstring
590
+ class JetMoeModel(JetMoePreTrainedModel):
591
+ def __init__(self, config: JetMoeConfig):
592
+ super().__init__(config)
593
+ self.padding_idx = config.pad_token_id
594
+ self.vocab_size = config.vocab_size
595
+
596
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
597
+ self.layers = nn.ModuleList(
598
+ [JetMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
599
+ )
600
+ self.norm = JetMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
601
+ self.rotary_emb = JetMoeRotaryEmbedding(config=config)
602
+ self.gradient_checkpointing = False
603
+ self._attn_implementation = config._attn_implementation
604
+
605
+ # Initialize weights and apply final processing
606
+ self.post_init()
607
+
608
+ @merge_with_config_defaults
609
+ @capture_outputs
610
+ @auto_docstring
611
+ def forward(
612
+ self,
613
+ input_ids: torch.LongTensor | None = None,
614
+ attention_mask: torch.Tensor | None = None,
615
+ position_ids: torch.LongTensor | None = None,
616
+ past_key_values: Cache | None = None,
617
+ inputs_embeds: torch.FloatTensor | None = None,
618
+ use_cache: bool | None = None,
619
+ **kwargs: Unpack[TransformersKwargs],
620
+ ) -> MoeModelOutputWithPast:
621
+ if (input_ids is None) ^ (inputs_embeds is not None):
622
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
623
+
624
+ if use_cache and past_key_values is None:
625
+ past_key_values = DynamicCache(config=self.config)
626
+
627
+ if inputs_embeds is None:
628
+ inputs_embeds = self.embed_tokens(input_ids)
629
+
630
+ if position_ids is None:
631
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
632
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
633
+ position_ids = position_ids.unsqueeze(0)
634
+
635
+ causal_mask = create_causal_mask(
636
+ config=self.config,
637
+ inputs_embeds=inputs_embeds,
638
+ attention_mask=attention_mask,
639
+ past_key_values=past_key_values,
640
+ position_ids=position_ids,
641
+ )
642
+
643
+ hidden_states = inputs_embeds
644
+
645
+ # create position embeddings to be shared across the decoder layers
646
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
647
+
648
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
649
+ hidden_states = decoder_layer(
650
+ hidden_states,
651
+ position_embeddings=position_embeddings,
652
+ attention_mask=causal_mask,
653
+ past_key_values=past_key_values,
654
+ use_cache=use_cache,
655
+ position_ids=position_ids,
656
+ **kwargs,
657
+ )
658
+
659
+ hidden_states = self.norm(hidden_states)
660
+
661
+ return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
662
+ last_hidden_state=hidden_states,
663
+ past_key_values=past_key_values,
664
+ )
665
+
666
+
667
+ def load_balancing_loss_func(
668
+ gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
669
+ num_experts: int | None = None,
670
+ top_k=2,
671
+ attention_mask: torch.Tensor | None = None,
672
+ ) -> torch.Tensor | int:
673
+ r"""
674
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
675
+
676
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
677
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
678
+ experts is too unbalanced.
679
+
680
+ Args:
681
+ gate_logits:
682
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
683
+ shape [batch_size X sequence_length, num_experts].
684
+ num_experts:
685
+ Number of experts
686
+ top_k:
687
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
688
+ parameter.
689
+ attention_mask (`torch.Tensor`, *optional*):
690
+ The attention_mask used in forward function
691
+ shape [batch_size X sequence_length] if not None.
692
+
693
+ Returns:
694
+ The auxiliary loss.
695
+ """
696
+ if gate_logits is None or not isinstance(gate_logits, tuple):
697
+ return 0
698
+
699
+ if isinstance(gate_logits, tuple):
700
+ compute_device = gate_logits[0].device
701
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
702
+
703
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
704
+
705
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
706
+
707
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
708
+
709
+ if attention_mask is None:
710
+ # Compute the percentage of tokens routed to each experts
711
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
712
+
713
+ # Compute the average probability of routing to these experts
714
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
715
+ else:
716
+ batch_size, sequence_length = attention_mask.shape
717
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
718
+
719
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
720
+ expert_attention_mask = (
721
+ attention_mask[None, :, :, None, None]
722
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
723
+ .reshape(-1, top_k, num_experts)
724
+ .to(compute_device)
725
+ )
726
+
727
+ # Compute the percentage of tokens routed to each experts
728
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
729
+ expert_attention_mask, dim=0
730
+ )
731
+
732
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
733
+ router_per_expert_attention_mask = (
734
+ attention_mask[None, :, :, None]
735
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
736
+ .reshape(-1, num_experts)
737
+ .to(compute_device)
738
+ )
739
+
740
+ # Compute the average probability of routing to these experts
741
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
742
+ router_per_expert_attention_mask, dim=0
743
+ )
744
+
745
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
746
+ return overall_loss * num_experts
747
+
748
+
749
+ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin):
750
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
751
+
752
+ def __init__(self, config):
753
+ super().__init__(config)
754
+ self.model = JetMoeModel(config)
755
+ self.vocab_size = config.vocab_size
756
+ self.aux_loss_coef = config.aux_loss_coef
757
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
758
+ self.tie_word_embeddings = config.tie_word_embeddings
759
+ self.num_experts = config.num_local_experts
760
+ self.num_experts_per_tok = config.num_experts_per_tok
761
+
762
+ # Initialize weights and apply final processing
763
+ self.post_init()
764
+
765
+ @can_return_tuple
766
+ @auto_docstring
767
+ def forward(
768
+ self,
769
+ input_ids: torch.LongTensor | None = None,
770
+ attention_mask: torch.Tensor | None = None,
771
+ position_ids: torch.LongTensor | None = None,
772
+ past_key_values: Cache | None = None,
773
+ inputs_embeds: torch.FloatTensor | None = None,
774
+ labels: torch.LongTensor | None = None,
775
+ use_cache: bool | None = None,
776
+ logits_to_keep: int | torch.Tensor = 0,
777
+ output_router_logits: bool | None = False,
778
+ **kwargs,
779
+ ) -> MoeCausalLMOutputWithPast:
780
+ outputs: MoeModelOutputWithPast = self.model(
781
+ input_ids=input_ids,
782
+ attention_mask=attention_mask,
783
+ position_ids=position_ids,
784
+ past_key_values=past_key_values,
785
+ inputs_embeds=inputs_embeds,
786
+ use_cache=use_cache,
787
+ output_router_logits=output_router_logits,
788
+ **kwargs,
789
+ )
790
+
791
+ hidden_states = outputs.last_hidden_state
792
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
793
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
794
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
795
+
796
+ loss = None
797
+ if labels is not None:
798
+ loss = self.loss_function(
799
+ logits,
800
+ labels,
801
+ vocab_size=self.config.vocab_size,
802
+ **kwargs,
803
+ )
804
+
805
+ aux_loss = None
806
+ if output_router_logits:
807
+ aux_loss = load_balancing_loss_func(
808
+ outputs.router_logits,
809
+ self.num_experts,
810
+ self.num_experts_per_tok,
811
+ attention_mask,
812
+ )
813
+ if labels is not None:
814
+ loss += self.aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
815
+
816
+ return MoeCausalLMOutputWithPast(
817
+ loss=loss,
818
+ aux_loss=aux_loss,
819
+ logits=logits,
820
+ past_key_values=outputs.past_key_values,
821
+ hidden_states=outputs.hidden_states,
822
+ attentions=outputs.attentions,
823
+ router_logits=outputs.router_logits,
824
+ )
825
+
826
+
827
+ class JetMoeForSequenceClassification(GenericForSequenceClassification, JetMoePreTrainedModel): ...
828
+
829
+
830
+ __all__ = ["JetMoeForCausalLM", "JetMoeModel", "JetMoePreTrainedModel", "JetMoeForSequenceClassification"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/vitmatte/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_vitmatte import *
22
+ from .image_processing_pil_vitmatte import *
23
+ from .image_processing_vitmatte import *
24
+ from .modeling_vitmatte import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/vitmatte/image_processing_pil_vitmatte.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Image processor class for ViTMatte."""
15
+
16
+ import numpy as np
17
+
18
+ from ...image_processing_backends import PilBackend
19
+ from ...image_processing_utils import BatchFeature
20
+ from ...image_transforms import PaddingMode
21
+ from ...image_transforms import pad as np_pad
22
+ from ...image_utils import (
23
+ IMAGENET_STANDARD_MEAN,
24
+ IMAGENET_STANDARD_STD,
25
+ ChannelDimension,
26
+ ImageInput,
27
+ get_image_size,
28
+ )
29
+ from ...processing_utils import ImagesKwargs, Unpack
30
+ from ...utils import TensorType, auto_docstring, is_torch_available
31
+
32
+
33
+ if is_torch_available():
34
+ pass
35
+
36
+
37
+ # Adapted from transformers.models.vitmatte.image_processing_vitmatte.VitMatteImageProcessorKwargs
38
+ class VitMatteImageProcessorKwargs(ImagesKwargs, total=False):
39
+ r"""
40
+ size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
41
+ The width and height of the image will be padded to be divisible by this number.
42
+ """
43
+
44
+ size_divisor: int
45
+
46
+
47
+ @auto_docstring
48
+ class VitMatteImageProcessorPil(PilBackend):
49
+ do_rescale = True
50
+ rescale_factor = 1 / 255
51
+ do_normalize = True
52
+ image_mean = IMAGENET_STANDARD_MEAN
53
+ image_std = IMAGENET_STANDARD_STD
54
+ do_pad = True
55
+ size_divisor = 32
56
+ valid_kwargs = VitMatteImageProcessorKwargs
57
+
58
+ def __init__(self, **kwargs: Unpack[VitMatteImageProcessorKwargs]) -> None:
59
+ size_divisibility = kwargs.pop("size_divisibility", None)
60
+ if size_divisibility is not None:
61
+ kwargs.setdefault("size_divisor", size_divisibility)
62
+ super().__init__(**kwargs)
63
+
64
+ def pad_image(
65
+ self,
66
+ image: np.ndarray,
67
+ size_divisor: int = 32,
68
+ ) -> np.ndarray:
69
+ """
70
+ Args:
71
+ image (`np.ndarray`):
72
+ Image to pad.
73
+ size_divisor (`int`, *optional*, defaults to 32):
74
+ The width and height of the image will be padded to be divisible by this number.
75
+ """
76
+ height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
77
+
78
+ pad_height = 0 if height % size_divisor == 0 else size_divisor - height % size_divisor
79
+ pad_width = 0 if width % size_divisor == 0 else size_divisor - width % size_divisor
80
+ if pad_width + pad_height > 0:
81
+ padding = ((0, pad_height), (0, pad_width))
82
+ image = np_pad(
83
+ image,
84
+ padding=padding,
85
+ mode=PaddingMode.CONSTANT,
86
+ constant_values=0,
87
+ data_format=ChannelDimension.FIRST,
88
+ input_data_format=ChannelDimension.FIRST,
89
+ )
90
+
91
+ return image
92
+
93
+ @auto_docstring
94
+ def preprocess(
95
+ self,
96
+ images: ImageInput,
97
+ trimaps: ImageInput,
98
+ **kwargs: Unpack[VitMatteImageProcessorKwargs],
99
+ ) -> BatchFeature:
100
+ r"""
101
+ trimaps (`ImageInput`):
102
+ The trimaps to preprocess.
103
+ """
104
+ return super().preprocess(images, trimaps, **kwargs)
105
+
106
+ def _preprocess_image_like_inputs(
107
+ self,
108
+ images: ImageInput,
109
+ trimaps: ImageInput,
110
+ do_convert_rgb: bool,
111
+ input_data_format: ChannelDimension,
112
+ device: str | None = None,
113
+ **kwargs: Unpack[VitMatteImageProcessorKwargs],
114
+ ) -> BatchFeature:
115
+ """
116
+ Preprocess image-like inputs.
117
+ """
118
+ images = self._prepare_image_like_inputs(
119
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
120
+ )
121
+ trimaps = self._prepare_image_like_inputs(images=trimaps, expected_ndims=2, device=device)
122
+
123
+ return self._preprocess(images, trimaps, **kwargs)
124
+
125
+ def _preprocess(
126
+ self,
127
+ images: list[np.ndarray],
128
+ trimaps: list[np.ndarray],
129
+ do_rescale: bool,
130
+ rescale_factor: float,
131
+ do_normalize: bool,
132
+ image_mean: float | list[float] | None,
133
+ image_std: float | list[float] | None,
134
+ do_pad: bool | None,
135
+ size_divisor: int | None,
136
+ return_tensors: str | TensorType | None,
137
+ **kwargs,
138
+ ) -> BatchFeature:
139
+ processed_images = []
140
+ for image, trimap in zip(images, trimaps):
141
+ if do_rescale:
142
+ image = self.rescale(image, rescale_factor)
143
+ trimap = self.rescale(trimap, rescale_factor)
144
+ if do_normalize:
145
+ image = self.normalize(image, image_mean, image_std)
146
+ # Concatenate images and trimaps along channel dimension
147
+ # trimap is already (1, H, W) from _prepare_image_like_inputs with expected_ndims=2
148
+ if trimap.ndim == 3 and trimap.shape[0] == 1:
149
+ image = np.concatenate([image, trimap], axis=0)
150
+ else:
151
+ image = np.concatenate([image, np.expand_dims(trimap, axis=0)], axis=0)
152
+ if do_pad:
153
+ image = self.pad_image(image, size_divisor)
154
+ processed_images.append(image)
155
+
156
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
157
+
158
+
159
+ __all__ = ["VitMatteImageProcessorPil"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/optimization.py ADDED
@@ -0,0 +1,1342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyTorch optimization for BERT model."""
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ import warnings
20
+ from functools import partial
21
+ from typing import Any
22
+
23
+ import torch
24
+ from torch.optim import Optimizer
25
+ from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
26
+
27
+ from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
28
+ from .trainer_utils import SchedulerType
29
+ from .utils import logging
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ def _get_constant_lambda(_=None):
36
+ return 1
37
+
38
+
39
+ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
40
+ """
41
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
42
+
43
+ Args:
44
+ optimizer ([`~torch.optim.Optimizer`]):
45
+ The optimizer for which to schedule the learning rate.
46
+ last_epoch (`int`, *optional*, defaults to -1):
47
+ The index of the last epoch when resuming training.
48
+
49
+ Return:
50
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
51
+ """
52
+
53
+ return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)
54
+
55
+
56
+ def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
57
+ """
58
+ Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
59
+
60
+ Args:
61
+ optimizer ([`~torch.optim.Optimizer`]):
62
+ The optimizer for which to schedule the learning rate.
63
+ kwargs (`dict`, *optional*):
64
+ Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
65
+ for possible parameters.
66
+
67
+ Return:
68
+ `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
69
+ """
70
+
71
+ return ReduceLROnPlateau(optimizer, **kwargs)
72
+
73
+
74
+ def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
75
+ if current_step < num_warmup_steps:
76
+ return float(current_step) / float(max(1.0, num_warmup_steps))
77
+ return 1.0
78
+
79
+
80
+ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
81
+ """
82
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
83
+ increases linearly between 0 and the initial lr set in the optimizer.
84
+
85
+ Args:
86
+ optimizer ([`~torch.optim.Optimizer`]):
87
+ The optimizer for which to schedule the learning rate.
88
+ num_warmup_steps (`int`):
89
+ The number of steps for the warmup phase.
90
+ last_epoch (`int`, *optional*, defaults to -1):
91
+ The index of the last epoch when resuming training.
92
+
93
+ Return:
94
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
95
+ """
96
+
97
+ lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
98
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
99
+
100
+
101
+ def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
102
+ if current_step < num_warmup_steps:
103
+ return float(current_step) / float(max(1, num_warmup_steps))
104
+ return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
105
+
106
+
107
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
108
+ """
109
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
110
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
111
+
112
+ Args:
113
+ optimizer ([`~torch.optim.Optimizer`]):
114
+ The optimizer for which to schedule the learning rate.
115
+ num_warmup_steps (`int`):
116
+ The number of steps for the warmup phase.
117
+ num_training_steps (`int`):
118
+ The total number of training steps.
119
+ last_epoch (`int`, *optional*, defaults to -1):
120
+ The index of the last epoch when resuming training.
121
+
122
+ Return:
123
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
124
+ """
125
+
126
+ lr_lambda = partial(
127
+ _get_linear_schedule_with_warmup_lr_lambda,
128
+ num_warmup_steps=num_warmup_steps,
129
+ num_training_steps=num_training_steps,
130
+ )
131
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
132
+
133
+
134
+ def _get_cosine_schedule_with_warmup_lr_lambda(
135
+ current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
136
+ ):
137
+ if current_step < num_warmup_steps:
138
+ return float(current_step) / float(max(1, num_warmup_steps))
139
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
140
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
141
+
142
+
143
+ def get_cosine_schedule_with_warmup(
144
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
145
+ ):
146
+ """
147
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
148
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
149
+ initial lr set in the optimizer.
150
+
151
+ Args:
152
+ optimizer ([`~torch.optim.Optimizer`]):
153
+ The optimizer for which to schedule the learning rate.
154
+ num_warmup_steps (`int`):
155
+ The number of steps for the warmup phase.
156
+ num_training_steps (`int`):
157
+ The total number of training steps.
158
+ num_cycles (`float`, *optional*, defaults to 0.5):
159
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
160
+ following a half-cosine).
161
+ last_epoch (`int`, *optional*, defaults to -1):
162
+ The index of the last epoch when resuming training.
163
+
164
+ Return:
165
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
166
+ """
167
+
168
+ lr_lambda = partial(
169
+ _get_cosine_schedule_with_warmup_lr_lambda,
170
+ num_warmup_steps=num_warmup_steps,
171
+ num_training_steps=num_training_steps,
172
+ num_cycles=num_cycles,
173
+ )
174
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
175
+
176
+
177
+ def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda(
178
+ current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int
179
+ ):
180
+ if current_step < num_warmup_steps:
181
+ return float(current_step) / float(max(1, num_warmup_steps))
182
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
183
+ if progress >= 1.0:
184
+ return 0.0
185
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
186
+
187
+
188
+ def get_cosine_with_hard_restarts_schedule_with_warmup(
189
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
190
+ ):
191
+ """
192
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
193
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
194
+ linearly between 0 and the initial lr set in the optimizer.
195
+
196
+ Args:
197
+ optimizer ([`~torch.optim.Optimizer`]):
198
+ The optimizer for which to schedule the learning rate.
199
+ num_warmup_steps (`int`):
200
+ The number of steps for the warmup phase.
201
+ num_training_steps (`int`):
202
+ The total number of training steps.
203
+ num_cycles (`int`, *optional*, defaults to 1):
204
+ The number of hard restarts to use.
205
+ last_epoch (`int`, *optional*, defaults to -1):
206
+ The index of the last epoch when resuming training.
207
+
208
+ Return:
209
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
210
+ """
211
+
212
+ lr_lambda = partial(
213
+ _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda,
214
+ num_warmup_steps=num_warmup_steps,
215
+ num_training_steps=num_training_steps,
216
+ num_cycles=num_cycles,
217
+ )
218
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
219
+
220
+
221
+ def _get_polynomial_decay_schedule_with_warmup_lr_lambda(
222
+ current_step: int,
223
+ *,
224
+ num_warmup_steps: int,
225
+ num_training_steps: int,
226
+ lr_end: float,
227
+ power: float,
228
+ lr_init: int,
229
+ ):
230
+ if current_step < num_warmup_steps:
231
+ return float(current_step) / float(max(1, num_warmup_steps))
232
+ elif current_step > num_training_steps:
233
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
234
+ else:
235
+ lr_range = lr_init - lr_end
236
+ decay_steps = num_training_steps - num_warmup_steps
237
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
238
+ decay = lr_range * pct_remaining**power + lr_end
239
+ return decay / lr_init # as LambdaLR multiplies by lr_init
240
+
241
+
242
+ def get_polynomial_decay_schedule_with_warmup(
243
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
244
+ ):
245
+ """
246
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
247
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
248
+ initial lr set in the optimizer.
249
+
250
+ Args:
251
+ optimizer ([`~torch.optim.Optimizer`]):
252
+ The optimizer for which to schedule the learning rate.
253
+ num_warmup_steps (`int`):
254
+ The number of steps for the warmup phase.
255
+ num_training_steps (`int`):
256
+ The total number of training steps.
257
+ lr_end (`float`, *optional*, defaults to 1e-7):
258
+ The end LR.
259
+ power (`float`, *optional*, defaults to 1.0):
260
+ Power factor.
261
+ last_epoch (`int`, *optional*, defaults to -1):
262
+ The index of the last epoch when resuming training.
263
+
264
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
265
+ implementation at
266
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
267
+
268
+ Return:
269
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
270
+
271
+ """
272
+
273
+ lr_init = optimizer.defaults["lr"]
274
+ if not (lr_init > lr_end):
275
+ raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
276
+
277
+ lr_lambda = partial(
278
+ _get_polynomial_decay_schedule_with_warmup_lr_lambda,
279
+ num_warmup_steps=num_warmup_steps,
280
+ num_training_steps=num_training_steps,
281
+ lr_end=lr_end,
282
+ power=power,
283
+ lr_init=lr_init,
284
+ )
285
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
286
+
287
+
288
+ def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int | None = None):
289
+ if current_step < num_warmup_steps:
290
+ return float(current_step) / float(max(1, num_warmup_steps))
291
+ shift = timescale - num_warmup_steps
292
+ decay = 1.0 / math.sqrt((current_step + shift) / timescale)
293
+ return decay
294
+
295
+
296
+ def get_inverse_sqrt_schedule(
297
+ optimizer: Optimizer, num_warmup_steps: int, timescale: int | None = None, last_epoch: int = -1
298
+ ):
299
+ """
300
+ Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
301
+ warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
302
+
303
+ Args:
304
+ optimizer ([`~torch.optim.Optimizer`]):
305
+ The optimizer for which to schedule the learning rate.
306
+ num_warmup_steps (`int`):
307
+ The number of steps for the warmup phase.
308
+ timescale (`int`, *optional*, defaults to `num_warmup_steps`):
309
+ Time scale.
310
+ last_epoch (`int`, *optional*, defaults to -1):
311
+ The index of the last epoch when resuming training.
312
+
313
+ Return:
314
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
315
+ """
316
+ # Note: this implementation is adapted from
317
+ # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
318
+
319
+ if timescale is None:
320
+ timescale = num_warmup_steps or 10_000
321
+
322
+ lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale)
323
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
324
+
325
+
326
+ def _get_cosine_schedule_with_warmup_lr_lambda(
327
+ current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
328
+ ):
329
+ if current_step < num_warmup_steps:
330
+ return float(current_step) / float(max(1, num_warmup_steps))
331
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
332
+ factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
333
+ factor = factor * (1 - min_lr_rate) + min_lr_rate
334
+ return max(0, factor)
335
+
336
+
337
+ def get_cosine_with_min_lr_schedule_with_warmup(
338
+ optimizer: Optimizer,
339
+ num_warmup_steps: int,
340
+ num_training_steps: int,
341
+ num_cycles: float = 0.5,
342
+ last_epoch: int = -1,
343
+ min_lr: float | None = None,
344
+ min_lr_rate: float | None = None,
345
+ ):
346
+ """
347
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
348
+ initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
349
+ initial lr set in the optimizer.
350
+
351
+ Args:
352
+ optimizer ([`~torch.optim.Optimizer`]):
353
+ The optimizer for which to schedule the learning rate.
354
+ num_warmup_steps (`int`):
355
+ The number of steps for the warmup phase.
356
+ num_training_steps (`int`):
357
+ The total number of training steps.
358
+ num_cycles (`float`, *optional*, defaults to 0.5):
359
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
360
+ following a half-cosine).
361
+ last_epoch (`int`, *optional*, defaults to -1):
362
+ The index of the last epoch when resuming training.
363
+ min_lr (`float`, *optional*):
364
+ The minimum learning rate to reach after the cosine schedule.
365
+ min_lr_rate (`float`, *optional*):
366
+ The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
367
+
368
+ Return:
369
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
370
+ """
371
+
372
+ if min_lr is not None and min_lr_rate is not None:
373
+ raise ValueError("Only one of min_lr or min_lr_rate should be set")
374
+ elif min_lr is not None:
375
+ min_lr_rate = min_lr / optimizer.defaults["lr"]
376
+ elif min_lr_rate is None:
377
+ raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
378
+
379
+ lr_lambda = partial(
380
+ _get_cosine_schedule_with_warmup_lr_lambda,
381
+ num_warmup_steps=num_warmup_steps,
382
+ num_training_steps=num_training_steps,
383
+ num_cycles=num_cycles,
384
+ min_lr_rate=min_lr_rate,
385
+ )
386
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
387
+
388
+
389
+ def _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda(
390
+ current_step: int,
391
+ *,
392
+ num_warmup_steps: int,
393
+ num_training_steps: int,
394
+ num_cycles: float,
395
+ min_lr_rate: float = 0.0,
396
+ warmup_lr_rate: float | None = None,
397
+ ):
398
+ current_step = float(current_step)
399
+ num_warmup_steps = float(num_warmup_steps)
400
+ num_training_steps = float(num_training_steps)
401
+
402
+ if current_step < num_warmup_steps:
403
+ if warmup_lr_rate is None:
404
+ return (current_step + 1.0) / max(1.0, num_warmup_steps)
405
+ else:
406
+ warmup_lr_rate = float(warmup_lr_rate)
407
+ return warmup_lr_rate + (1.0 - warmup_lr_rate) * (current_step) / (max(1, num_warmup_steps - 1))
408
+ progress = (current_step - num_warmup_steps + 1.0) / (max(1.0, num_training_steps - num_warmup_steps))
409
+ factor = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))
410
+ factor = factor * (1 - min_lr_rate) + min_lr_rate
411
+ return max(0, factor)
412
+
413
+
414
+ def get_cosine_with_min_lr_schedule_with_warmup_lr_rate(
415
+ optimizer: Optimizer,
416
+ num_warmup_steps: int,
417
+ num_training_steps: int,
418
+ num_cycles: float = 0.5,
419
+ last_epoch: int = -1,
420
+ min_lr: float | None = None,
421
+ min_lr_rate: float | None = None,
422
+ warmup_lr_rate: float | None = None,
423
+ ):
424
+ """
425
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
426
+ initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
427
+ initial lr set in the optimizer.
428
+
429
+ Args:
430
+ optimizer ([`~torch.optim.Optimizer`]):
431
+ The optimizer for which to schedule the learning rate.
432
+ num_warmup_steps (`int`):
433
+ The number of steps for the warmup phase.
434
+ num_training_steps (`int`):
435
+ The total number of training steps.
436
+ num_cycles (`float`, *optional*, defaults to 0.5):
437
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
438
+ following a half-cosine).
439
+ last_epoch (`int`, *optional*, defaults to -1):
440
+ The index of the last epoch when resuming training.
441
+ min_lr (`float`, *optional*):
442
+ The minimum learning rate to reach after the cosine schedule.
443
+ min_lr_rate (`float`, *optional*):
444
+ The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
445
+ warmup_lr_rate (`float`, *optional*):
446
+ The minimum learning rate as a ratio of the start learning rate. If not set, `warmup_lr_rate` will be treated as float(1/num_warmup_steps).
447
+
448
+ Return:
449
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
450
+ """
451
+
452
+ if min_lr is not None and min_lr_rate is not None:
453
+ raise ValueError("Only one of min_lr or min_lr_rate should be set")
454
+ elif min_lr is not None:
455
+ min_lr_rate = min_lr / optimizer.defaults["lr"]
456
+ elif min_lr_rate is None:
457
+ raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
458
+
459
+ lr_lambda = partial(
460
+ _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda,
461
+ num_warmup_steps=num_warmup_steps,
462
+ num_training_steps=num_training_steps,
463
+ num_cycles=num_cycles,
464
+ min_lr_rate=min_lr_rate,
465
+ warmup_lr_rate=warmup_lr_rate,
466
+ )
467
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
468
+
469
+
470
+ def _get_wsd_scheduler_lambda(
471
+ current_step: int,
472
+ *,
473
+ num_warmup_steps: int,
474
+ num_stable_steps: int,
475
+ num_decay_steps: int,
476
+ warmup_type: str,
477
+ decay_type: str,
478
+ min_lr_ratio: float,
479
+ num_cycles: float,
480
+ ):
481
+ if current_step < num_warmup_steps:
482
+ progress = float(current_step) / float(max(1, num_warmup_steps))
483
+ if warmup_type == "linear":
484
+ factor = progress
485
+ elif warmup_type == "cosine":
486
+ factor = 0.5 * (1.0 - math.cos(math.pi * progress))
487
+ elif warmup_type == "1-sqrt":
488
+ factor = 1.0 - math.sqrt(1.0 - progress)
489
+ factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
490
+ return max(0.0, factor)
491
+
492
+ if current_step < num_warmup_steps + num_stable_steps:
493
+ return 1.0
494
+
495
+ if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
496
+ progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
497
+ if decay_type == "linear":
498
+ factor = 1.0 - progress
499
+ elif decay_type == "cosine":
500
+ factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
501
+ elif decay_type == "1-sqrt":
502
+ factor = 1.0 - math.sqrt(progress)
503
+ factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
504
+ return max(0.0, factor)
505
+ return min_lr_ratio
506
+
507
+
508
+ def get_wsd_schedule(
509
+ optimizer: Optimizer,
510
+ num_warmup_steps: int,
511
+ num_decay_steps: int,
512
+ num_training_steps: int | None = None,
513
+ num_stable_steps: int | None = None,
514
+ warmup_type: str = "linear",
515
+ decay_type: str = "cosine",
516
+ min_lr_ratio: float = 0,
517
+ num_cycles: float = 0.5,
518
+ last_epoch: int = -1,
519
+ ):
520
+ """
521
+ Create a schedule with a learning rate that has three stages:
522
+ 1. warmup: increase from min_lr_ratio times the initial learning rate to the initial learning rate following a warmup_type.
523
+ 2. stable: constant learning rate.
524
+ 3. decay: decrease from the initial learning rate to min_lr_ratio times the initial learning rate following a decay_type.
525
+
526
+ Args:
527
+ optimizer ([`~torch.optim.Optimizer`]):
528
+ The optimizer for which to schedule the learning rate.
529
+ num_warmup_steps (`int`):
530
+ The number of steps for the warmup phase.
531
+ num_decay_steps (`int`):
532
+ The number of steps for the decay phase.
533
+ num_training_steps (`int`, *optional*):
534
+ The total number of training steps. This is the sum of the warmup, stable and decay steps. If `num_stable_steps` is not provided, the stable phase will be `num_training_steps - num_warmup_steps - num_decay_steps`.
535
+ num_stable_steps (`int`, *optional*):
536
+ The number of steps for the stable phase. Please ensure that `num_warmup_steps + num_stable_steps + num_decay_steps` equals `num_training_steps`, otherwise the other steps will default to the minimum learning rate.
537
+ warmup_type (`str`, *optional*, defaults to "linear"):
538
+ The type of warmup to use. Can be 'linear', 'cosine' or '1-sqrt'.
539
+ decay_type (`str`, *optional*, defaults to "cosine"):
540
+ The type of decay to use. Can be 'linear', 'cosine' or '1-sqrt'.
541
+ min_lr_ratio (`float`, *optional*, defaults to 0):
542
+ The minimum learning rate as a ratio of the initial learning rate.
543
+ num_cycles (`float`, *optional*, defaults to 0.5):
544
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
545
+ following a half-cosine).
546
+ last_epoch (`int`, *optional*, defaults to -1):
547
+ The index of the last epoch when resuming training.
548
+
549
+ Return:
550
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
551
+ """
552
+
553
+ if num_training_steps is None and num_stable_steps is None:
554
+ raise ValueError("Either num_training_steps or num_stable_steps must be specified.")
555
+
556
+ if num_training_steps is not None and num_stable_steps is not None:
557
+ warnings.warn("Both num_training_steps and num_stable_steps are specified. num_stable_steps will be used.")
558
+
559
+ if warmup_type not in ["linear", "cosine", "1-sqrt"]:
560
+ raise ValueError(f"Unknown warmup type: {warmup_type}, expected 'linear', 'cosine' or '1-sqrt'")
561
+
562
+ if decay_type not in ["linear", "cosine", "1-sqrt"]:
563
+ raise ValueError(f"Unknown decay type: {decay_type}, expected 'linear', 'cosine' or '1-sqrt'")
564
+
565
+ if num_stable_steps is None:
566
+ num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
567
+
568
+ lr_lambda = partial(
569
+ _get_wsd_scheduler_lambda,
570
+ num_warmup_steps=num_warmup_steps,
571
+ num_stable_steps=num_stable_steps,
572
+ num_decay_steps=num_decay_steps,
573
+ warmup_type=warmup_type,
574
+ decay_type=decay_type,
575
+ min_lr_ratio=min_lr_ratio,
576
+ num_cycles=num_cycles,
577
+ )
578
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
579
+
580
+
581
+ class StreamingAverage:
582
+ """Rolling window average for smoothing metric values.
583
+
584
+ Maintains a sliding window of values and computes their average,
585
+ useful for smoothing noisy metric values before making learning rate decisions.
586
+
587
+ Args:
588
+ window_size (`int`):
589
+ The maximum number of values to keep in the rolling window.
590
+ """
591
+
592
+ def __init__(self, window_size: int) -> None:
593
+ self.window_size: int = window_size
594
+ self.values: list[float] = []
595
+ self.sum: float = 0.0
596
+
597
+ def streamavg(self, value: float) -> float:
598
+ """Add a value and return the current rolling average."""
599
+ self.values.append(value)
600
+ self.sum += value
601
+
602
+ if len(self.values) > self.window_size:
603
+ removed = self.values.pop(0)
604
+ self.sum -= removed
605
+
606
+ return self.sum / len(self.values)
607
+
608
+ def state_dict(self) -> dict[str, Any]:
609
+ return {
610
+ "window_size": self.window_size,
611
+ "values": self.values.copy(),
612
+ "sum": self.sum,
613
+ }
614
+
615
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
616
+ self.window_size = state_dict.get("window_size", self.window_size)
617
+ self.values = state_dict.get("values", []).copy()
618
+ self.sum = state_dict.get("sum", 0.0)
619
+
620
+
621
+ class GreedyLR:
622
+ """Adaptive learning rate scheduler that responds to training metrics.
623
+
624
+ GreedyLR dynamically adjusts the learning rate based on training performance:
625
+ - Increases LR when metrics improve consistently (divides by factor)
626
+ - Decreases LR when metrics plateau (multiplies by factor)
627
+
628
+ This differs from traditional schedulers like cosine annealing by responding
629
+ to actual training dynamics rather than following a predetermined schedule.
630
+
631
+ Reference: `GreedyLR: A Novel Adaptive Learning Rate Scheduler <https://arxiv.org/abs/2512.14527>`_
632
+
633
+ Args:
634
+ optimizer ([`~torch.optim.Optimizer`]):
635
+ The optimizer for which to schedule the learning rate.
636
+ mode (`str`, *optional*, defaults to `"min"`):
637
+ One of 'min' or 'max'. In 'min' mode, LR will be reduced when the
638
+ metric has stopped decreasing; in 'max' mode when it has stopped increasing.
639
+ factor (`float`, *optional*, defaults to 0.95):
640
+ Factor by which the learning rate will be adjusted. LR is multiplied by
641
+ factor on plateau and divided by factor on improvement. Must be < 1.0.
642
+ patience (`int`, *optional*, defaults to 10):
643
+ Number of epochs with no improvement after which learning rate will be adjusted.
644
+ threshold (`float`, *optional*, defaults to 1e-06):
645
+ Threshold for measuring the new optimum.
646
+ threshold_mode (`str`, *optional*, defaults to `"abs"`):
647
+ One of 'rel' or 'abs'.
648
+ cooldown (`int`, *optional*, defaults to 0):
649
+ Number of epochs to wait before resuming normal operation after LR has been reduced.
650
+ warmup (`int`, *optional*, defaults to 0):
651
+ Number of epochs to wait before resuming normal operation after LR has been increased.
652
+ min_lr (`float` or `list[float]`, *optional*, defaults to 0.001):
653
+ A lower bound on the learning rate.
654
+ max_lr (`float` or `list[float]`, *optional*, defaults to 1.0):
655
+ An upper bound on the learning rate.
656
+ eps (`float`, *optional*, defaults to 1e-08):
657
+ Minimal decay applied to lr.
658
+ verbose (`bool`, *optional*, defaults to `False`):
659
+ If True, prints a message to stdout for each update.
660
+ smooth (`bool`, *optional*, defaults to `False`):
661
+ If True, applies streaming average smoothing to metrics.
662
+ window_size (`int`, *optional*, defaults to 50):
663
+ The window size for the streaming average when smooth=True.
664
+ reset_start (`int`, *optional*, defaults to 500):
665
+ Number of steps to wait at min_lr before resetting to initial state.
666
+
667
+ Example:
668
+ ```python
669
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
670
+ >>> scheduler = GreedyLR(optimizer, mode="min", patience=10)
671
+ >>> for epoch in range(100):
672
+ ... train(...)
673
+ ... val_loss = validate(...)
674
+ ... scheduler.step(val_loss)
675
+ ```
676
+ """
677
+
678
+ def __init__(
679
+ self,
680
+ optimizer: Optimizer,
681
+ mode: str = "min",
682
+ factor: float = 0.95,
683
+ patience: int = 10,
684
+ threshold: float = 1e-6,
685
+ threshold_mode: str = "abs",
686
+ cooldown: int = 0,
687
+ warmup: int = 0,
688
+ min_lr: float | list[float] = 1e-3,
689
+ max_lr: float | list[float] = 1.0,
690
+ eps: float = 1e-8,
691
+ verbose: bool = False,
692
+ smooth: bool = False,
693
+ window_size: int = 50,
694
+ reset_start: int = 500,
695
+ ) -> None:
696
+ if factor >= 1.0:
697
+ raise ValueError("Factor should be < 1.0.")
698
+ if not isinstance(optimizer, Optimizer):
699
+ raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
700
+
701
+ self.optimizer = optimizer
702
+ self.factor = factor
703
+ self.patience = patience
704
+ self.verbose = verbose
705
+ self.cooldown = cooldown
706
+ self.warmup = warmup
707
+ self.cooldown_counter = 0
708
+ self.warmup_counter = 0
709
+ self.mode = mode
710
+ self.threshold = threshold
711
+ self.threshold_mode = threshold_mode
712
+ self.eps = eps
713
+ self.smooth = smooth
714
+ self.window_size = window_size
715
+ self.reset_start = reset_start
716
+ self.reset_start_original = reset_start
717
+ self.last_epoch = 0
718
+
719
+ if isinstance(min_lr, (list, tuple)):
720
+ if len(min_lr) != len(optimizer.param_groups):
721
+ raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
722
+ self.min_lrs = list(min_lr)
723
+ else:
724
+ self.min_lrs = [min_lr] * len(optimizer.param_groups)
725
+
726
+ if isinstance(max_lr, (list, tuple)):
727
+ if len(max_lr) != len(optimizer.param_groups):
728
+ raise ValueError(f"expected {len(optimizer.param_groups)} max_lrs, got {len(max_lr)}")
729
+ self.max_lrs = list(max_lr)
730
+ else:
731
+ self.max_lrs = [max_lr] * len(optimizer.param_groups)
732
+
733
+ self._init_lrs = [group["lr"] for group in optimizer.param_groups]
734
+ self._last_lr = self._init_lrs.copy()
735
+
736
+ self.best: float = float("inf") if mode == "min" else float("-inf")
737
+ self.num_bad_epochs = 0
738
+ self.num_good_epochs = 0
739
+
740
+ if mode not in ("min", "max"):
741
+ raise ValueError(f"mode {mode} is unknown!")
742
+ if threshold_mode not in ("rel", "abs"):
743
+ raise ValueError(f"threshold mode {threshold_mode} is unknown!")
744
+
745
+ self._streaming_avg: StreamingAverage | None = None
746
+ if smooth:
747
+ self._streaming_avg = StreamingAverage(window_size)
748
+
749
+ def step(self, metrics: float, epoch: int | None = None) -> None:
750
+ """Perform a scheduler step based on the given metrics.
751
+
752
+ Args:
753
+ metrics (`float`):
754
+ The metric value to use for LR adjustment decisions.
755
+ epoch (`int`, *optional*):
756
+ The current epoch number. If None, uses internal counter.
757
+ """
758
+ current = float(metrics)
759
+
760
+ if self.smooth and self._streaming_avg is not None:
761
+ current = self._streaming_avg.streamavg(current)
762
+
763
+ if epoch is None:
764
+ epoch = self.last_epoch + 1
765
+ self.last_epoch = epoch
766
+
767
+ if self.cooldown_counter > 0:
768
+ self.cooldown_counter -= 1
769
+ self.num_bad_epochs = 0
770
+ self.num_good_epochs = 0
771
+ elif self.warmup_counter > 0:
772
+ self.warmup_counter -= 1
773
+ self.num_bad_epochs = 0
774
+ self.num_good_epochs = 0
775
+ else:
776
+ if self.is_better(current, self.best):
777
+ self.best = current
778
+ self.num_bad_epochs = 0
779
+ self.num_good_epochs += 1
780
+ else:
781
+ self.num_bad_epochs += 1
782
+ self.num_good_epochs = 0
783
+
784
+ if self.num_good_epochs > self.patience:
785
+ self._increase_lr(epoch)
786
+ self.warmup_counter = self.warmup
787
+ self.num_good_epochs = 0
788
+ elif self.num_bad_epochs > self.patience:
789
+ self._reduce_lr(epoch)
790
+ self.cooldown_counter = self.cooldown
791
+ self.num_bad_epochs = 0
792
+
793
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
794
+
795
+ def is_better(self, current: float, best: float) -> bool:
796
+ if self.mode == "min":
797
+ if self.threshold_mode == "rel":
798
+ return current < best * (1.0 - self.threshold)
799
+ else:
800
+ return current < best - self.threshold
801
+ else:
802
+ if self.threshold_mode == "rel":
803
+ return current > best * (1.0 + self.threshold)
804
+ else:
805
+ return current > best + self.threshold
806
+
807
+ def _reduce_lr(self, epoch: int) -> None:
808
+ all_at_min = True
809
+ for i, param_group in enumerate(self.optimizer.param_groups):
810
+ old_lr = float(param_group["lr"])
811
+ new_lr = max(old_lr * self.factor, self.min_lrs[i])
812
+
813
+ if old_lr - new_lr > self.eps:
814
+ param_group["lr"] = new_lr
815
+ if self.verbose:
816
+ print(f"Epoch {epoch}: reducing learning rate of group {i} to {new_lr:.4e}.")
817
+
818
+ if param_group["lr"] > self.min_lrs[i]:
819
+ all_at_min = False
820
+
821
+ if all_at_min:
822
+ self.reset_start -= 1
823
+ if self.reset_start <= 0:
824
+ self._reset()
825
+
826
+ def _increase_lr(self, epoch: int) -> None:
827
+ for i, param_group in enumerate(self.optimizer.param_groups):
828
+ old_lr = float(param_group["lr"])
829
+ new_lr = min(old_lr / self.factor, self.max_lrs[i])
830
+
831
+ if new_lr - old_lr > self.eps:
832
+ param_group["lr"] = new_lr
833
+ if self.verbose:
834
+ print(f"Epoch {epoch}: increasing learning rate of group {i} to {new_lr:.4e}.")
835
+
836
+ self.reset_start = self.reset_start_original
837
+
838
+ def _reset(self) -> None:
839
+ for i, param_group in enumerate(self.optimizer.param_groups):
840
+ param_group["lr"] = self._init_lrs[i]
841
+
842
+ self.best = float("inf") if self.mode == "min" else float("-inf")
843
+ self.num_bad_epochs = 0
844
+ self.num_good_epochs = 0
845
+ self.cooldown_counter = 0
846
+ self.warmup_counter = 0
847
+ self.reset_start = self.reset_start_original
848
+
849
+ if self.smooth and self._streaming_avg is not None:
850
+ self._streaming_avg = StreamingAverage(self.window_size)
851
+
852
+ if self.verbose:
853
+ print("Scheduler reset to initial state.")
854
+
855
+ def get_last_lr(self) -> list[float]:
856
+ """Return last computed learning rate by current scheduler."""
857
+ return self._last_lr
858
+
859
+ def state_dict(self) -> dict[str, Any]:
860
+ """Return the state of the scheduler as a dictionary."""
861
+ state = {
862
+ "factor": self.factor,
863
+ "min_lrs": self.min_lrs,
864
+ "max_lrs": self.max_lrs,
865
+ "patience": self.patience,
866
+ "verbose": self.verbose,
867
+ "cooldown": self.cooldown,
868
+ "warmup": self.warmup,
869
+ "cooldown_counter": self.cooldown_counter,
870
+ "warmup_counter": self.warmup_counter,
871
+ "mode": self.mode,
872
+ "threshold": self.threshold,
873
+ "threshold_mode": self.threshold_mode,
874
+ "best": self.best,
875
+ "num_bad_epochs": self.num_bad_epochs,
876
+ "num_good_epochs": self.num_good_epochs,
877
+ "eps": self.eps,
878
+ "last_epoch": self.last_epoch,
879
+ "smooth": self.smooth,
880
+ "window_size": self.window_size,
881
+ "reset_start": self.reset_start,
882
+ "reset_start_original": self.reset_start_original,
883
+ "_last_lr": self._last_lr,
884
+ "_init_lrs": self._init_lrs,
885
+ }
886
+
887
+ if self.smooth and self._streaming_avg is not None:
888
+ state["_streaming_avg"] = self._streaming_avg.state_dict()
889
+
890
+ return state
891
+
892
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
893
+ """Load state from a dictionary."""
894
+ self.factor = state_dict.get("factor", self.factor)
895
+ self.min_lrs = state_dict.get("min_lrs", self.min_lrs)
896
+ self.max_lrs = state_dict.get("max_lrs", self.max_lrs)
897
+ self.patience = state_dict.get("patience", self.patience)
898
+ self.verbose = state_dict.get("verbose", self.verbose)
899
+ self.cooldown = state_dict.get("cooldown", self.cooldown)
900
+ self.warmup = state_dict.get("warmup", self.warmup)
901
+ self.cooldown_counter = state_dict.get("cooldown_counter", self.cooldown_counter)
902
+ self.warmup_counter = state_dict.get("warmup_counter", self.warmup_counter)
903
+ self.mode = state_dict.get("mode", self.mode)
904
+ self.threshold = state_dict.get("threshold", self.threshold)
905
+ self.threshold_mode = state_dict.get("threshold_mode", self.threshold_mode)
906
+ self.best = state_dict.get("best", self.best)
907
+ self.num_bad_epochs = state_dict.get("num_bad_epochs", self.num_bad_epochs)
908
+ self.num_good_epochs = state_dict.get("num_good_epochs", self.num_good_epochs)
909
+ self.eps = state_dict.get("eps", self.eps)
910
+ self.last_epoch = state_dict.get("last_epoch", self.last_epoch)
911
+ self.smooth = state_dict.get("smooth", self.smooth)
912
+ self.window_size = state_dict.get("window_size", self.window_size)
913
+ self.reset_start = state_dict.get("reset_start", self.reset_start)
914
+ self.reset_start_original = state_dict.get("reset_start_original", self.reset_start_original)
915
+ self._last_lr = state_dict.get("_last_lr", self._last_lr)
916
+ self._init_lrs = state_dict.get("_init_lrs", self._init_lrs)
917
+
918
+ if "_streaming_avg" in state_dict:
919
+ if self._streaming_avg is None:
920
+ self._streaming_avg = StreamingAverage(self.window_size)
921
+ self._streaming_avg.load_state_dict(state_dict["_streaming_avg"])
922
+
923
+ if "_last_lr" in state_dict:
924
+ for param_group, lr in zip(self.optimizer.param_groups, self._last_lr):
925
+ param_group["lr"] = lr
926
+
927
+
928
+ def get_greedy_schedule(optimizer: Optimizer, **kwargs):
929
+ """
930
+ Create an adaptive learning rate scheduler that adjusts LR based on training metrics.
931
+
932
+ Args:
933
+ optimizer ([`~torch.optim.Optimizer`]):
934
+ The optimizer for which to schedule the learning rate.
935
+ kwargs (`dict`, *optional*):
936
+ Extra parameters passed to the scheduler. See [`GreedyLR`] for possible parameters.
937
+
938
+ Return:
939
+ [`GreedyLR`] with the appropriate schedule.
940
+ """
941
+ return GreedyLR(optimizer, **kwargs)
942
+
943
+
944
+ TYPE_TO_SCHEDULER_FUNCTION = {
945
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
946
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
947
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
948
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
949
+ SchedulerType.CONSTANT: get_constant_schedule,
950
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
951
+ SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
952
+ SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
953
+ SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
954
+ SchedulerType.COSINE_WARMUP_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup_lr_rate,
955
+ SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule,
956
+ SchedulerType.GREEDY: get_greedy_schedule,
957
+ }
958
+
959
+
960
+ def get_scheduler(
961
+ name: str | SchedulerType,
962
+ optimizer: Optimizer,
963
+ num_warmup_steps: int | None = None,
964
+ num_training_steps: int | None = None,
965
+ scheduler_specific_kwargs: dict | None = None,
966
+ ):
967
+ """
968
+ Unified API to get any scheduler from its name.
969
+
970
+ Args:
971
+ name (`str` or `SchedulerType`):
972
+ The name of the scheduler to use.
973
+ optimizer (`torch.optim.Optimizer`):
974
+ The optimizer that will be used during training.
975
+ num_warmup_steps (`int`, *optional*):
976
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
977
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
978
+ num_training_steps (`int``, *optional*):
979
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
980
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
981
+ scheduler_specific_kwargs (`dict`, *optional*):
982
+ Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
983
+ parameters will cause the scheduler function to raise a TypeError.
984
+ """
985
+ name = SchedulerType(name)
986
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
987
+
988
+ # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
989
+ # recursively call `get_scheduler` to get the proper schedulers on each parameter
990
+ if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
991
+ optimizer_dict = optimizer.optimizer_dict
992
+ scheduler_dict = {}
993
+
994
+ for param in optimizer_dict:
995
+ scheduler_dict[param] = get_scheduler(
996
+ name,
997
+ optimizer=optimizer_dict[param],
998
+ num_warmup_steps=num_warmup_steps,
999
+ num_training_steps=num_training_steps,
1000
+ scheduler_specific_kwargs=scheduler_specific_kwargs,
1001
+ )
1002
+
1003
+ def scheduler_hook(param):
1004
+ # Since the optimizer hook has been already attached we only need to
1005
+ # attach the scheduler hook, the gradients have been zeroed here
1006
+ scheduler_dict[param].step()
1007
+
1008
+ for param in optimizer_dict:
1009
+ if param.requires_grad:
1010
+ param.register_post_accumulate_grad_hook(scheduler_hook)
1011
+
1012
+ return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"])
1013
+
1014
+ if name == SchedulerType.CONSTANT:
1015
+ return schedule_func(optimizer)
1016
+
1017
+ if scheduler_specific_kwargs is None:
1018
+ scheduler_specific_kwargs = {}
1019
+
1020
+ if name == SchedulerType.REDUCE_ON_PLATEAU:
1021
+ return schedule_func(optimizer, **scheduler_specific_kwargs)
1022
+
1023
+ if name == SchedulerType.GREEDY:
1024
+ return schedule_func(optimizer, **scheduler_specific_kwargs)
1025
+
1026
+ # All other schedulers require `num_warmup_steps`
1027
+ if num_warmup_steps is None:
1028
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
1029
+
1030
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
1031
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
1032
+
1033
+ if name == SchedulerType.INVERSE_SQRT:
1034
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **scheduler_specific_kwargs)
1035
+
1036
+ # wsd scheduler requires either num_training_steps or num_stable_steps
1037
+ if name == SchedulerType.WARMUP_STABLE_DECAY:
1038
+ return schedule_func(
1039
+ optimizer,
1040
+ num_warmup_steps=num_warmup_steps,
1041
+ num_training_steps=num_training_steps,
1042
+ **scheduler_specific_kwargs,
1043
+ )
1044
+
1045
+ # All other schedulers require `num_training_steps`
1046
+ if num_training_steps is None:
1047
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
1048
+
1049
+ return schedule_func(
1050
+ optimizer,
1051
+ num_warmup_steps=num_warmup_steps,
1052
+ num_training_steps=num_training_steps,
1053
+ **scheduler_specific_kwargs,
1054
+ )
1055
+
1056
+
1057
+ class Adafactor(Optimizer):
1058
+ """
1059
+ AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
1060
+ https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
1061
+
1062
+ Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://huggingface.co/papers/1804.04235 Note that
1063
+ this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
1064
+ `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
1065
+ `relative_step=False`.
1066
+
1067
+ Arguments:
1068
+ params (`Iterable[nn.parameter.Parameter]`):
1069
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
1070
+ lr (`float`, *optional*):
1071
+ The external learning rate.
1072
+ eps (`tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
1073
+ Regularization constants for square gradient and parameter scale respectively
1074
+ clip_threshold (`float`, *optional*, defaults to 1.0):
1075
+ Threshold of root mean square of final gradient update
1076
+ decay_rate (`float`, *optional*, defaults to -0.8):
1077
+ Coefficient used to compute running averages of square
1078
+ beta1 (`float`, *optional*):
1079
+ Coefficient used for computing running averages of gradient
1080
+ weight_decay (`float`, *optional*, defaults to 0.0):
1081
+ Weight decay (L2 penalty)
1082
+ scale_parameter (`bool`, *optional*, defaults to `True`):
1083
+ If True, learning rate is scaled by root mean square
1084
+ relative_step (`bool`, *optional*, defaults to `True`):
1085
+ If True, time-dependent learning rate is computed instead of external learning rate
1086
+ warmup_init (`bool`, *optional*, defaults to `False`):
1087
+ Time-dependent learning rate computation depends on whether warm-up initialization is being used
1088
+
1089
+ This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
1090
+
1091
+ Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
1092
+
1093
+ - Training without LR warmup or clip_threshold is not recommended.
1094
+
1095
+ - use scheduled LR warm-up to fixed LR
1096
+ - use clip_threshold=1.0 (https://huggingface.co/papers/1804.04235)
1097
+ - Disable relative updates
1098
+ - Use scale_parameter=False
1099
+ - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
1100
+
1101
+ Example:
1102
+
1103
+ ```python
1104
+ Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
1105
+ ```
1106
+
1107
+ Others reported the following combination to work well:
1108
+
1109
+ ```python
1110
+ Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
1111
+ ```
1112
+
1113
+ When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
1114
+ scheduler as following:
1115
+
1116
+ ```python
1117
+ from transformers.optimization import Adafactor, AdafactorSchedule
1118
+
1119
+ optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
1120
+ lr_scheduler = AdafactorSchedule(optimizer)
1121
+ trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
1122
+ ```
1123
+
1124
+ Usage:
1125
+
1126
+ ```python
1127
+ # replace AdamW with Adafactor
1128
+ optimizer = Adafactor(
1129
+ model.parameters(),
1130
+ lr=1e-3,
1131
+ eps=(1e-30, 1e-3),
1132
+ clip_threshold=1.0,
1133
+ decay_rate=-0.8,
1134
+ beta1=None,
1135
+ weight_decay=0.0,
1136
+ relative_step=False,
1137
+ scale_parameter=False,
1138
+ warmup_init=False,
1139
+ )
1140
+ ```"""
1141
+
1142
+ def __init__(
1143
+ self,
1144
+ params,
1145
+ lr=None,
1146
+ eps=(1e-30, 1e-3),
1147
+ clip_threshold=1.0,
1148
+ decay_rate=-0.8,
1149
+ beta1=None,
1150
+ weight_decay=0.0,
1151
+ scale_parameter=True,
1152
+ relative_step=True,
1153
+ warmup_init=False,
1154
+ ):
1155
+ if lr is not None and relative_step:
1156
+ raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
1157
+ if warmup_init and not relative_step:
1158
+ raise ValueError("`warmup_init=True` requires `relative_step=True`")
1159
+
1160
+ defaults = {
1161
+ "lr": lr,
1162
+ "eps": eps,
1163
+ "clip_threshold": clip_threshold,
1164
+ "decay_rate": decay_rate,
1165
+ "beta1": beta1,
1166
+ "weight_decay": weight_decay,
1167
+ "scale_parameter": scale_parameter,
1168
+ "relative_step": relative_step,
1169
+ "warmup_init": warmup_init,
1170
+ }
1171
+ super().__init__(params, defaults)
1172
+
1173
+ @staticmethod
1174
+ def _get_lr(param_group, param_state):
1175
+ rel_step_sz = param_group["lr"]
1176
+ if param_group["relative_step"]:
1177
+ min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
1178
+ rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
1179
+ param_scale = 1.0
1180
+ if param_group["scale_parameter"]:
1181
+ param_scale = max(param_group["eps"][1], param_state["RMS"])
1182
+ return param_scale * rel_step_sz
1183
+
1184
+ @staticmethod
1185
+ def _get_options(param_group, param_shape):
1186
+ factored = len(param_shape) >= 2
1187
+ use_first_moment = param_group["beta1"] is not None
1188
+ return factored, use_first_moment
1189
+
1190
+ @staticmethod
1191
+ def _rms(tensor):
1192
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
1193
+
1194
+ @staticmethod
1195
+ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
1196
+ # copy from fairseq's adafactor implementation:
1197
+ # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
1198
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
1199
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
1200
+ return torch.mul(r_factor, c_factor)
1201
+
1202
+ @torch.no_grad()
1203
+ def step(self, closure=None):
1204
+ """
1205
+ Performs a single optimization step
1206
+
1207
+ Arguments:
1208
+ closure (callable, optional): A closure that reevaluates the model
1209
+ and returns the loss.
1210
+ """
1211
+ loss = None
1212
+ if closure is not None:
1213
+ loss = closure()
1214
+
1215
+ for group in self.param_groups:
1216
+ for p in group["params"]:
1217
+ if p.grad is None:
1218
+ continue
1219
+ grad = p.grad
1220
+ if grad.dtype in {torch.float16, torch.bfloat16}:
1221
+ grad = grad.float()
1222
+ if grad.is_sparse:
1223
+ raise RuntimeError("Adafactor does not support sparse gradients.")
1224
+
1225
+ state = self.state[p]
1226
+ grad_shape = grad.shape
1227
+
1228
+ factored, use_first_moment = self._get_options(group, grad_shape)
1229
+ # State Initialization
1230
+ if len(state) == 0:
1231
+ state["step"] = 0
1232
+
1233
+ if use_first_moment:
1234
+ # Exponential moving average of gradient values
1235
+ state["exp_avg"] = torch.zeros_like(grad)
1236
+ if factored:
1237
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
1238
+ state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
1239
+ else:
1240
+ state["exp_avg_sq"] = torch.zeros_like(grad)
1241
+
1242
+ state["RMS"] = 0
1243
+ else:
1244
+ if use_first_moment:
1245
+ state["exp_avg"] = state["exp_avg"].to(grad)
1246
+ if factored:
1247
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
1248
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
1249
+ else:
1250
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
1251
+
1252
+ p_data_fp32 = p
1253
+ if p.dtype in {torch.float16, torch.bfloat16}:
1254
+ p_data_fp32 = p_data_fp32.float()
1255
+
1256
+ state["step"] += 1
1257
+ state["RMS"] = self._rms(p_data_fp32)
1258
+ lr = self._get_lr(group, state)
1259
+
1260
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
1261
+ update = (grad**2) + group["eps"][0]
1262
+ if factored:
1263
+ exp_avg_sq_row = state["exp_avg_sq_row"]
1264
+ exp_avg_sq_col = state["exp_avg_sq_col"]
1265
+
1266
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
1267
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
1268
+
1269
+ # Approximation of exponential moving average of square of gradient
1270
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
1271
+ update.mul_(grad)
1272
+ else:
1273
+ exp_avg_sq = state["exp_avg_sq"]
1274
+
1275
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
1276
+ update = exp_avg_sq.rsqrt().mul_(grad)
1277
+
1278
+ update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
1279
+ update.mul_(lr)
1280
+
1281
+ if use_first_moment:
1282
+ exp_avg = state["exp_avg"]
1283
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
1284
+ update = exp_avg
1285
+
1286
+ if group["weight_decay"] != 0:
1287
+ p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
1288
+
1289
+ p_data_fp32.add_(-update)
1290
+
1291
+ if p.dtype in {torch.float16, torch.bfloat16}:
1292
+ p.copy_(p_data_fp32)
1293
+
1294
+ return loss
1295
+
1296
+
1297
+ class AdafactorSchedule(LambdaLR):
1298
+ """
1299
+ Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g.,
1300
+ for logging), this class creates a proxy object that retrieves the current lr values from the optimizer.
1301
+
1302
+ It returns `initial_lr` during startup and the actual `lr` during stepping.
1303
+ """
1304
+
1305
+ def __init__(self, optimizer, initial_lr=0.0):
1306
+ def lr_lambda(_):
1307
+ return initial_lr
1308
+
1309
+ for group in optimizer.param_groups:
1310
+ group["initial_lr"] = initial_lr
1311
+ super().__init__(optimizer, lr_lambda)
1312
+ for group in optimizer.param_groups:
1313
+ del group["initial_lr"]
1314
+
1315
+ def get_lr(self):
1316
+ opt = self.optimizer
1317
+ lrs = [
1318
+ opt._get_lr(group, opt.state[group["params"][0]])
1319
+ for group in opt.param_groups
1320
+ if group["params"][0].grad is not None
1321
+ ]
1322
+ if len(lrs) == 0:
1323
+ lrs = self.base_lrs # if called before stepping
1324
+ return lrs
1325
+
1326
+
1327
+ def get_adafactor_schedule(optimizer, initial_lr=0.0):
1328
+ """
1329
+ Get a proxy schedule for [`~optimization.Adafactor`]
1330
+
1331
+ Args:
1332
+ optimizer ([`~torch.optim.Optimizer`]):
1333
+ The optimizer for which to schedule the learning rate.
1334
+ initial_lr (`float`, *optional*, defaults to 0.0):
1335
+ Initial lr
1336
+
1337
+ Return:
1338
+ [`~optimization.Adafactor`] proxy schedule object.
1339
+
1340
+
1341
+ """
1342
+ return AdafactorSchedule(optimizer, initial_lr)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/tokenization_python.py ADDED
@@ -0,0 +1,1420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see
16
+ tokenization_utils_tokenizers.py
17
+ """
18
+
19
+ import bisect
20
+ import unicodedata
21
+ from collections import OrderedDict
22
+ from typing import Any, overload
23
+
24
+ from .tokenization_utils_base import (
25
+ INIT_TOKENIZER_DOCSTRING,
26
+ AddedToken,
27
+ BatchEncoding,
28
+ EncodedInput,
29
+ PreTokenizedInput,
30
+ PreTrainedTokenizerBase,
31
+ TextInput,
32
+ TruncationStrategy,
33
+ )
34
+ from .utils import PaddingStrategy, TensorType, add_end_docstrings, logging
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ # Slow tokenizers are saved in a vocabulary plus three separated files
40
+ SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
41
+ ADDED_TOKENS_FILE = "added_tokens.json"
42
+ TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
43
+
44
+
45
+ class Trie:
46
+ """
47
+ Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
48
+ Loose reference https://en.wikipedia.org/wiki/Trie
49
+ """
50
+
51
+ def __init__(self, *args):
52
+ self.data = {}
53
+ self._tokens = set()
54
+ self._termination_char = ""
55
+ self.update(*args)
56
+
57
+ def update(self, *args):
58
+ """
59
+ Updates the Trie with new tokens provided as arguments.
60
+
61
+ Args:
62
+ *args: Variable number of words to be added to the Trie.
63
+ """
64
+ for token in tuple(*args):
65
+ self.add(token)
66
+
67
+ def add(self, word: str):
68
+ """
69
+ Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
70
+ The special key `""` in `self._termination_char` is used to represent termination.
71
+
72
+ This function is idempotent, adding twice the same word will leave the trie unchanged
73
+
74
+ Example:
75
+
76
+ ```python
77
+ >>> trie = Trie()
78
+ >>> trie.add("Hello 友達")
79
+ >>> trie.data
80
+ {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
81
+
82
+ >>> trie.add("Hello")
83
+ >>> trie.data
84
+ {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
85
+ ```
86
+ """
87
+ if not word:
88
+ # Prevent empty string
89
+ return
90
+
91
+ self._tokens.add(word)
92
+ ref = self.data
93
+ for char in word:
94
+ ref[char] = ref.setdefault(char, {})
95
+ ref = ref[char]
96
+ ref[self._termination_char] = 1
97
+
98
+ def split(self, text: str) -> list[str]:
99
+ """
100
+ Will look for the words added to the trie within `text`. Output is the original string split along the
101
+ boundaries of the words found.
102
+
103
+ This trie will match the longest possible word first !
104
+
105
+ Example:
106
+
107
+ ```python
108
+ >>> trie = Trie()
109
+ >>> trie.split("[CLS] This is a extra_id_100")
110
+ ["[CLS] This is a extra_id_100"]
111
+
112
+ >>> trie.add("[CLS]")
113
+ >>> trie.add("extra_id_1")
114
+ >>> trie.add("extra_id_100")
115
+ >>> trie.split("[CLS] This is a extra_id_100")
116
+ ["[CLS]", " This is a ", "extra_id_100"]
117
+ ```
118
+ """
119
+ # indexes are counted left of the chars index.
120
+ # "hello", index 0, is left of h, index 1 is between h and e.
121
+ # index 5 is right of the "o".
122
+
123
+ # States are going to capture every possible start (indexes as above)
124
+ # as keys, and have as values, a pointer to the position in the trie
125
+ # where we're at. This is a partial match for now.
126
+ # This enables to keep track of multiple matches while we're iterating
127
+ # the string
128
+ # If the trie contains, "blowing", and "lower" and we encounter the
129
+ # string "blower", we need to split into ["b", "lower"].
130
+ # This is where we need to keep track of multiple possible starts.
131
+ states = OrderedDict()
132
+
133
+ # This will contain every indices where we need
134
+ # to cut.
135
+ # We force to cut at offset 0 and len(text) (added later)
136
+ offsets = [0]
137
+
138
+ # This is used by the lookahead which needs to skip over
139
+ # some text where the full match exceeded the place in the initial
140
+ # for loop
141
+ skip = 0
142
+ # Main loop, Giving this algorithm O(n) complexity
143
+ for current, current_char in enumerate(text):
144
+ if skip and current < skip:
145
+ # Prevents the lookahead for matching twice
146
+ # like extra_id_100 and id_100
147
+ continue
148
+
149
+ # This will track every state
150
+ # that stop matching, we need to stop tracking them.
151
+ # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
152
+ # fail on "b", we need to remove 0 from the valid states.
153
+ to_remove = set()
154
+ # Whenever we found a match, we need to drop everything
155
+ # this is a greedy algorithm, it will match on the first found token
156
+ reset = False
157
+
158
+ # In this case, we already have partial matches (But unfinished)
159
+ for start, trie_pointer in states.items():
160
+ if "" in trie_pointer:
161
+ # This is a final match, we need to reset and
162
+ # store the results in `offsets`.
163
+
164
+ # Lookahead to match longest first
165
+ # Important in case of extra_id_1 vs extra_id_100
166
+ # Here we are also actively looking for other earlier partial
167
+ # matches
168
+ # "[CLS]", "L", we need to match CLS even if L is special
169
+ for lookstart, looktrie_pointer in states.items():
170
+ if lookstart > start:
171
+ # This partial match is later, we can stop looking
172
+ break
173
+ elif lookstart < start:
174
+ # This partial match is earlier, the trie pointer
175
+ # was already updated, so index is + 1
176
+ lookahead_index = current + 1
177
+ end = current + 1
178
+ else:
179
+ # Here lookstart == start and
180
+ # looktrie_pointer == trie_pointer
181
+ # It wasn't updated yet so indices are current ones
182
+ lookahead_index = current
183
+ end = current
184
+ next_char = text[lookahead_index] if lookahead_index < len(text) else None
185
+ if "" in looktrie_pointer:
186
+ start = lookstart
187
+ end = lookahead_index
188
+ skip = lookahead_index
189
+
190
+ while next_char in looktrie_pointer:
191
+ looktrie_pointer = looktrie_pointer[next_char]
192
+ lookahead_index += 1
193
+ if "" in looktrie_pointer:
194
+ start = lookstart
195
+ end = lookahead_index
196
+ skip = lookahead_index
197
+
198
+ if lookahead_index == len(text):
199
+ # End of string
200
+ break
201
+ next_char = text[lookahead_index]
202
+ # End lookahead
203
+
204
+ # Storing and resetting
205
+ offsets.append(start)
206
+ offsets.append(end)
207
+ reset = True
208
+ break
209
+ elif current_char in trie_pointer:
210
+ # The current character being looked at has a match within the trie
211
+ # update the pointer (it will be stored back into states later).
212
+ trie_pointer = trie_pointer[current_char]
213
+
214
+ # Storing back the new pointer into the states.
215
+ # Partial matches got longer by one.
216
+ states[start] = trie_pointer
217
+ else:
218
+ # The new character has not match in the trie, we need
219
+ # to stop keeping track of this partial match.
220
+ # We can't do it directly within the loop because of how
221
+ # python iteration works
222
+ to_remove.add(start)
223
+
224
+ # Either clearing the full start (we found a real match)
225
+ # Or clearing only the partial matches that didn't work.
226
+ if reset:
227
+ states = {}
228
+ else:
229
+ for start in to_remove:
230
+ del states[start]
231
+
232
+ # If this character is a starting character within the trie
233
+ # start keeping track of this partial match.
234
+ if current >= skip and current_char in self.data:
235
+ states[current] = self.data[current_char]
236
+
237
+ # We have a cut at the end with states.
238
+ for start, trie_pointer in states.items():
239
+ if "" in trie_pointer:
240
+ # This is a final match, we need to reset and
241
+ # store the results in `offsets`.
242
+ end = len(text)
243
+ offsets.append(start)
244
+ offsets.append(end)
245
+ # Longest cut is always the one with lower start so the first
246
+ # item so we need to break.
247
+ break
248
+
249
+ return self.cut_text(text, offsets)
250
+
251
+ def cut_text(self, text, offsets):
252
+ # We have all the offsets now, we just need to do the actual splitting.
253
+ # We need to eventually add the first part of the string and the eventual
254
+ # last part.
255
+ offsets.append(len(text))
256
+ tokens = []
257
+ start = 0
258
+ for end in offsets:
259
+ if start > end:
260
+ logger.error(
261
+ "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
262
+ " anyway."
263
+ )
264
+ continue
265
+ elif start == end:
266
+ # This might happen if there's a match at index 0
267
+ # we're also preventing zero-width cuts in case of two
268
+ # consecutive matches
269
+ continue
270
+ tokens.append(text[start:end])
271
+ start = end
272
+
273
+ return tokens
274
+
275
+
276
+ class ExtensionsTrie(Trie):
277
+ def __init__(self, *args):
278
+ super().__init__(*args)
279
+
280
+ def extensions(self, prefix: str):
281
+ """
282
+ Generates all extensions of a given prefix token in the Trie.
283
+
284
+ Example:
285
+
286
+ ```python
287
+ >>> trie = Trie()
288
+ >>> trie.add("apple")
289
+ >>> trie.add("app")
290
+ >>> trie.add("application")
291
+ >>> trie.extensions("app")
292
+ ['app', 'apple', 'application']
293
+ ```
294
+ """
295
+ prefix_node = self._get_node(prefix)
296
+ ret = self._collect_tokens(prefix_node)
297
+ return [prefix + token for token in ret]
298
+
299
+ def _get_node(self, token: str) -> dict:
300
+ """
301
+ Retrieves the node corresponding to the given token in the Trie.
302
+
303
+ Args:
304
+ token (str): The token for which the corresponding node needs to be retrieved.
305
+
306
+ Returns:
307
+ dict: The node in the Trie corresponding to the given token.
308
+ """
309
+ node = self.data
310
+ for char in token:
311
+ if char not in node:
312
+ break
313
+
314
+ node = node[char]
315
+ return node
316
+
317
+ def _collect_tokens(self, node: dict) -> list:
318
+ """
319
+ Generates all tokens in the Trie starting from a given node.
320
+
321
+ Args:
322
+ node (dict): The node in the Trie from which tokens need to be generated.
323
+
324
+ Returns:
325
+ list: List of tokens generated from the given node.
326
+ """
327
+ tokens = [self._termination_char] if self._termination_char in node else []
328
+ for token, subtrie_head in node.items():
329
+ if token != self._termination_char:
330
+ subtokens = self._collect_tokens(subtrie_head)
331
+ tokens.extend([token + subtoken for subtoken in subtokens])
332
+ return tokens
333
+
334
+
335
+ def _is_whitespace(char):
336
+ """Checks whether `char` is a whitespace character."""
337
+ # \t, \n, and \r are technically control characters but we treat them
338
+ # as whitespace since they are generally considered as such.
339
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
340
+ return True
341
+ cat = unicodedata.category(char)
342
+ if cat == "Zs":
343
+ return True
344
+ return False
345
+
346
+
347
+ def _is_control(char):
348
+ """Checks whether `char` is a control character."""
349
+ # These are technically control characters but we count them as whitespace
350
+ # characters.
351
+ if char == "\t" or char == "\n" or char == "\r":
352
+ return False
353
+ cat = unicodedata.category(char)
354
+ if cat.startswith("C"):
355
+ return True
356
+ return False
357
+
358
+
359
+ def _is_punctuation(char):
360
+ """Checks whether `char` is a punctuation character."""
361
+ cp = ord(char)
362
+ # We treat all non-letter/number ASCII as punctuation.
363
+ # Characters such as "^", "$", and "`" are not in the Unicode
364
+ # Punctuation class but we treat them as punctuation anyways, for
365
+ # consistency.
366
+ if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
367
+ return True
368
+ cat = unicodedata.category(char)
369
+ if cat.startswith("P"):
370
+ return True
371
+ return False
372
+
373
+
374
+ def _is_end_of_word(text):
375
+ """Checks whether the last character in text is one of a punctuation, control or whitespace character."""
376
+ last_char = text[-1]
377
+ return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
378
+
379
+
380
+ def _is_start_of_word(text):
381
+ """Checks whether the first character in text is one of a punctuation, control or whitespace character."""
382
+ first_char = text[0]
383
+ return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
384
+
385
+
386
+ def _insert_one_token_to_ordered_list(token_list: list[str], new_token: str):
387
+ """
388
+ Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted.
389
+ """
390
+ insertion_idx = bisect.bisect_left(token_list, new_token)
391
+ # Checks if new_token is already in the ordered token_list
392
+ if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token:
393
+ # new_token is in token_list, don't add
394
+ return
395
+ else:
396
+ token_list.insert(insertion_idx, new_token)
397
+
398
+
399
+ @add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
400
+ class PythonBackend(PreTrainedTokenizerBase):
401
+ """
402
+ Base class for all slow tokenizers.
403
+
404
+ Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`].
405
+
406
+ Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading
407
+ pretrained tokenizers as well as adding tokens to the vocabulary.
408
+
409
+ This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the
410
+ specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
411
+ """
412
+
413
+ def __init__(self, **kwargs):
414
+ # 1. Init the parent class
415
+
416
+ self.tokens_trie = Trie()
417
+
418
+ # Initialize total_vocab_size early to avoid issues if get_vocab() is called early (custom tokenizers)
419
+ self.total_vocab_size = 0
420
+
421
+ # 2. init `_added_tokens_decoder` if child class did not
422
+ if not hasattr(self, "_added_tokens_decoder"):
423
+ self._added_tokens_decoder: dict[int, AddedToken] = {}
424
+
425
+ # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite
426
+ self._added_tokens_decoder.update(kwargs.pop("added_tokens_decoder", {}))
427
+ self._added_tokens_encoder: dict[str, int] = {k.content: v for v, k in self._added_tokens_decoder.items()}
428
+
429
+ # 4. Token type ID configuration for dynamic mask building
430
+ # These can be overridden by subclasses to avoid overriding create_token_type_ids_from_sequences
431
+ self.token_type_ids_pattern = kwargs.pop("token_type_ids_pattern", "bert_style") # "all_zeros" or "bert_style"
432
+ self.token_type_ids_include_special_tokens = kwargs.pop("token_type_ids_include_special_tokens", True)
433
+
434
+ # 5. Special tokens mask configuration
435
+ # Patterns: "none", "cls_sep", "eos", "bos", "bos_eos", "cls_double_sep", "prefix_suffix"
436
+ self.special_tokens_pattern = kwargs.pop("special_tokens_pattern", None)
437
+
438
+ # 6. Set backend to "custom" if not already set (for direct PreTrainedTokenizer subclasses)
439
+ if "backend" not in kwargs:
440
+ kwargs["backend"] = "custom"
441
+
442
+ # 7. init the parent class
443
+ super().__init__(**kwargs)
444
+
445
+ # 4. If some of the special tokens are not part of the vocab, we add them, at the end.
446
+ # V5: the order of addition follows self.SPECIAL_TOKENS_ATTRIBUTES, then extra special tokens
447
+ # Note: _add_tokens will automatically skip tokens that are already in the base vocab
448
+ self._add_tokens(
449
+ [token for token in self.all_special_tokens if token not in self._added_tokens_encoder],
450
+ special_tokens=True,
451
+ )
452
+
453
+ @property
454
+ def is_fast(self) -> bool:
455
+ return False
456
+
457
+ @property
458
+ def added_tokens_encoder(self) -> dict[str, int]:
459
+ """
460
+ Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
461
+ optimisation in `self._added_tokens_encoder` for the slow tokenizers.
462
+ """
463
+ return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
464
+
465
+ @property
466
+ def added_tokens_decoder(self) -> dict[int, AddedToken]:
467
+ """
468
+ Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
469
+
470
+ Returns:
471
+ `dict[str, int]`: The added tokens.
472
+ """
473
+ return dict(sorted(self._added_tokens_decoder.items(), key=lambda item: item[0]))
474
+
475
+ @added_tokens_decoder.setter
476
+ def added_tokens_decoder(self, value: dict[int, AddedToken | str]) -> dict[int, AddedToken]:
477
+ # Always raise an error if string because users should define the behavior
478
+ for index, token in value.items():
479
+ if not isinstance(token, (str, AddedToken)) or not isinstance(index, int):
480
+ raise TypeError(
481
+ f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, should be a dict of {int, AddedToken | str}"
482
+ )
483
+
484
+ self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
485
+ self._added_tokens_encoder[str(token)] = index
486
+ self._update_total_vocab_size()
487
+
488
+ def get_added_vocab(self) -> dict[str, int]:
489
+ """
490
+ Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
491
+ the fast call because for now we always add the tokens even if they are already in the vocabulary. This is
492
+ something we should change.
493
+
494
+ Returns:
495
+ `dict[str, int]`: The added tokens.
496
+ """
497
+ return self._added_tokens_encoder
498
+
499
+ def __len__(self):
500
+ """
501
+ Size of the full vocabulary with the added tokens.
502
+ """
503
+ # Lazy evaluation: compute if not already set (e.g., during initialization)
504
+ if self.total_vocab_size == 0:
505
+ self._update_total_vocab_size()
506
+ return self.total_vocab_size
507
+
508
+ def _update_total_vocab_size(self):
509
+ """
510
+ Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because
511
+ otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and
512
+ is only updated when adding tokens.
513
+ """
514
+ self.total_vocab_size = len(self.get_vocab())
515
+
516
+ def _add_tokens(self, new_tokens: list[str] | list[AddedToken], special_tokens: bool = False) -> int:
517
+ """
518
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
519
+ it with indices starting from length of the current vocabulary. Special tokens are sometimes already in the
520
+ vocab which is why they have to be handled specifically.
521
+
522
+ Args:
523
+ new_tokens (`list[str]`or `list[tokenizers.AddedToken]`):
524
+ Token(s) to add in vocabulary. A token is counted as added if it's not already in the vocabulary
525
+ (tested by checking if the tokenizer assign the index of the `unk_token` to them). If a token is part
526
+ of the vocabulary then we simply mark this token as an `AddedToken` which allows to control the
527
+ stripping and normalization of this token. This is NOT possible in `tokenizers`.
528
+ special_tokens (`bool`, *optional*, defaults to `False`):
529
+ Whether or not the tokens should be added as special tokens.
530
+
531
+ Returns:
532
+ `int`: The number of tokens actually added to the vocabulary.
533
+
534
+ Examples:
535
+
536
+ ```python
537
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
538
+ tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
539
+ model = BertModel.from_pretrained("google-bert/bert-base-uncased")
540
+
541
+ num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
542
+ print("We have added", num_added_toks, "tokens")
543
+ # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
544
+ model.resize_token_embeddings(len(tokenizer))
545
+ ```"""
546
+ added_tokens = 0
547
+ if new_tokens is None:
548
+ return added_tokens
549
+ # TODO this is fairly slow to improve!
550
+ current_vocab = self.get_vocab().copy()
551
+ new_idx = len(current_vocab) # only call this once, len gives the last index + 1
552
+ for token in new_tokens:
553
+ if not isinstance(token, (str, AddedToken)):
554
+ raise TypeError(f"Token {token} is not a string but a {type(token)}.")
555
+ if str(token) == "":
556
+ continue
557
+ if isinstance(token, str):
558
+ if token in self._added_tokens_encoder:
559
+ continue
560
+ else:
561
+ # very important for fast and slow equivalence!
562
+ is_special = token in self.all_special_tokens or special_tokens
563
+ token = AddedToken(
564
+ token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special
565
+ )
566
+ elif special_tokens:
567
+ # doing token.special=True changes the normalization! will fix in rust
568
+ # this is important and the only reason why the AddedTokens in each class are normalized by default
569
+ token.__setstate__({"special": True, "normalized": token.normalized})
570
+ if token in self._added_tokens_decoder:
571
+ continue
572
+ if not token.special and token.normalized and getattr(self, "do_lower_case", False):
573
+ # Normalize if requested
574
+ token.content = token.content.lower()
575
+ if token.content not in current_vocab:
576
+ token_index = new_idx + added_tokens
577
+ current_vocab[token.content] = token_index
578
+ added_tokens += 1
579
+ else:
580
+ token_index = current_vocab[token.content]
581
+
582
+ if token.special and str(token) not in self.all_special_tokens:
583
+ self._extra_special_tokens.append(token)
584
+ # the setter automatically updates the reverse map
585
+ self._added_tokens_decoder[token_index] = token
586
+ self._added_tokens_encoder[token.content] = token_index
587
+ if self.verbose:
588
+ logger.info(f"Adding {token} to the vocabulary")
589
+
590
+ self._update_trie()
591
+ self._update_total_vocab_size()
592
+ return added_tokens
593
+
594
+ def _update_trie(self, unique_no_split_tokens: list[str] | None = None):
595
+ for token in self._added_tokens_decoder.values():
596
+ if token.content not in self.tokens_trie._tokens:
597
+ self.tokens_trie.add(token.content)
598
+ for token in unique_no_split_tokens or []:
599
+ if token not in self.tokens_trie._tokens:
600
+ self.tokens_trie.add(token)
601
+
602
+ def num_special_tokens_to_add(self, pair: bool = False) -> int:
603
+ """
604
+ Returns the number of added tokens when encoding a sequence with special tokens.
605
+
606
+ <Tip>
607
+
608
+ This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
609
+ this inside your training loop.
610
+
611
+ </Tip>
612
+
613
+ Args:
614
+ pair (`bool`, *optional*, defaults to `False`):
615
+ Whether the number of added tokens should be computed in the case of a sequence pair or a single
616
+ sequence.
617
+
618
+ Returns:
619
+ `int`: Number of special tokens added to sequences.
620
+ """
621
+ token_ids_0 = []
622
+ token_ids_1 = []
623
+ return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
624
+
625
+ def tokenize(self, text: TextInput, **kwargs) -> list[str]:
626
+ """
627
+ Converts a string into a sequence of tokens, using the tokenizer.
628
+
629
+ Args:
630
+ text: The sequence to be encoded.
631
+ **kwargs: Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
632
+
633
+ Returns:
634
+ The list of tokens.
635
+ """
636
+ split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)
637
+ text, kwargs = self.prepare_for_tokenization(text, **kwargs)
638
+
639
+ if split_special_tokens:
640
+ # Don't split on any tokens - just tokenize directly
641
+ return self._tokenize(text)
642
+
643
+ # Split on added tokens
644
+ tokens = self.tokens_trie.split(text)
645
+ no_split_token = self._added_tokens_encoder.keys()
646
+
647
+ # Handle added token properties (lstrip, rstrip, single_word)
648
+ for i, token in enumerate(tokens):
649
+ if token in no_split_token:
650
+ tok_extended = self._added_tokens_decoder.get(self._added_tokens_encoder[token])
651
+ left = tokens[i - 1] if i > 0 else None
652
+ right = tokens[i + 1] if i < len(tokens) - 1 else None
653
+
654
+ if isinstance(tok_extended, AddedToken):
655
+ if tok_extended.rstrip and right:
656
+ tokens[i + 1] = right.lstrip()
657
+ if tok_extended.lstrip and left:
658
+ tokens[i - 1] = left.rstrip()
659
+ if tok_extended.single_word:
660
+ if left and left[-1] != " ":
661
+ tokens[i - 1] += token
662
+ tokens[i] = ""
663
+ elif right and right[0] != " ":
664
+ tokens[i + 1] = token + tokens[i + 1]
665
+ tokens[i] = ""
666
+
667
+ # Tokenize non-added tokens
668
+ result = []
669
+ all_special_tokens_set = set(self.all_special_tokens)
670
+ for token in tokens:
671
+ if not token:
672
+ continue
673
+ if token in no_split_token or token in all_special_tokens_set:
674
+ result.append(token)
675
+ else:
676
+ result.extend(self._tokenize(token))
677
+
678
+ return result
679
+
680
+ def _tokenize(self, text, **kwargs):
681
+ """
682
+ Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
683
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
684
+
685
+ Do NOT take care of added tokens.
686
+ """
687
+ raise NotImplementedError
688
+
689
+ def _convert_token_to_id_with_added_voc(self, token):
690
+ if token in self.added_tokens_encoder:
691
+ return self.added_tokens_encoder[token]
692
+ return self._convert_token_to_id(token)
693
+
694
+ def _convert_token_to_id(self, token):
695
+ raise NotImplementedError
696
+
697
+ def _encode_plus(
698
+ self,
699
+ text: TextInput | PreTokenizedInput | EncodedInput,
700
+ text_pair: TextInput | PreTokenizedInput | EncodedInput | None = None,
701
+ add_special_tokens: bool = True,
702
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
703
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
704
+ max_length: int | None = None,
705
+ stride: int = 0,
706
+ is_split_into_words: bool = False,
707
+ pad_to_multiple_of: int | None = None,
708
+ padding_side: str | None = None,
709
+ return_tensors: str | TensorType | None = None,
710
+ return_token_type_ids: bool | None = None,
711
+ return_attention_mask: bool | None = None,
712
+ return_overflowing_tokens: bool = False,
713
+ return_special_tokens_mask: bool = False,
714
+ return_length: bool = False,
715
+ verbose: bool = True,
716
+ **kwargs,
717
+ ) -> BatchEncoding:
718
+ # Detect batched inputs (list of sequences)
719
+ is_batched = isinstance(text, (list, tuple)) and (
720
+ (not text and not is_split_into_words)
721
+ or (text and is_split_into_words and isinstance(text[0], (list, tuple)))
722
+ or (text and not is_split_into_words and isinstance(text[0], (str, list, tuple)))
723
+ )
724
+
725
+ if is_batched:
726
+ if text_pair is not None:
727
+ if not isinstance(text_pair, (list, tuple)) or len(text_pair) != len(text):
728
+ raise ValueError("If `text` is a batch, `text_pair` must also be a batch of the same length.")
729
+ pairs = text_pair if text_pair is not None else [None] * len(text)
730
+
731
+ batch_outputs = {}
732
+ for current_text, current_pair in zip(text, pairs):
733
+ # Handle tuples/lists as sequence pairs like ("text1", "text2")
734
+ # For is_split_into_words=True: only unpack if it's a tuple of exactly 2 sequences (pair)
735
+ # Otherwise, treat the list as a single pretokenized sequence
736
+ if (
737
+ isinstance(current_text, (list, tuple))
738
+ and current_text
739
+ and not isinstance(current_text[0], int)
740
+ and current_pair is None
741
+ ):
742
+ # Check if this looks like a pair: tuple/list of length 2 where elements are strings or lists/tuples
743
+ is_pair = (
744
+ len(current_text) == 2
745
+ and (isinstance(current_text[0], str) or isinstance(current_text[0], (list, tuple)))
746
+ and (isinstance(current_text[1], str) or isinstance(current_text[1], (list, tuple)))
747
+ )
748
+ if is_pair:
749
+ current_text, current_pair = current_text
750
+ elif len(current_text) == 1:
751
+ current_text = current_text[0]
752
+ elif not is_split_into_words:
753
+ # Only raise error for non-pretokenized input
754
+ raise ValueError(f"Expected a pair of sequences, got {len(current_text)} sequences.")
755
+
756
+ current_output = self._encode_plus(
757
+ text=current_text,
758
+ text_pair=current_pair,
759
+ add_special_tokens=add_special_tokens,
760
+ padding_strategy=PaddingStrategy.DO_NOT_PAD, # we pad in batch afterward
761
+ truncation_strategy=truncation_strategy,
762
+ max_length=max_length,
763
+ stride=stride,
764
+ is_split_into_words=is_split_into_words,
765
+ pad_to_multiple_of=None, # we pad in batch afterward
766
+ padding_side=None, # we pad in batch afterward
767
+ return_tensors=None, # We convert the whole batch to tensors at the end
768
+ return_token_type_ids=return_token_type_ids,
769
+ return_attention_mask=False, # we pad in batch afterward
770
+ return_overflowing_tokens=return_overflowing_tokens,
771
+ return_special_tokens_mask=return_special_tokens_mask,
772
+ return_length=return_length,
773
+ verbose=verbose,
774
+ **kwargs,
775
+ )
776
+ for key, value in current_output.items():
777
+ batch_outputs.setdefault(key, []).append(value)
778
+
779
+ # Remove overflow-related keys before tensor conversion if return_tensors is set
780
+ # Slow tokenizers don't support returning these as tensors
781
+ if return_tensors and return_overflowing_tokens:
782
+ batch_outputs.pop("overflowing_tokens", None)
783
+ batch_outputs.pop("num_truncated_tokens", None)
784
+
785
+ batch_outputs = self.pad(
786
+ batch_outputs,
787
+ padding=padding_strategy.value,
788
+ max_length=max_length,
789
+ pad_to_multiple_of=pad_to_multiple_of,
790
+ padding_side=padding_side,
791
+ return_attention_mask=return_attention_mask,
792
+ )
793
+
794
+ return BatchEncoding(batch_outputs, tensor_type=return_tensors)
795
+
796
+ # Single sequence handling
797
+ def get_input_ids(text):
798
+ if isinstance(text, str):
799
+ # Normal case: tokenize string
800
+ return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
801
+ if isinstance(text, (list, tuple)) and text:
802
+ if isinstance(text[0], int):
803
+ return text
804
+ # Pre-tokenized strings
805
+ if isinstance(text[0], str):
806
+ if is_split_into_words:
807
+ return self.convert_tokens_to_ids(
808
+ [tok for word in text for tok in self.tokenize(word, **kwargs)]
809
+ )
810
+ return self.convert_tokens_to_ids(text)
811
+ raise ValueError(f"Input must be a string, list of strings, or list of ints, got: {type(text)}")
812
+
813
+ first_ids = get_input_ids(text)
814
+ second_ids = get_input_ids(text_pair) if text_pair is not None else None
815
+
816
+ return self.prepare_for_model(
817
+ first_ids,
818
+ pair_ids=second_ids,
819
+ add_special_tokens=add_special_tokens,
820
+ padding=padding_strategy.value,
821
+ truncation=truncation_strategy.value,
822
+ max_length=max_length,
823
+ stride=stride,
824
+ pad_to_multiple_of=pad_to_multiple_of,
825
+ padding_side=padding_side,
826
+ return_tensors=return_tensors,
827
+ prepend_batch_axis=True,
828
+ return_attention_mask=return_attention_mask,
829
+ return_token_type_ids=return_token_type_ids,
830
+ return_overflowing_tokens=return_overflowing_tokens,
831
+ return_special_tokens_mask=return_special_tokens_mask,
832
+ return_length=return_length,
833
+ verbose=verbose,
834
+ )
835
+
836
+ def prepare_for_tokenization(
837
+ self, text: str, is_split_into_words: bool = False, **kwargs
838
+ ) -> tuple[str, dict[str, Any]]:
839
+ """
840
+ Performs any necessary transformations before tokenization.
841
+
842
+ This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the
843
+ `kwargs` at the end of the encoding process to be sure all the arguments have been used.
844
+
845
+ Args:
846
+ text (`str`):
847
+ The text to prepare.
848
+ is_split_into_words (`bool`, *optional*, defaults to `False`):
849
+ Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
850
+ tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
851
+ which it will tokenize. This is useful for NER or token classification.
852
+ kwargs (`dict[str, Any]`, *optional*):
853
+ Keyword arguments to use for the tokenization.
854
+
855
+ Returns:
856
+ `tuple[str, dict[str, Any]]`: The prepared text and the unused kwargs.
857
+ """
858
+ return (text, kwargs)
859
+
860
+ def build_inputs_with_special_tokens(
861
+ self, token_ids_0: list[int], token_ids_1: list[int] | None = None
862
+ ) -> list[int]:
863
+ """
864
+ Build model inputs from a sequence or a pair of sequences by adding special tokens.
865
+
866
+ This method dynamically builds inputs based on the tokenizer's `special_tokens_pattern`:
867
+ - `"none"`: No special tokens
868
+ - `"cls_sep"`: [CLS] seq0 [SEP] or [CLS] seq0 [SEP] seq1 [SEP]
869
+ - `"eos"`: seq0 [EOS] or seq0 [EOS] seq1 [EOS]
870
+ - `"bos"`: [BOS] seq0 or [BOS] seq0 [BOS] seq1
871
+ - `"bos_eos"`: [BOS] seq0 [EOS] or [BOS] seq0 [EOS] seq1 [EOS]
872
+ - `"cls_double_sep"`: [CLS] seq0 [SEP] or [CLS] seq0 [SEP] [SEP] seq1 [SEP]
873
+ - `"prefix_suffix"`: `<prefix_tokens> seq0 [seq1] <suffix_tokens>` (custom prefix/suffix stored on the tokenizer)
874
+
875
+ Args:
876
+ token_ids_0 (`list[int]`):
877
+ List of IDs to which the special tokens will be added.
878
+ token_ids_1 (`list[int]`, *optional*):
879
+ Optional second list of IDs for sequence pairs.
880
+
881
+ Returns:
882
+ `list[int]`: List of input IDs with the appropriate special tokens.
883
+ """
884
+ if self.special_tokens_pattern == "cls_sep":
885
+ # [CLS] seq0 [SEP] or [CLS] seq0 [SEP] seq1 [SEP]
886
+ if self.cls_token_id is None and self.sep_token_id is None:
887
+ raise ValueError(
888
+ "Cannot add special tokens following 'cls_sep' pattern because one or several special tokens "
889
+ f"are not defined (cls_token_id={self.cls_token_id}; sep_token_id={self.sep_token_id})"
890
+ "Set the required special tokens in tokenizer or update `tokenizer.special_tokens_pattern`"
891
+ )
892
+ if token_ids_1 is None:
893
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
894
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1 + [self.sep_token_id]
895
+
896
+ elif self.special_tokens_pattern == "eos":
897
+ # seq0 [EOS] or seq0 [EOS] seq1 [EOS]
898
+ if self.eos_token_id is None:
899
+ raise ValueError(
900
+ "Cannot add special tokens following 'eos' pattern because eos token is not defined "
901
+ f"(eos_token_id={self.eos_token_id})."
902
+ "Set the required special tokens in tokenizer or update `tokenizer.special_tokens_pattern`"
903
+ )
904
+ if token_ids_1 is None:
905
+ return token_ids_0 + [self.eos_token_id]
906
+ return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
907
+
908
+ elif self.special_tokens_pattern == "bos":
909
+ # [BOS] seq0 or [BOS] seq0 [BOS] seq1
910
+ if self.bos_token_id is None:
911
+ raise ValueError(
912
+ "Cannot add special tokens following 'bos' pattern because bos token is not defined "
913
+ f"(bos_token_id={self.bos_token_id})."
914
+ "Set the required special tokens in tokenizer or update `tokenizer.special_tokens_pattern`"
915
+ )
916
+ if token_ids_1 is None:
917
+ return [self.bos_token_id] + token_ids_0
918
+ return [self.bos_token_id] + token_ids_0 + [self.bos_token_id] + token_ids_1
919
+
920
+ elif self.special_tokens_pattern == "bos_eos":
921
+ # [BOS] seq0 [EOS] or [BOS] seq0 [EOS] seq1 [EOS]
922
+ if self.bos_token_id is None and self.eos_token_id is None:
923
+ raise ValueError(
924
+ "Cannot add special tokens following 'bos_eos' pattern because one or several special tokens "
925
+ f"are not defined (bos_token_id={self.bos_token_id}; eos_token_id={self.eos_token_id})"
926
+ "Set the required special tokens in tokenizer or update `tokenizer.special_tokens_pattern`"
927
+ )
928
+ return token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1
929
+
930
+ if token_ids_1 is None:
931
+ return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
932
+ return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
933
+
934
+ elif self.special_tokens_pattern == "cls_double_sep":
935
+ # [CLS] seq0 [SEP] or [CLS] seq0 [SEP] [SEP] seq1 [SEP]
936
+ if self.cls_token_id is None and self.sep_token_id is None:
937
+ raise ValueError(
938
+ "Cannot add special tokens following 'cls_double_sep' pattern because one or several special tokens "
939
+ f"are not defined (cls_token_id={self.cls_token_id}; sep_token_id={self.sep_token_id})"
940
+ "Set the required special tokens in tokenizer or update `tokenizer.special_tokens_pattern`"
941
+ )
942
+ if token_ids_1 is None:
943
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
944
+ return (
945
+ [self.cls_token_id]
946
+ + token_ids_0
947
+ + [self.sep_token_id, self.sep_token_id]
948
+ + token_ids_1
949
+ + [self.sep_token_id]
950
+ )
951
+
952
+ elif self.special_tokens_pattern == "prefix_suffix":
953
+ prefix_tokens = getattr(self, "prefix_tokens", [])
954
+ suffix_tokens = getattr(self, "suffix_tokens", [])
955
+ if token_ids_1 is None:
956
+ return prefix_tokens + token_ids_0 + suffix_tokens
957
+ return prefix_tokens + token_ids_0 + token_ids_1 + suffix_tokens
958
+
959
+ else: # "none" or any other value
960
+ # No special tokens
961
+ if token_ids_1 is None:
962
+ return token_ids_0
963
+ return token_ids_0 + token_ids_1
964
+
965
+ def get_special_tokens_mask(
966
+ self, token_ids_0: list, token_ids_1: list | None = None, already_has_special_tokens: bool = False
967
+ ) -> list[int]:
968
+ """
969
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
970
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
971
+
972
+ This method dynamically builds the special tokens mask based on the tokenizer's `special_tokens_pattern`:
973
+ - `"none"`: No special tokens (default, returns all 0s)
974
+ - `"cls_sep"`: [CLS] seq0 [SEP] or [CLS] seq0 [SEP] seq1 [SEP]
975
+ - `"eos"`: seq0 [EOS] or seq0 [EOS] seq1 [EOS]
976
+ - `"bos"`: [BOS] seq0 or [BOS] seq0 [BOS] seq1
977
+ - `"bos_eos"`: [BOS] seq0 [EOS] or [BOS] seq0 [EOS] seq1 [EOS]
978
+ - `"cls_double_sep"`: [CLS] seq0 [SEP] or [CLS] seq0 [SEP] [SEP] seq1 [SEP]
979
+ - `"prefix_suffix"`: `<prefix_tokens> seq0 [seq1] <suffix_tokens>`
980
+
981
+ Args:
982
+ token_ids_0 (`list[int]`):
983
+ List of ids of the first sequence.
984
+ token_ids_1 (`list[int]`, *optional*):
985
+ List of ids of the second sequence.
986
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
987
+ Whether or not the token list is already formatted with special tokens for the model.
988
+
989
+ Returns:
990
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
991
+ """
992
+ if already_has_special_tokens:
993
+ if token_ids_1 is not None:
994
+ raise ValueError(
995
+ "You should not supply a second sequence if the provided sequence of "
996
+ "ids is already formatted with special tokens for the model."
997
+ )
998
+
999
+ return super().get_special_tokens_mask(
1000
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
1001
+ )
1002
+
1003
+ if self.special_tokens_pattern == "cls_sep":
1004
+ # [CLS] seq0 [SEP] or [CLS] seq0 [SEP] seq1 [SEP]
1005
+ if token_ids_1 is None:
1006
+ return [1] + ([0] * len(token_ids_0)) + [1]
1007
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
1008
+
1009
+ elif self.special_tokens_pattern == "eos":
1010
+ # seq0 [EOS] or seq0 [EOS] seq1 [EOS]
1011
+ if token_ids_1 is None:
1012
+ return ([0] * len(token_ids_0)) + [1]
1013
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
1014
+
1015
+ elif self.special_tokens_pattern == "bos":
1016
+ # [BOS] seq0 or [BOS] seq0 [BOS] seq1
1017
+ if token_ids_1 is None:
1018
+ return [1] + ([0] * len(token_ids_0))
1019
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
1020
+
1021
+ elif self.special_tokens_pattern == "bos_eos":
1022
+ # [BOS] seq0 [EOS] or [BOS] seq0 [EOS] seq1 [EOS]
1023
+ if token_ids_1 is None:
1024
+ return [1] + ([0] * len(token_ids_0)) + [1]
1025
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
1026
+
1027
+ elif self.special_tokens_pattern == "cls_double_sep":
1028
+ # [CLS] seq0 [SEP] or [CLS] seq0 [SEP] [SEP] seq1 [SEP]
1029
+ if token_ids_1 is None:
1030
+ return [1] + ([0] * len(token_ids_0)) + [1]
1031
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
1032
+
1033
+ elif self.special_tokens_pattern == "prefix_suffix":
1034
+ prefix_len = len(getattr(self, "prefix_tokens", []))
1035
+ suffix_len = len(getattr(self, "suffix_tokens", []))
1036
+ mask = [1] * prefix_len + ([0] * len(token_ids_0))
1037
+ if token_ids_1 is not None:
1038
+ mask += [0] * len(token_ids_1)
1039
+ mask += [1] * suffix_len
1040
+ return mask
1041
+
1042
+ else:
1043
+ return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
1044
+
1045
+ @overload
1046
+ def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
1047
+
1048
+ @overload
1049
+ def convert_ids_to_tokens(self, ids: list[int], skip_special_tokens: bool = False) -> list[str]: ...
1050
+
1051
+ def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False) -> str | list[str]:
1052
+ """
1053
+ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
1054
+ added tokens.
1055
+
1056
+ Args:
1057
+ ids (`int` or `list[int]`):
1058
+ The token id (or token ids) to convert to tokens.
1059
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
1060
+ Whether or not to remove special tokens in the decoding.
1061
+
1062
+ Returns:
1063
+ `str` or `list[str]`: The decoded token(s).
1064
+ """
1065
+ if isinstance(ids, int):
1066
+ return (
1067
+ self._added_tokens_decoder[ids].content
1068
+ if ids in self._added_tokens_decoder
1069
+ else self._convert_id_to_token(ids)
1070
+ )
1071
+
1072
+ tokens = []
1073
+ # self.all_special_ids is an @property which may be slow, so only compute it once before the loop
1074
+ ids_to_skip = set(self.all_special_ids) if skip_special_tokens else set()
1075
+ for index in ids:
1076
+ index = int(index)
1077
+ if index in ids_to_skip:
1078
+ continue
1079
+ tokens.append(
1080
+ self._added_tokens_decoder[index].content
1081
+ if index in self._added_tokens_decoder
1082
+ else self._convert_id_to_token(index)
1083
+ )
1084
+ return tokens
1085
+
1086
+ def _convert_id_to_token(self, index: int) -> str:
1087
+ raise NotImplementedError
1088
+
1089
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
1090
+ return " ".join(tokens)
1091
+
1092
+ def _decode(
1093
+ self,
1094
+ token_ids: int | list[int],
1095
+ skip_special_tokens: bool = False,
1096
+ clean_up_tokenization_spaces: bool | None = None,
1097
+ **kwargs,
1098
+ ) -> str:
1099
+ """Decode token ids to string."""
1100
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
1101
+ if isinstance(filtered_tokens, str):
1102
+ filtered_tokens = [filtered_tokens]
1103
+
1104
+ text = self.convert_tokens_to_string(filtered_tokens)
1105
+
1106
+ # Apply tokenizer-specific cleanup if available and requested
1107
+ clean_up_tokenization_spaces = (
1108
+ clean_up_tokenization_spaces
1109
+ if clean_up_tokenization_spaces is not None
1110
+ else self.clean_up_tokenization_spaces
1111
+ )
1112
+ if clean_up_tokenization_spaces:
1113
+ text = self.clean_up_tokenization(text)
1114
+
1115
+ return text
1116
+
1117
+ def prepare_for_model(
1118
+ self,
1119
+ ids: list[int],
1120
+ pair_ids: list[int] | None = None,
1121
+ add_special_tokens: bool = True,
1122
+ padding: bool | str | PaddingStrategy = False,
1123
+ truncation: bool | str | TruncationStrategy = False,
1124
+ max_length: int | None = None,
1125
+ stride: int = 0,
1126
+ pad_to_multiple_of: int | None = None,
1127
+ padding_side: str | None = None,
1128
+ return_tensors: str | TensorType | None = None,
1129
+ return_token_type_ids: bool | None = None,
1130
+ return_attention_mask: bool | None = None,
1131
+ return_overflowing_tokens: bool = False,
1132
+ return_special_tokens_mask: bool = False,
1133
+ return_length: bool = False,
1134
+ verbose: bool = True,
1135
+ prepend_batch_axis: bool = False,
1136
+ **kwargs,
1137
+ ) -> BatchEncoding:
1138
+ """
1139
+ Prepares a sequence of input ids so it can be used by the model. Adds special tokens, truncates, and pads.
1140
+
1141
+ Args:
1142
+ ids: Tokenized input ids of the first sequence.
1143
+ pair_ids: Tokenized input ids of the second sequence (optional).
1144
+ """
1145
+ # Get padding/truncation strategies
1146
+ padding_strategy, truncation_strategy, max_length, _ = self._get_padding_truncation_strategies(
1147
+ padding=padding,
1148
+ truncation=truncation,
1149
+ max_length=max_length,
1150
+ pad_to_multiple_of=pad_to_multiple_of,
1151
+ verbose=verbose,
1152
+ **kwargs,
1153
+ )
1154
+
1155
+ # Validation
1156
+ if (
1157
+ return_overflowing_tokens
1158
+ and truncation_strategy == TruncationStrategy.LONGEST_FIRST
1159
+ and pair_ids is not None
1160
+ ):
1161
+ raise ValueError(
1162
+ "Not possible to return overflowing tokens for pair of sequences with the "
1163
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
1164
+ "for instance `only_second` or `only_first`."
1165
+ )
1166
+
1167
+ # Defaults
1168
+ if return_token_type_ids is None:
1169
+ return_token_type_ids = "token_type_ids" in self.model_input_names
1170
+ if return_attention_mask is None:
1171
+ return_attention_mask = "attention_mask" in self.model_input_names
1172
+
1173
+ # Truncation
1174
+ pair = pair_ids is not None
1175
+ num_special = self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0
1176
+ total_len = len(ids) + len(pair_ids or []) + num_special
1177
+
1178
+ overflowing_tokens = []
1179
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
1180
+ ids, pair_ids, overflowing_tokens = self.truncate_sequences(
1181
+ ids,
1182
+ pair_ids=pair_ids,
1183
+ num_tokens_to_remove=total_len - max_length,
1184
+ truncation_strategy=truncation_strategy,
1185
+ stride=stride,
1186
+ )
1187
+
1188
+ # Add special tokens
1189
+ if add_special_tokens:
1190
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
1191
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
1192
+ else:
1193
+ sequence = ids + (pair_ids if pair_ids else [])
1194
+ token_type_ids = [0] * len(sequence)
1195
+
1196
+ # Build output
1197
+ encoded_inputs = {"input_ids": sequence}
1198
+ if return_token_type_ids:
1199
+ encoded_inputs["token_type_ids"] = token_type_ids
1200
+ if return_special_tokens_mask:
1201
+ encoded_inputs["special_tokens_mask"] = (
1202
+ self.get_special_tokens_mask(ids, pair_ids) if add_special_tokens else [0] * len(sequence)
1203
+ )
1204
+ if return_overflowing_tokens and not return_tensors and overflowing_tokens:
1205
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
1206
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length if max_length else 0
1207
+
1208
+ # Check sequence length and warn if needed
1209
+ self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
1210
+
1211
+ # Pad
1212
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
1213
+ encoded_inputs = self.pad(
1214
+ encoded_inputs,
1215
+ max_length=max_length,
1216
+ padding=padding_strategy.value,
1217
+ pad_to_multiple_of=pad_to_multiple_of,
1218
+ padding_side=padding_side,
1219
+ return_attention_mask=return_attention_mask,
1220
+ )
1221
+
1222
+ if return_length:
1223
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
1224
+
1225
+ return BatchEncoding(encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis)
1226
+
1227
+ def truncate_sequences(
1228
+ self,
1229
+ ids: list[int],
1230
+ pair_ids: list[int] | None = None,
1231
+ num_tokens_to_remove: int = 0,
1232
+ truncation_strategy: str | TruncationStrategy = "longest_first",
1233
+ stride: int = 0,
1234
+ ) -> tuple[list[int], list[int], list[int]]:
1235
+ """Truncates sequences according to the specified strategy."""
1236
+ if num_tokens_to_remove <= 0:
1237
+ return ids, pair_ids, []
1238
+
1239
+ if not isinstance(truncation_strategy, TruncationStrategy):
1240
+ truncation_strategy = TruncationStrategy(truncation_strategy)
1241
+
1242
+ overflowing_tokens = []
1243
+
1244
+ # ONLY_FIRST or LONGEST_FIRST with single sequence
1245
+ if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
1246
+ truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
1247
+ ):
1248
+ window_len = min(len(ids), stride + num_tokens_to_remove)
1249
+ if self.truncation_side == "left":
1250
+ overflowing_tokens = ids[:window_len]
1251
+ ids = ids[num_tokens_to_remove:]
1252
+ else:
1253
+ overflowing_tokens = ids[-window_len:]
1254
+ ids = ids[:-num_tokens_to_remove]
1255
+
1256
+ # LONGEST_FIRST with pair
1257
+ elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
1258
+ logger.warning(
1259
+ "Be aware, overflowing tokens are not returned for the setting you have chosen,"
1260
+ f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
1261
+ "truncation strategy. So the returned list will always be empty even if some "
1262
+ "tokens have been removed."
1263
+ )
1264
+ len_ids, len_pair = len(ids), len(pair_ids) if pair_ids else 0
1265
+ first_remove = min(abs(len_pair - len_ids), num_tokens_to_remove)
1266
+ second_remove = num_tokens_to_remove - first_remove
1267
+
1268
+ if len_ids > len_pair:
1269
+ ids_to_move = first_remove + second_remove // 2
1270
+ pair_ids_to_move = second_remove - second_remove // 2
1271
+ else:
1272
+ ids_to_move = second_remove // 2
1273
+ pair_ids_to_move = first_remove + second_remove - (second_remove // 2)
1274
+
1275
+ if self.truncation_side == "right":
1276
+ ids = ids[:-ids_to_move] if ids_to_move > 0 else ids
1277
+ pair_ids = pair_ids[:-pair_ids_to_move] if pair_ids and pair_ids_to_move > 0 else pair_ids
1278
+ else:
1279
+ ids = ids[ids_to_move:]
1280
+ pair_ids = pair_ids[pair_ids_to_move:] if pair_ids else None
1281
+
1282
+ # ONLY_SECOND
1283
+ elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids:
1284
+ window_len = min(len(pair_ids), stride + num_tokens_to_remove)
1285
+ if self.truncation_side == "right":
1286
+ overflowing_tokens = pair_ids[-window_len:]
1287
+ pair_ids = pair_ids[:-num_tokens_to_remove]
1288
+ else:
1289
+ overflowing_tokens = pair_ids[:window_len]
1290
+ pair_ids = pair_ids[num_tokens_to_remove:]
1291
+
1292
+ return ids, pair_ids, overflowing_tokens
1293
+
1294
+ def create_token_type_ids_from_sequences(
1295
+ self, token_ids_0: list[int], token_ids_1: list[int] | None = None
1296
+ ) -> list[int]:
1297
+ """
1298
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task.
1299
+
1300
+ This method dynamically builds the token type IDs based on the tokenizer's configuration attributes:
1301
+ - `token_type_ids_pattern`: Pattern to use ("all_zeros" or "bert_style")
1302
+ - `token_type_ids_include_special_tokens`: Whether to account for special tokens in length calculation
1303
+
1304
+ Args:
1305
+ token_ids_0 (`list[int]`):
1306
+ List of IDs.
1307
+ token_ids_1 (`list[int]`, *optional*):
1308
+ Optional second list of IDs for sequence pairs.
1309
+
1310
+ Returns:
1311
+ `list[int]`: Token type IDs according to the configured pattern.
1312
+
1313
+ Examples:
1314
+ ```python
1315
+ # All zeros pattern (default, used by RoBERTa, BART, etc.)
1316
+ tokenizer.token_type_ids_pattern = "all_zeros"
1317
+ # Returns: [0, 0, 0, ...] for both sequences
1318
+
1319
+ # BERT-style pattern (first sequence gets 0s, second gets 1s)
1320
+ tokenizer.token_type_ids_pattern = "bert_style"
1321
+ # Returns: [0, 0, 0, ..., 1, 1, 1, ...] for sequence pairs
1322
+ ```
1323
+ """
1324
+ # Calculate lengths - account for special tokens if configured
1325
+ if self.token_type_ids_include_special_tokens:
1326
+ # Build the full sequence to get accurate length
1327
+ if token_ids_1 is None:
1328
+ sequence = self.build_inputs_with_special_tokens(token_ids_0)
1329
+ seq0_len = len(sequence)
1330
+ seq1_len = 0
1331
+ else:
1332
+ full_sequence = self.build_inputs_with_special_tokens(token_ids_0, token_ids_1)
1333
+ # Approximate split - this works for most tokenizers
1334
+ # For more complex cases, subclasses should still override
1335
+ seq0_with_special = self.build_inputs_with_special_tokens(token_ids_0)
1336
+ seq0_len = len(seq0_with_special)
1337
+ seq1_len = len(full_sequence) - seq0_len
1338
+ else:
1339
+ # Use raw token lengths
1340
+ seq0_len = len(token_ids_0)
1341
+ seq1_len = len(token_ids_1) if token_ids_1 is not None else 0
1342
+
1343
+ # Build token type IDs based on pattern
1344
+ if self.special_tokens_pattern == "prefix_suffix":
1345
+ total_len = len(getattr(self, "prefix_tokens", [])) + len(token_ids_0)
1346
+ if token_ids_1 is not None:
1347
+ total_len += len(token_ids_1)
1348
+ total_len += len(getattr(self, "suffix_tokens", []))
1349
+ return [0] * total_len
1350
+
1351
+ if self.token_type_ids_pattern == "bert_style" and token_ids_1 is not None:
1352
+ # BERT-style: first sequence gets 0s, second sequence gets 1s
1353
+ return [0] * seq0_len + [1] * seq1_len
1354
+ else:
1355
+ # All zeros pattern (default): everything gets 0s
1356
+ return [0] * (seq0_len + seq1_len)
1357
+
1358
+ def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str, ...]:
1359
+ """
1360
+ Default implementation for common vocabulary saving patterns.
1361
+ Saves self.encoder/self.vocab as JSON, optionally with self.bpe_ranks as merges.
1362
+ Returns empty tuple if no vocabulary exists.
1363
+
1364
+ Override this method if your tokenizer needs custom saving logic (e.g., SentencePiece models,
1365
+ multiple vocabulary files, or special file formats).
1366
+
1367
+ Args:
1368
+ save_directory (`str`):
1369
+ The directory in which to save the vocabulary.
1370
+ filename_prefix (`str`, *optional*):
1371
+ An optional prefix to add to the named of the saved files.
1372
+
1373
+ Returns:
1374
+ `tuple[str, ...]`: Paths to the files saved, or empty tuple if no files saved.
1375
+ """
1376
+ import json
1377
+ import os
1378
+
1379
+ vocab_attr = getattr(self, "encoder", None) or getattr(self, "vocab", None)
1380
+ if vocab_attr is None:
1381
+ return ()
1382
+
1383
+ if not os.path.isdir(save_directory):
1384
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
1385
+ return ()
1386
+
1387
+ vocab_files_names = getattr(self, "vocab_files_names", {})
1388
+ prefix = f"{filename_prefix}-" if filename_prefix else ""
1389
+
1390
+ # Save vocabulary
1391
+ vocab_file = os.path.join(save_directory, prefix + vocab_files_names.get("vocab_file", "vocab.json"))
1392
+ with open(vocab_file, "w", encoding="utf-8") as f:
1393
+ f.write(json.dumps(vocab_attr, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
1394
+
1395
+ # Save BPE merges if present
1396
+ bpe_ranks = getattr(self, "bpe_ranks", None)
1397
+ if bpe_ranks is None:
1398
+ return (vocab_file,)
1399
+
1400
+ merge_file = os.path.join(save_directory, prefix + vocab_files_names.get("merges_file", "merges.txt"))
1401
+ with open(merge_file, "w", encoding="utf-8") as writer:
1402
+ if getattr(self, "add_bpe_version_header", False):
1403
+ writer.write("#version: 0.2\n")
1404
+
1405
+ index = 0
1406
+ for bpe_tokens, token_index in sorted(bpe_ranks.items(), key=lambda kv: kv[1]):
1407
+ if index != token_index:
1408
+ logger.warning(
1409
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
1410
+ " Please check that the tokenizer is not corrupted!"
1411
+ )
1412
+ index = token_index
1413
+ writer.write(" ".join(bpe_tokens) + "\n")
1414
+ index += 1
1415
+
1416
+ return (vocab_file, merge_file)
1417
+
1418
+
1419
+ # Backward compatibility alias
1420
+ PreTrainedTokenizer = PythonBackend
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/video_utils.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import warnings
17
+ from collections.abc import Callable, Iterable, Mapping
18
+ from contextlib import redirect_stdout
19
+ from dataclasses import dataclass, fields
20
+ from io import BytesIO
21
+ from typing import NewType, Union
22
+ from urllib.parse import urlparse
23
+
24
+ import httpx
25
+ import numpy as np
26
+
27
+ from .image_transforms import PaddingMode, to_channel_dimension_format
28
+ from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image
29
+ from .utils import (
30
+ is_av_available,
31
+ is_cv2_available,
32
+ is_decord_available,
33
+ is_numpy_array,
34
+ is_torch_available,
35
+ is_torch_tensor,
36
+ is_torchcodec_available,
37
+ is_torchvision_available,
38
+ is_vision_available,
39
+ is_yt_dlp_available,
40
+ logging,
41
+ requires_backends,
42
+ )
43
+
44
+
45
+ if is_vision_available():
46
+ import PIL.Image
47
+
48
+ if is_torchvision_available():
49
+ from torchvision import io as torchvision_io
50
+
51
+ if is_torch_available():
52
+ import torch
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ URL = NewType("URL", str)
58
+ Path = NewType("Path", str)
59
+
60
+ VideoInput = Union[
61
+ list["PIL.Image.Image"],
62
+ np.ndarray,
63
+ "torch.Tensor",
64
+ list[np.ndarray],
65
+ list["torch.Tensor"],
66
+ list[list["PIL.Image.Image"]],
67
+ list[list[np.ndarray]],
68
+ list[list["torch.Tensor"]],
69
+ URL,
70
+ list[URL],
71
+ list[list[URL]],
72
+ Path,
73
+ list[Path],
74
+ list[list[Path]],
75
+ ]
76
+
77
+
78
+ @dataclass
79
+ class VideoMetadata(Mapping):
80
+ total_num_frames: int
81
+ fps: float | None = None
82
+ width: int | None = None
83
+ height: int | None = None
84
+ duration: float | None = None
85
+ video_backend: str | None = None
86
+ frames_indices: list[int] | None = None
87
+
88
+ def __iter__(self):
89
+ return (f.name for f in fields(self))
90
+
91
+ def __len__(self):
92
+ return len(fields(self))
93
+
94
+ def __getitem__(self, item):
95
+ return getattr(self, item)
96
+
97
+ def __setitem__(self, key, value):
98
+ return setattr(self, key, value)
99
+
100
+ @property
101
+ def timestamps(self) -> list[float]:
102
+ "Timestamps of the sampled frames in seconds."
103
+ if self.fps is None or self.frames_indices is None:
104
+ raise ValueError("Cannot infer video `timestamps` when `fps` or `frames_indices` is None.")
105
+ return [frame_idx / self.fps for frame_idx in self.frames_indices]
106
+
107
+ @property
108
+ def sampled_fps(self) -> float:
109
+ "FPS of the sampled video."
110
+ if self.frames_indices is None or self.total_num_frames is None or self.fps is None:
111
+ return self.fps or 24
112
+ return len(self.frames_indices) / self.total_num_frames * self.fps
113
+
114
+ def update(self, dictionary):
115
+ for key, value in dictionary.items():
116
+ if hasattr(self, key):
117
+ setattr(self, key, value)
118
+
119
+
120
+ VideoMetadataType = VideoMetadata | dict | list[dict | VideoMetadata] | list[list[dict | VideoMetadata]]
121
+
122
+
123
+ def is_valid_video_frame(frame):
124
+ return isinstance(frame, PIL.Image.Image) or (
125
+ (is_numpy_array(frame) or is_torch_tensor(frame)) and frame.ndim == 3
126
+ )
127
+
128
+
129
+ def is_valid_video(video):
130
+ if not isinstance(video, (list, tuple)):
131
+ return (is_numpy_array(video) or is_torch_tensor(video)) and video.ndim == 4
132
+ return video and all(is_valid_video_frame(frame) for frame in video)
133
+
134
+
135
+ def valid_videos(videos):
136
+ # If we have a list of videos, it could be either one video as list of frames or a batch
137
+ if isinstance(videos, (list, tuple)):
138
+ for video_or_frame in videos:
139
+ if not (is_valid_video(video_or_frame) or is_valid_video_frame(video_or_frame)):
140
+ return False
141
+ # If not a list, then we have a single 4D video or 5D batched tensor
142
+ elif not is_valid_video(videos) or videos.ndim == 5:
143
+ return False
144
+ return True
145
+
146
+
147
+ def is_batched_video(videos):
148
+ if isinstance(videos, (list, tuple)):
149
+ return is_valid_video(videos[0])
150
+ elif (is_numpy_array(videos) or is_torch_tensor(videos)) and videos.ndim == 5:
151
+ return True
152
+ return False
153
+
154
+
155
+ def is_scaled_video(video: np.ndarray) -> bool:
156
+ """
157
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
158
+ """
159
+ # It's possible the video has pixel values in [0, 255] but is of floating type
160
+ return np.min(video) >= 0 and np.max(video) <= 1
161
+
162
+
163
+ def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union[np.ndarray, "torch.Tensor"]]:
164
+ """
165
+ Given a batch of videos, converts each video to a 4D array. If video is already in array type,
166
+ it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element.
167
+
168
+ Args:
169
+ videos (`VideoInput`):
170
+ Video inputs to turn into a list of videos.
171
+ """
172
+
173
+ if not (isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0])):
174
+ return videos
175
+
176
+ video_converted = []
177
+ for video in videos:
178
+ video = [np.array(frame) for frame in video]
179
+ video = np.stack(video)
180
+ video_converted.append(video)
181
+ return video_converted
182
+
183
+
184
+ def make_batched_videos(videos) -> list[Union[np.ndarray, "torch.Tensor", "URL", "Path"]]:
185
+ """
186
+ Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1.
187
+ If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image`
188
+ frames are converted to 4D arrays.
189
+
190
+ We assume that all inputs in the list are in the same format, based on the type of the first element.
191
+
192
+ Args:
193
+ videos (`VideoInput`):
194
+ Video inputs to turn into a list of videos.
195
+ """
196
+ # Early exit for deeply nested list of image frame paths. We shouldn't flatten them
197
+ try:
198
+ if isinstance(videos[0][0], list) and isinstance(videos[0][0][0], str):
199
+ return [image_paths for sublist in videos for image_paths in sublist]
200
+ except (IndexError, TypeError):
201
+ pass
202
+
203
+ if is_batched_video(videos):
204
+ return convert_pil_frames_to_video(list(videos))
205
+ elif isinstance(videos, str) or is_valid_video(videos):
206
+ return convert_pil_frames_to_video([videos])
207
+ # only one frame passed, thus we unsqueeze time dim
208
+ elif is_valid_image(videos):
209
+ if isinstance(videos, PIL.Image.Image):
210
+ videos = np.array(videos)
211
+ return [videos[None, ...]]
212
+ elif not isinstance(videos, list):
213
+ raise ValueError(
214
+ f"Invalid video input. Expected either a list of video frames or an input of 4 or 5 dimensions, but got"
215
+ f" type {type(videos)}."
216
+ )
217
+
218
+ # Recursively flatten any nested structure
219
+ flat_videos_list = []
220
+ for item in videos:
221
+ if isinstance(item, str) or is_valid_video(item):
222
+ flat_videos_list.append(item)
223
+ elif isinstance(item, list) and item:
224
+ flat_videos_list.extend(make_batched_videos(item))
225
+
226
+ flat_videos_list = convert_pil_frames_to_video(flat_videos_list)
227
+ return flat_videos_list
228
+
229
+
230
+ def make_batched_metadata(videos: VideoInput, video_metadata: VideoMetadataType) -> list[VideoMetadata]:
231
+ if video_metadata is None:
232
+ # Create default metadata and fill attributes we can infer from given video
233
+ video_metadata = [
234
+ {
235
+ "total_num_frames": len(video),
236
+ "fps": None,
237
+ "duration": None,
238
+ "frames_indices": list(range(len(video))),
239
+ "height": get_video_size(video)[0] if is_valid_video(video) else None,
240
+ "width": get_video_size(video)[1] if is_valid_video(video) else None,
241
+ }
242
+ for video in videos
243
+ ]
244
+
245
+ if isinstance(video_metadata, list):
246
+ # Flatten if nested list
247
+ if isinstance(video_metadata[0], list):
248
+ video_metadata = [
249
+ VideoMetadata(**metadata) for metadata_list in video_metadata for metadata in metadata_list
250
+ ]
251
+ # Simply wrap in VideoMetadata if simple dict
252
+ elif isinstance(video_metadata[0], dict):
253
+ video_metadata = [VideoMetadata(**metadata) for metadata in video_metadata]
254
+ else:
255
+ # Create a batched list from single object
256
+ video_metadata = [VideoMetadata(**video_metadata)]
257
+ return video_metadata
258
+
259
+
260
+ def get_video_size(video: np.ndarray, channel_dim: ChannelDimension | None = None) -> tuple[int, int]:
261
+ """
262
+ Returns the (height, width) dimensions of the video.
263
+
264
+ Args:
265
+ video (`np.ndarray`):
266
+ The video to get the dimensions of.
267
+ channel_dim (`ChannelDimension`, *optional*):
268
+ Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the video.
269
+
270
+ Returns:
271
+ A tuple of the video's height and width.
272
+ """
273
+ if channel_dim is None:
274
+ channel_dim = infer_channel_dimension_format(video, num_channels=(1, 3, 4))
275
+
276
+ if channel_dim == ChannelDimension.FIRST:
277
+ return video.shape[-2], video.shape[-1]
278
+ elif channel_dim == ChannelDimension.LAST:
279
+ return video.shape[-3], video.shape[-2]
280
+ else:
281
+ raise ValueError(f"Unsupported data format: {channel_dim}")
282
+
283
+
284
+ def get_uniform_frame_indices(total_num_frames: int, num_frames: int | None = None):
285
+ """
286
+ Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
287
+ when loading a video.
288
+
289
+ Args:
290
+ total_num_frames (`int`):
291
+ Total number of frames that a video has.
292
+ num_frames (`int`, *optional*):
293
+ Number of frames to sample uniformly. If not specified, all frames are sampled.
294
+
295
+ Returns:
296
+ np.ndarray: np array of frame indices that will be sampled.
297
+ """
298
+ if num_frames is not None:
299
+ indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
300
+ else:
301
+ indices = np.arange(0, total_num_frames).astype(int)
302
+ return indices
303
+
304
+
305
+ def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
306
+ """
307
+ A default sampling function that replicates the logic used in get_uniform_frame_indices,
308
+ while optionally handling `fps` if `num_frames` is not provided.
309
+
310
+ Args:
311
+ metadata (`VideoMetadata`):
312
+ `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
313
+ num_frames (`int`, *optional*):
314
+ Number of frames to sample uniformly.
315
+ fps (`int` or `float`, *optional*):
316
+ Desired frames per second. Takes priority over num_frames if both are provided.
317
+
318
+ Returns:
319
+ `np.ndarray`: Array of frame indices to sample.
320
+ """
321
+ total_num_frames = metadata.total_num_frames
322
+ video_fps = metadata.fps
323
+
324
+ # If num_frames is not given but fps is, calculate num_frames from fps
325
+ if num_frames is None and fps is not None:
326
+ num_frames = int(total_num_frames / video_fps * fps)
327
+ if num_frames > total_num_frames:
328
+ raise ValueError(
329
+ f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
330
+ f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
331
+ )
332
+
333
+ if num_frames is not None:
334
+ indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
335
+ else:
336
+ indices = np.arange(0, total_num_frames, dtype=int)
337
+ return indices
338
+
339
+
340
+ def read_video_opencv(
341
+ video_path: Union["URL", "Path"],
342
+ sample_indices_fn: Callable,
343
+ **kwargs,
344
+ ) -> tuple[np.ndarray, VideoMetadata]:
345
+ """
346
+ Decode a video using the OpenCV backend.
347
+
348
+ Args:
349
+ video_path (`str`):
350
+ Path to the video file.
351
+ sample_indices_fn (`Callable`):
352
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
353
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
354
+ If not provided, simple uniform sampling with fps is performed.
355
+ Example:
356
+ def sample_indices_fn(metadata, **kwargs):
357
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
358
+
359
+ Returns:
360
+ tuple[`np.ndarray`, `VideoMetadata`]: A tuple containing:
361
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
362
+ - `VideoMetadata` object.
363
+ """
364
+ # Lazy import cv2
365
+ requires_backends(read_video_opencv, ["cv2"])
366
+ import cv2
367
+
368
+ video = cv2.VideoCapture(video_path)
369
+ total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
370
+ video_fps = video.get(cv2.CAP_PROP_FPS)
371
+ duration = total_num_frames / video_fps if video_fps else 0
372
+ metadata = VideoMetadata(
373
+ total_num_frames=int(total_num_frames),
374
+ fps=float(video_fps),
375
+ duration=float(duration),
376
+ video_backend="opencv",
377
+ height=int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)),
378
+ width=int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
379
+ )
380
+
381
+ indices = sample_indices_fn(metadata=metadata, **kwargs)
382
+ index = 0
383
+ frames = []
384
+ while video.isOpened():
385
+ success, frame = video.read()
386
+ if not success:
387
+ break
388
+ if index in indices:
389
+ height, width, channel = frame.shape
390
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
391
+ frames.append(frame[0:height, 0:width, 0:channel])
392
+ if success:
393
+ index += 1
394
+ if index >= total_num_frames:
395
+ break
396
+
397
+ video.release()
398
+ metadata.frames_indices = indices
399
+ return np.stack(frames), metadata
400
+
401
+
402
+ def read_video_decord(
403
+ video_path: Union["URL", "Path"],
404
+ sample_indices_fn: Callable,
405
+ **kwargs,
406
+ ):
407
+ """
408
+ Decode a video using the Decord backend.
409
+
410
+ Args:
411
+ video_path (`str`):
412
+ Path to the video file.
413
+ sample_indices_fn (`Callable`):
414
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
415
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
416
+ If not provided, simple uniform sampling with fps is performed.
417
+ Example:
418
+ def sample_indices_fn(metadata, **kwargs):
419
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
420
+
421
+ Returns:
422
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
423
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
424
+ - `VideoMetadata` object.
425
+ """
426
+ # Lazy import from decord
427
+ requires_backends(read_video_decord, ["decord"])
428
+ from decord import VideoReader, cpu
429
+
430
+ vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
431
+ video_fps = vr.get_avg_fps()
432
+ total_num_frames = len(vr)
433
+ duration = total_num_frames / video_fps if video_fps else 0
434
+ metadata = VideoMetadata(
435
+ total_num_frames=int(total_num_frames),
436
+ fps=float(video_fps),
437
+ duration=float(duration),
438
+ video_backend="decord",
439
+ )
440
+
441
+ indices = sample_indices_fn(metadata=metadata, **kwargs)
442
+ video = vr.get_batch(indices).asnumpy()
443
+
444
+ metadata.update(
445
+ {
446
+ "frames_indices": indices,
447
+ "height": video.shape[1],
448
+ "width": video.shape[2],
449
+ }
450
+ )
451
+ return video, metadata
452
+
453
+
454
+ def read_video_pyav(
455
+ video_path: Union["URL", "Path"],
456
+ sample_indices_fn: Callable,
457
+ **kwargs,
458
+ ):
459
+ """
460
+ Decode the video with PyAV decoder.
461
+
462
+ Args:
463
+ video_path (`str`):
464
+ Path to the video file.
465
+ sample_indices_fn (`Callable`, *optional*):
466
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
467
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
468
+ If not provided, simple uniform sampling with fps is performed.
469
+ Example:
470
+ def sample_indices_fn(metadata, **kwargs):
471
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
472
+
473
+ Returns:
474
+ tuple[`np.array`, `VideoMetadata`]: A tuple containing:
475
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
476
+ - `VideoMetadata` object.
477
+ """
478
+ # Lazy import av
479
+ requires_backends(read_video_pyav, ["av"])
480
+ import av
481
+
482
+ container = av.open(video_path)
483
+ total_num_frames = container.streams.video[0].frames
484
+ video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
485
+ duration = total_num_frames / video_fps if video_fps else 0
486
+ metadata = VideoMetadata(
487
+ total_num_frames=int(total_num_frames),
488
+ fps=float(video_fps),
489
+ duration=float(duration),
490
+ video_backend="pyav",
491
+ height=container.streams.video[0].height,
492
+ width=container.streams.video[0].width,
493
+ )
494
+
495
+ indices = sample_indices_fn(metadata=metadata, **kwargs)
496
+ frames = []
497
+ container.seek(0)
498
+ end_index = indices[-1]
499
+ for i, frame in enumerate(container.decode(video=0)):
500
+ if i > end_index:
501
+ break
502
+ if i >= 0 and i in indices:
503
+ frames.append(frame)
504
+
505
+ video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
506
+ metadata.frames_indices = indices
507
+ return video, metadata
508
+
509
+
510
+ def read_video_torchvision(
511
+ video_path: Union["URL", "Path"],
512
+ sample_indices_fn: Callable,
513
+ **kwargs,
514
+ ):
515
+ """
516
+ Decode the video with torchvision decoder.
517
+
518
+ Args:
519
+ video_path (`str`):
520
+ Path to the video file.
521
+ sample_indices_fn (`Callable`, *optional*):
522
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
523
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
524
+ If not provided, simple uniform sampling with fps is performed.
525
+ Example:
526
+ def sample_indices_fn(metadata, **kwargs):
527
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
528
+
529
+ Returns:
530
+ tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
531
+ - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
532
+ - `VideoMetadata` object.
533
+ """
534
+ warnings.warn(
535
+ "Using `torchvision` for video decoding is deprecated and will be removed in future versions. "
536
+ "Please use `torchcodec` instead."
537
+ )
538
+ video, _, info = torchvision_io.read_video(
539
+ video_path,
540
+ start_pts=0.0,
541
+ end_pts=None,
542
+ pts_unit="sec",
543
+ output_format="TCHW",
544
+ )
545
+ video_fps = info["video_fps"]
546
+ total_num_frames = video.size(0)
547
+ duration = total_num_frames / video_fps if video_fps else 0
548
+ metadata = VideoMetadata(
549
+ total_num_frames=int(total_num_frames),
550
+ fps=float(video_fps),
551
+ duration=float(duration),
552
+ video_backend="torchvision",
553
+ )
554
+
555
+ indices = sample_indices_fn(metadata=metadata, **kwargs)
556
+ video = video[indices].contiguous()
557
+ metadata.update(
558
+ {
559
+ "frames_indices": indices,
560
+ "height": video.shape[2],
561
+ "width": video.shape[3],
562
+ }
563
+ )
564
+ return video, metadata
565
+
566
+
567
+ def read_video_torchcodec(
568
+ video_path: Union["URL", "Path"],
569
+ sample_indices_fn: Callable,
570
+ **kwargs,
571
+ ):
572
+ """
573
+ Decode the video with torchcodec decoder.
574
+
575
+ Args:
576
+ video_path (`str`):
577
+ Path to the video file.
578
+ sample_indices_fn (`Callable`):
579
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
580
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
581
+ If not provided, simple uniform sampling with fps is performed.
582
+ Example:
583
+ def sample_indices_fn(metadata, **kwargs):
584
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
585
+
586
+ Returns:
587
+ Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
588
+ - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
589
+ - `VideoMetadata` object.
590
+ """
591
+ # Lazy import torchcodec
592
+ requires_backends(read_video_torchcodec, ["torchcodec"])
593
+ from torchcodec.decoders import VideoDecoder
594
+
595
+ # VideoDecoder expects a string for device, default to "cpu" if None
596
+
597
+ decoder = VideoDecoder(
598
+ video_path,
599
+ # Interestingly `exact` mode takes less than approximate when we load the whole video
600
+ seek_mode="exact",
601
+ # Allow FFmpeg decide on the number of threads for efficiency
602
+ num_ffmpeg_threads=0,
603
+ device=kwargs.get("device", "cpu"),
604
+ )
605
+ total_num_frames = decoder.metadata.num_frames
606
+ video_fps = decoder.metadata.average_fps
607
+ metadata = VideoMetadata(
608
+ total_num_frames=total_num_frames,
609
+ fps=video_fps,
610
+ duration=decoder.metadata.duration_seconds,
611
+ video_backend="torchcodec",
612
+ height=decoder.metadata.height,
613
+ width=decoder.metadata.width,
614
+ )
615
+
616
+ indices = sample_indices_fn(metadata=metadata, **kwargs)
617
+ video = decoder.get_frames_at(indices=indices).data.contiguous()
618
+ metadata.frames_indices = indices
619
+ return video, metadata
620
+
621
+
622
+ VIDEO_DECODERS = {
623
+ "decord": read_video_decord,
624
+ "opencv": read_video_opencv,
625
+ "pyav": read_video_pyav,
626
+ "torchvision": read_video_torchvision,
627
+ "torchcodec": read_video_torchcodec,
628
+ }
629
+
630
+
631
+ def load_video(
632
+ video: VideoInput,
633
+ num_frames: int | None = None,
634
+ fps: int | float | None = None,
635
+ backend: str = "pyav",
636
+ sample_indices_fn: Callable | None = None,
637
+ **kwargs,
638
+ ) -> np.ndarray:
639
+ """
640
+ Loads `video` to a numpy array.
641
+
642
+ Args:
643
+ video (`VideoInput`):
644
+ The video to convert to the numpy array format. Can be a link to video or local path.
645
+ num_frames (`int`, *optional*):
646
+ Number of frames to sample uniformly. If not passed, the whole video is loaded.
647
+ fps (`int` or `float`, *optional*):
648
+ Number of frames to sample per second. Should be passed only when `num_frames=None`.
649
+ If not specified and `num_frames==None`, all frames are sampled.
650
+ backend (`str`, *optional*, defaults to `"pyav"`):
651
+ The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav".
652
+ sample_indices_fn (`Callable`, *optional*):
653
+ A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
654
+ by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
655
+ If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
656
+ The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
657
+ indices at which the video should be sampled. For example:
658
+
659
+ Example:
660
+ def sample_indices_fn(metadata, **kwargs):
661
+ return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
662
+
663
+ Returns:
664
+ tuple[`np.ndarray`, Dict]: A tuple containing:
665
+ - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
666
+ - Metadata dictionary.
667
+ """
668
+
669
+ # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
670
+ if fps is not None and num_frames is not None and sample_indices_fn is None:
671
+ raise ValueError(
672
+ "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
673
+ )
674
+
675
+ # If user didn't pass a sampling function, create one on the fly with default logic
676
+ if sample_indices_fn is None:
677
+
678
+ def sample_indices_fn_func(metadata, **fn_kwargs):
679
+ return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
680
+
681
+ sample_indices_fn = sample_indices_fn_func
682
+
683
+ # Early exit if provided an array or `PIL` frames
684
+ if not isinstance(video, str):
685
+ metadata = [None] * len(video)
686
+ return video, metadata
687
+
688
+ if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
689
+ if not is_yt_dlp_available():
690
+ raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
691
+ # Lazy import from yt_dlp
692
+ requires_backends(load_video, ["yt_dlp"])
693
+ from yt_dlp import YoutubeDL
694
+
695
+ buffer = BytesIO()
696
+ with redirect_stdout(buffer), YoutubeDL() as f:
697
+ f.download([video])
698
+ bytes_obj = buffer.getvalue()
699
+ file_obj = BytesIO(bytes_obj)
700
+ elif video.startswith("http://") or video.startswith("https://"):
701
+ file_obj = BytesIO(httpx.get(video, follow_redirects=True).content)
702
+ elif os.path.isfile(video):
703
+ file_obj = video
704
+ else:
705
+ raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
706
+
707
+ # can also load with decord, but not cv2/torchvision
708
+ # both will fail in case of url links
709
+ video_is_url = video.startswith("http://") or video.startswith("https://")
710
+ if video_is_url and backend == "opencv":
711
+ raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
712
+
713
+ if (
714
+ (not is_decord_available() and backend == "decord")
715
+ or (not is_av_available() and backend == "pyav")
716
+ or (not is_cv2_available() and backend == "opencv")
717
+ or (not is_torchvision_available() and backend == "torchvision")
718
+ or (not is_torchcodec_available() and backend == "torchcodec")
719
+ ):
720
+ raise ImportError(
721
+ f"You chose backend={backend} for loading the video but the required library is not found in your environment "
722
+ f"Make sure to install {backend} before loading the video."
723
+ )
724
+
725
+ video_decoder = VIDEO_DECODERS[backend]
726
+ video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
727
+ return video, metadata
728
+
729
+
730
+ def convert_to_rgb(
731
+ video: np.ndarray,
732
+ input_data_format: str | ChannelDimension | None = None,
733
+ ) -> np.ndarray:
734
+ """
735
+ Convert video to RGB by blending the transparency layer if it's in RGBA format, otherwise simply returns it.
736
+
737
+ Args:
738
+ video (`np.ndarray`):
739
+ The video to convert.
740
+ input_data_format (`ChannelDimension`, *optional*):
741
+ The channel dimension format of the input video. If unset, will use the inferred format from the input.
742
+ """
743
+ if not isinstance(video, np.ndarray):
744
+ raise TypeError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}")
745
+
746
+ # np.array usually comes with ChannelDimension.LAST so let's convert it
747
+ if input_data_format is None:
748
+ input_data_format = infer_channel_dimension_format(video)
749
+ video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format)
750
+
751
+ # 3 channels for RGB already
752
+ if video.shape[-3] == 3:
753
+ return video
754
+
755
+ # Grayscale video so we repeat it 3 times for each channel
756
+ if video.shape[-3] == 1:
757
+ return video.repeat(3, -3)
758
+
759
+ if not (video[..., 3, :, :] < 255).any():
760
+ return video
761
+
762
+ # There is a transparency layer, blend it with a white background.
763
+ # Calculate the alpha proportion for blending.
764
+ alpha = video[..., 3, :, :] / 255.0
765
+ video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., 3, :, :]
766
+ return video
767
+
768
+
769
+ def pad(
770
+ video: np.ndarray,
771
+ padding: int | tuple[int, int] | Iterable[tuple[int, int]],
772
+ mode: PaddingMode = PaddingMode.CONSTANT,
773
+ constant_values: float | Iterable[float] = 0.0,
774
+ data_format: str | ChannelDimension | None = None,
775
+ input_data_format: str | ChannelDimension | None = None,
776
+ ) -> np.ndarray:
777
+ """
778
+ Pads the `video` with the specified (height, width) `padding` and `mode`.
779
+
780
+ Args:
781
+ video (`np.ndarray`):
782
+ The video to pad.
783
+ padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
784
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
785
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
786
+ - `((before, after),)` yields same before and after pad for height and width.
787
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
788
+ mode (`PaddingMode`):
789
+ The padding mode to use. Can be one of:
790
+ - `"constant"`: pads with a constant value.
791
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
792
+ vector along each axis.
793
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
794
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
795
+ constant_values (`float` or `Iterable[float]`, *optional*):
796
+ The value to use for the padding if `mode` is `"constant"`.
797
+ data_format (`str` or `ChannelDimension`, *optional*):
798
+ The channel dimension format for the output video. Can be one of:
799
+ - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
800
+ - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
801
+ If unset, will use same as the input video.
802
+ input_data_format (`str` or `ChannelDimension`, *optional*):
803
+ The channel dimension format for the input video. Can be one of:
804
+ - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
805
+ - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
806
+ If unset, will use the inferred format of the input video.
807
+
808
+ Returns:
809
+ `np.ndarray`: The padded video.
810
+
811
+ """
812
+ if input_data_format is None:
813
+ input_data_format = infer_channel_dimension_format(video)
814
+
815
+ def _expand_for_data_format(values):
816
+ """
817
+ Convert values to be in the format expected by np.pad based on the data format.
818
+ """
819
+ if isinstance(values, (int, float)):
820
+ values = ((values, values), (values, values))
821
+ elif isinstance(values, tuple) and len(values) == 1:
822
+ values = ((values[0], values[0]), (values[0], values[0]))
823
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
824
+ values = (values, values)
825
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
826
+ pass
827
+ else:
828
+ raise ValueError(f"Unsupported format: {values}")
829
+
830
+ # add 0 for channel dimension
831
+ values = (
832
+ ((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0))
833
+ )
834
+
835
+ # Add additional padding if there's a batch dimension
836
+ values = (0, *values) if video.ndim == 5 else values
837
+ return values
838
+
839
+ padding_map = {
840
+ PaddingMode.CONSTANT: "constant",
841
+ PaddingMode.REFLECT: "reflect",
842
+ PaddingMode.REPLICATE: "replicate",
843
+ PaddingMode.SYMMETRIC: "symmetric",
844
+ }
845
+ padding = _expand_for_data_format(padding)
846
+
847
+ pad_kwargs = {}
848
+ if mode not in padding_map:
849
+ raise ValueError(f"Invalid padding mode: {mode}")
850
+ elif mode == PaddingMode.CONSTANT:
851
+ pad_kwargs["constant_values"] = _expand_for_data_format(constant_values)
852
+
853
+ video = np.pad(video, padding, mode=padding_map[mode], **pad_kwargs)
854
+ video = to_channel_dimension_format(video, data_format, input_data_format) if data_format is not None else video
855
+ return video
856
+
857
+
858
+ def group_videos_by_shape(
859
+ videos: list["torch.Tensor"],
860
+ ) -> tuple[dict[tuple[int, int], "torch.Tensor"], dict[int, tuple[tuple[int, int], int]]]:
861
+ """
862
+ Groups videos by shape.
863
+ Returns a dictionary with the shape as key and a list of videos with that shape as value,
864
+ and a dictionary with the index of the video in the original list as key and the shape and index in the grouped list as value.
865
+ """
866
+ grouped_videos = {}
867
+
868
+ grouped_videos_index = {}
869
+ for i, video in enumerate(videos):
870
+ shape = video.shape[-2::]
871
+ num_frames = video.shape[-4] # video format BTCHW
872
+ shape = (num_frames, *shape)
873
+ if shape not in grouped_videos:
874
+ grouped_videos[shape] = []
875
+ grouped_videos[shape].append(video)
876
+ grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
877
+
878
+ # stack videos with the same size and number of frames
879
+ grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
880
+ return grouped_videos, grouped_videos_index
881
+
882
+
883
+ def reorder_videos(
884
+ processed_videos: dict[tuple[int, int], "torch.Tensor"],
885
+ grouped_videos_index: dict[int, tuple[tuple[int, int], int]],
886
+ ) -> list["torch.Tensor"]:
887
+ """
888
+ Reconstructs a list of videos in the original order.
889
+ """
890
+ return [
891
+ processed_videos[grouped_videos_index[i][0]][grouped_videos_index[i][1]]
892
+ for i in range(len(grouped_videos_index))
893
+ ]