Add files using upload-large-folder tool
Browse files- LTA_openwebtext_dualt/logs/debug_2k_stream1024_fc_mask1_4gpu/debug_2k_stream1024_fc_mask1_4gpu_now_20260517_125945.log +147 -0
- LTA_openwebtext_dualt/logs/infer/lta_owt_lm1bclassic_fullvocab_bert_c1024_len1024_elfLdim_d1280_l32_h16_ff5120_lr3e-4_gbs512_2node8gpu_1m_save10k_t-20260522071024-s2ss5_latest_step0030000_shard01_gpu1_b16.log +18 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/activations.py +369 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cache_utils.py +1623 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/configuration_utils.py +1365 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/file_utils.py +105 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/fusion_mapping.py +270 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_processing_backends.py +689 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_processing_utils.py +688 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_utils.py +1069 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/masking_utils.py +1514 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py +503 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/blt/configuration_blt.py +286 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/jetmoe/__init__.py +27 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/jetmoe/modeling_jetmoe.py +830 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/vitmatte/__init__.py +29 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/vitmatte/image_processing_pil_vitmatte.py +159 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/optimization.py +1342 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/tokenization_python.py +1420 -0
- LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/video_utils.py +893 -0
LTA_openwebtext_dualt/logs/debug_2k_stream1024_fc_mask1_4gpu/debug_2k_stream1024_fc_mask1_4gpu_now_20260517_125945.log
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NCCL version 2.25.1+cuda12.8
|
| 2 |
+
{
|
| 3 |
+
"device": "cuda:0",
|
| 4 |
+
"rank": 0,
|
| 5 |
+
"world_size": 4,
|
| 6 |
+
"samples": "tokenized_hf:13425484:pad=0",
|
| 7 |
+
"vocab_size": 2048,
|
| 8 |
+
"tokenizer_vocab_size": 2048,
|
| 9 |
+
"save_dir": "runs/debug_2k_stream1024_fc_mask1_4gpu_now_20260517_125945",
|
| 10 |
+
"batch_size": 32,
|
| 11 |
+
"grad_accum": 2,
|
| 12 |
+
"effective_batch_size": 256,
|
| 13 |
+
"global_batch_size": 256,
|
| 14 |
+
"lr_schedule": "cosine",
|
| 15 |
+
"optimizer": "adamw",
|
| 16 |
+
"epochs": 0.0,
|
| 17 |
+
"steps_per_epoch": 52443,
|
| 18 |
+
"total_steps": 1,
|
| 19 |
+
"warmup_steps": 26222,
|
| 20 |
+
"warmup_epochs": 0.5,
|
| 21 |
+
"min_lr": 6e-05,
|
| 22 |
+
"weight_decay": 0.1,
|
| 23 |
+
"output_weight_decay": -1.0,
|
| 24 |
+
"adamw_param_groups": "nanogpt",
|
| 25 |
+
"adam_beta1": 0.9,
|
| 26 |
+
"adam_beta2": 0.95,
|
| 27 |
+
"adam_eps": 1e-08,
|
| 28 |
+
"muon_impl": "legacy",
|
| 29 |
+
"muon_momentum": 0.95,
|
| 30 |
+
"muon_ns_steps": 5,
|
| 31 |
+
"muon_update_scale": 1.0,
|
| 32 |
+
"muon_nesterov": false,
|
| 33 |
+
"muon_width_scale": false,
|
| 34 |
+
"muon_grouping": "",
|
| 35 |
+
"muon_param_count": 0,
|
| 36 |
+
"muon_adam_param_count": 0,
|
| 37 |
+
"muon_param_names": [],
|
| 38 |
+
"muon_adam_param_names": [],
|
| 39 |
+
"muon_effective_nesterov": false,
|
| 40 |
+
"muon_effective_width_scale": false,
|
| 41 |
+
"muon_effective_weight_decay": 0.1,
|
| 42 |
+
"muon_adam_fallback_nesterov": false,
|
| 43 |
+
"muon_adam_fallback_weight_decay": 0.1,
|
| 44 |
+
"ema_decay": 0.0,
|
| 45 |
+
"ema_start_step": 0,
|
| 46 |
+
"model_type": "ddit",
|
| 47 |
+
"ddit_mlp_type": "gelu",
|
| 48 |
+
"elf_num_time_tokens": 4,
|
| 49 |
+
"elf_num_model_mode_tokens": 0,
|
| 50 |
+
"qk_norm": true,
|
| 51 |
+
"output_bias": false,
|
| 52 |
+
"output_init_std": -1.0,
|
| 53 |
+
"norm_type": "rmsnorm",
|
| 54 |
+
"target_loss": "hard_ce",
|
| 55 |
+
"linear_soft_target_power": 1.0,
|
| 56 |
+
"linear_soft_target_min_conf": 0.0,
|
| 57 |
+
"linear_soft_target_max_conf": 1.0,
|
| 58 |
+
"t_sampling_mode": "logit_normal",
|
| 59 |
+
"t_sampling_power": 1.0,
|
| 60 |
+
"t_sampling_eps": 0.0001,
|
| 61 |
+
"t_sampling_logit_mean": -1.5,
|
| 62 |
+
"t_sampling_logit_std": 0.8,
|
| 63 |
+
"dual_t": true,
|
| 64 |
+
"corrupt_t_mode": "same",
|
| 65 |
+
"corrupt_min_t": 0.0,
|
| 66 |
+
"corrupt_max_t": 1.0,
|
| 67 |
+
"prefix_block_prob": 0.0,
|
| 68 |
+
"prefix_block_len": 128,
|
| 69 |
+
"mask_ratio_floor_schedule": "none",
|
| 70 |
+
"dirichlet_endpoint_mode": "categorical_dual_t",
|
| 71 |
+
"dirichlet_semantic_t_mode": "same",
|
| 72 |
+
"dirichlet_semantic_t_value": 0.0,
|
| 73 |
+
"dirichlet_semantic_t_curve": "linear",
|
| 74 |
+
"dirichlet_semantic_t_power": 1.0,
|
| 75 |
+
"endpoint_sequence_random_prob_alpha": 0.0,
|
| 76 |
+
"categorical_wrong_from_full_vocab": true,
|
| 77 |
+
"categorical_wrong_from_batch_valid_tokens": false,
|
| 78 |
+
"categorical_wrong_basin_token_ids": "",
|
| 79 |
+
"categorical_wrong_basin_prob": 0.0,
|
| 80 |
+
"categorical_wrong_unigram_prob": 0.0,
|
| 81 |
+
"categorical_wrong_uniform_prob": 0.0,
|
| 82 |
+
"categorical_wrong_corpus_unigram_path": "",
|
| 83 |
+
"categorical_wrong_corpus_unigram_alpha": 1.0,
|
| 84 |
+
"categorical_wrong_basin_shared_prob": 0.0,
|
| 85 |
+
"categorical_wrong_unigram_shared_prob": 0.0,
|
| 86 |
+
"mask_mixture_original_prob": 0.0,
|
| 87 |
+
"mask_mixture_lowk_prob": 0.0,
|
| 88 |
+
"mask_mixture_lowcorrupt_prob": 0.0,
|
| 89 |
+
"mask_mixture_block_prob": 0.0,
|
| 90 |
+
"mask_mixture_all_prob": 0.0,
|
| 91 |
+
"mask_mixture_lowk_clean_tokens": "1,2,4,8,16,32,64",
|
| 92 |
+
"mask_mixture_lowcorrupt_tokens": "1,2,4,8,16,32,64",
|
| 93 |
+
"mask_mixture_block_tokens": "64,128",
|
| 94 |
+
"simplex_bridge_sampler": "dirichlet",
|
| 95 |
+
"logistic_normal_sigma_min": 0.18,
|
| 96 |
+
"logistic_normal_sigma_max": 2.2,
|
| 97 |
+
"logistic_normal_tau_min": 0.65,
|
| 98 |
+
"logistic_normal_tau_max": 1.15,
|
| 99 |
+
"torch_compile": false,
|
| 100 |
+
"compile_mode": "max-autotune",
|
| 101 |
+
"state_format": "prob",
|
| 102 |
+
"meanflow_weight": 0.0,
|
| 103 |
+
"rollout_train_prob": 0.0,
|
| 104 |
+
"rollout_train_steps": 1,
|
| 105 |
+
"rollout_train_infer_steps": 64,
|
| 106 |
+
"rollout_train_temp": 1.45,
|
| 107 |
+
"rollout_train_max_gamma": 1.0,
|
| 108 |
+
"rollout_train_corrupt_only": true,
|
| 109 |
+
"rollout_train_samplewise": false,
|
| 110 |
+
"rollout_train_compute_always": false,
|
| 111 |
+
"bridge_noise_init": "logistic_normal",
|
| 112 |
+
"noise_sigma": -1.0,
|
| 113 |
+
"allow_tf32": true,
|
| 114 |
+
"activation_checkpointing": false,
|
| 115 |
+
"activation_checkpoint_interval": 1,
|
| 116 |
+
"activation_checkpoint_scope": "block",
|
| 117 |
+
"ddp_static_graph": false,
|
| 118 |
+
"ddp_gradient_as_bucket_view": true,
|
| 119 |
+
"blocking_data_transfer": false,
|
| 120 |
+
"dataloader_prefetch_factor": 2,
|
| 121 |
+
"full_train_stats": false,
|
| 122 |
+
"tokenized_hf": true,
|
| 123 |
+
"tokenized_pad_token": "pad",
|
| 124 |
+
"elf_conditional_hf": false,
|
| 125 |
+
"record_pad_truncate": false,
|
| 126 |
+
"record_add_eos": false,
|
| 127 |
+
"record_add_special_tokens": false,
|
| 128 |
+
"record_pad_token": "pad",
|
| 129 |
+
"record_shuffle_buffer": 10000,
|
| 130 |
+
"wrap": false,
|
| 131 |
+
"wrap_mode": "stream",
|
| 132 |
+
"wrap_record_buffer_size": 200,
|
| 133 |
+
"owt_cached_chunks": false,
|
| 134 |
+
"owt_chunk_cache_dir": "",
|
| 135 |
+
"owt_chunk_cache_rebuild": false,
|
| 136 |
+
"owt_chunk_cache_write_batch": 4096,
|
| 137 |
+
"owt_exact_repeat_per_chunk": 0,
|
| 138 |
+
"online_chunk_shuffle": false,
|
| 139 |
+
"online_chunk_shuffle_buffer": 10000,
|
| 140 |
+
"openwebtext_split": "all",
|
| 141 |
+
"detokenizer": "auto",
|
| 142 |
+
"resolved_detokenizer": null,
|
| 143 |
+
"num_workers": 2,
|
| 144 |
+
"latest_every": 1000,
|
| 145 |
+
"resume_path": ""
|
| 146 |
+
}
|
| 147 |
+
step=1 epoch=1/1 epoch_step=1/52443 micro_steps=2 elapsed=2.6s lr=4.576310e-08 loss=7.6246 loss_recon=7.6246 loss_meanflow=0.0000 mean_model_t=0.2163 mean_corrupt_t=0.2163 mean_loss_t_weight=1.0000 linear_soft_target_mean_conf=0.0000 prior_center_loss_beta=0.0000 rollout_train_applied=0.0000 grad_enabled_before_rollout=1.0000 grad_enabled_after_rollout=1.0000 logits_requires_grad=1.0000 raw_loss_requires_grad=1.0000 acc_all=0.0000 corrupt_frac=1.0000 acc_corrupt=0.0000 loss_corrupt=7.6246 wrong_frac=0.7833 init_acc_corrupt=0.1268 acc_corrupt_t_0p0_0p2=0.0000 corrupt_frac_t_0p0_0p2=0.5469 acc_corrupt_t_0p2_0p4=0.0000 corrupt_frac_t_0p2_0p4=0.3750 acc_corrupt_t_0p4_0p6=0.0000 corrupt_frac_t_0p4_0p6=0.0625 acc_corrupt_t_0p6_0p8=0.0000 corrupt_frac_t_0p6_0p8=0.0312 out_w_norm=0.0000 out_g_norm=0.2113 loss_all=7.6246 init_gold_top10=0.1954 init_gold_top100=0.2858
|
LTA_openwebtext_dualt/logs/infer/lta_owt_lm1bclassic_fullvocab_bert_c1024_len1024_elfLdim_d1280_l32_h16_ff5120_lr3e-4_gbs512_2node8gpu_1m_save10k_t-20260522071024-s2ss5_latest_step0030000_shard01_gpu1_b16.log
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[ckpt] runs/lta_owt_lm1bclassic_fullvocab_bert_c1024_len1024_elfLdim_d1280_l32_h16_ff5120_lr3e-4_gbs512_2node8gpu_1m_save10k_t-20260522071024-s2ss5/latest.pt step=30000
|
| 2 |
+
[decode] steps128_c1024_t1p45 generated 16/256
|
| 3 |
+
[decode] steps128_c1024_t1p45 generated 32/256
|
| 4 |
+
[decode] steps128_c1024_t1p45 generated 48/256
|
| 5 |
+
[decode] steps128_c1024_t1p45 generated 64/256
|
| 6 |
+
[decode] steps128_c1024_t1p45 generated 80/256
|
| 7 |
+
[decode] steps128_c1024_t1p45 generated 96/256
|
| 8 |
+
[decode] steps128_c1024_t1p45 generated 112/256
|
| 9 |
+
[decode] steps128_c1024_t1p45 generated 128/256
|
| 10 |
+
[decode] steps128_c1024_t1p45 generated 144/256
|
| 11 |
+
[decode] steps128_c1024_t1p45 generated 160/256
|
| 12 |
+
[decode] steps128_c1024_t1p45 generated 176/256
|
| 13 |
+
[decode] steps128_c1024_t1p45 generated 192/256
|
| 14 |
+
[decode] steps128_c1024_t1p45 generated 208/256
|
| 15 |
+
[decode] steps128_c1024_t1p45 generated 224/256
|
| 16 |
+
[decode] steps128_c1024_t1p45 generated 240/256
|
| 17 |
+
[decode] steps128_c1024_t1p45 generated 256/256
|
| 18 |
+
[summary] {"name": "steps128_c1024_t1p45", "step": 30000, "decode_steps": 128, "concentration_max": 1024.0, "raw_genppl": 15.432598193356394, "stripped_genppl": 15.314571366003578, "sample_entropy": 3.169319289650098, "distinct_1": 0.005100250244140625, "distinct_2": 0.10675937805474096, "top_token_mass": 0.2738838195800781, "raw_kept": 256, "stripped_kept": 256}
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/activations.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import functools
|
| 16 |
+
import math
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch import Tensor, nn
|
| 21 |
+
|
| 22 |
+
from .integrations.hub_kernels import use_kernel_forward_from_hub
|
| 23 |
+
from .utils import logging
|
| 24 |
+
from .utils.import_utils import is_torchdynamo_compiling
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@use_kernel_forward_from_hub("GeluTanh")
|
| 31 |
+
class GELUTanh(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
A fast C implementation of the tanh approximation of the GeLU activation function. See
|
| 34 |
+
https://huggingface.co/papers/1606.08415.
|
| 35 |
+
|
| 36 |
+
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
|
| 37 |
+
match due to rounding errors.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, use_gelu_tanh_python: bool = False):
|
| 41 |
+
super().__init__()
|
| 42 |
+
if use_gelu_tanh_python:
|
| 43 |
+
self.act = self._gelu_tanh_python
|
| 44 |
+
else:
|
| 45 |
+
self.act = functools.partial(nn.functional.gelu, approximate="tanh")
|
| 46 |
+
|
| 47 |
+
def _gelu_tanh_python(self, input: Tensor) -> Tensor:
|
| 48 |
+
return input * 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
| 49 |
+
|
| 50 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 51 |
+
return self.act(input)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Added for compatibility with autoawq which is archived now and imports PytorchGELUTanh from activations.py
|
| 55 |
+
PytorchGELUTanh = GELUTanh
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@use_kernel_forward_from_hub("NewGELU")
|
| 59 |
+
class NewGELUActivation(nn.Module):
|
| 60 |
+
"""
|
| 61 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
| 62 |
+
the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 66 |
+
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@use_kernel_forward_from_hub("GeLU")
|
| 70 |
+
class GELUActivation(nn.Module):
|
| 71 |
+
"""
|
| 72 |
+
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
| 73 |
+
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
| 74 |
+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
|
| 75 |
+
Also see the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, use_gelu_python: bool = False):
|
| 79 |
+
super().__init__()
|
| 80 |
+
if use_gelu_python:
|
| 81 |
+
self.act = self._gelu_python
|
| 82 |
+
else:
|
| 83 |
+
self.act = nn.functional.gelu
|
| 84 |
+
|
| 85 |
+
def _gelu_python(self, input: Tensor) -> Tensor:
|
| 86 |
+
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
|
| 87 |
+
|
| 88 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 89 |
+
return self.act(input)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@use_kernel_forward_from_hub("SiLU")
|
| 93 |
+
class SiLUActivation(nn.Module):
|
| 94 |
+
"""
|
| 95 |
+
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
|
| 96 |
+
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
|
| 97 |
+
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
|
| 98 |
+
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
|
| 99 |
+
later.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 103 |
+
return nn.functional.silu(input)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@use_kernel_forward_from_hub("FastGELU")
|
| 107 |
+
class FastGELUActivation(nn.Module):
|
| 108 |
+
"""
|
| 109 |
+
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 113 |
+
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@use_kernel_forward_from_hub("QuickGELU")
|
| 117 |
+
class QuickGELUActivation(nn.Module):
|
| 118 |
+
"""
|
| 119 |
+
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 123 |
+
return input * torch.sigmoid(1.702 * input)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class ClippedGELUActivation(nn.Module):
|
| 127 |
+
"""
|
| 128 |
+
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
|
| 129 |
+
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
|
| 130 |
+
https://huggingface.co/papers/2004.09602.
|
| 131 |
+
|
| 132 |
+
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
| 133 |
+
initially created.
|
| 134 |
+
|
| 135 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
| 136 |
+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://huggingface.co/papers/1606.08415
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, min: float, max: float):
|
| 140 |
+
if min > max:
|
| 141 |
+
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
|
| 142 |
+
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.min = min
|
| 145 |
+
self.max = max
|
| 146 |
+
|
| 147 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 148 |
+
return torch.clip(gelu(x), self.min, self.max)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class AccurateGELUActivation(nn.Module):
|
| 152 |
+
"""
|
| 153 |
+
Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
|
| 154 |
+
https://github.com/hendrycks/GELUs
|
| 155 |
+
|
| 156 |
+
Implemented along with MEGA (Moving Average Equipped Gated Attention)
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(self):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.precomputed_constant = math.sqrt(2 / math.pi)
|
| 162 |
+
|
| 163 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 164 |
+
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class MishActivation(nn.Module):
|
| 168 |
+
"""
|
| 169 |
+
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://huggingface.co/papers/1908.08681). Also
|
| 170 |
+
visit the official repository for the paper: https://github.com/digantamisra98/Mish
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(self):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.act = nn.functional.mish
|
| 176 |
+
|
| 177 |
+
def _mish_python(self, input: Tensor) -> Tensor:
|
| 178 |
+
return input * torch.tanh(nn.functional.softplus(input))
|
| 179 |
+
|
| 180 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 181 |
+
return self.act(input)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class LinearActivation(nn.Module):
|
| 185 |
+
"""
|
| 186 |
+
Applies the linear activation function, i.e. forwarding input directly to output.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 190 |
+
return input
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class LaplaceActivation(nn.Module):
|
| 194 |
+
"""
|
| 195 |
+
Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
|
| 196 |
+
https://huggingface.co/papers/2209.10655
|
| 197 |
+
|
| 198 |
+
Inspired by squared relu, but with bounded range and gradient for better stability
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def forward(self, input, mu=0.707107, sigma=0.282095):
|
| 202 |
+
input = (input - mu).div(sigma * math.sqrt(2.0))
|
| 203 |
+
return 0.5 * (1.0 + torch.erf(input))
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class ReLUSquaredActivation(nn.Module):
|
| 207 |
+
"""
|
| 208 |
+
Applies the relu^2 activation introduced in https://huggingface.co/papers/2109.08668
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
def forward(self, input):
|
| 212 |
+
relu_applied = nn.functional.relu(input)
|
| 213 |
+
squared = torch.square(relu_applied)
|
| 214 |
+
return squared
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class SqrtSoftplusActivation(nn.Module):
|
| 218 |
+
"""sqrt(softplus(x)) — the router scoring function used by DeepSeek V4."""
|
| 219 |
+
|
| 220 |
+
def forward(self, input):
|
| 221 |
+
return nn.functional.softplus(input).sqrt()
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class ClassInstantier(OrderedDict):
|
| 225 |
+
def __getitem__(self, key):
|
| 226 |
+
content = super().__getitem__(key)
|
| 227 |
+
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
| 228 |
+
return cls(**kwargs)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class XIELUActivation(nn.Module):
|
| 232 |
+
"""
|
| 233 |
+
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
|
| 234 |
+
|
| 235 |
+
If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA
|
| 236 |
+
Otherwise, we emit a single warning and use xIELU Python
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
alpha_p_init=0.8,
|
| 242 |
+
alpha_n_init=0.8,
|
| 243 |
+
beta=0.5,
|
| 244 |
+
eps=-1e-6,
|
| 245 |
+
dtype=torch.bfloat16,
|
| 246 |
+
with_vector_loads=False,
|
| 247 |
+
):
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.alpha_p = nn.Parameter(torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=dtype))).unsqueeze(0))
|
| 250 |
+
self.alpha_n = nn.Parameter(
|
| 251 |
+
torch.log(torch.expm1(torch.tensor(alpha_n_init - beta, dtype=dtype))).unsqueeze(0)
|
| 252 |
+
)
|
| 253 |
+
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
|
| 254 |
+
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
|
| 255 |
+
self.with_vector_loads = with_vector_loads
|
| 256 |
+
# Temporary until xIELU CUDA fully implemented
|
| 257 |
+
self._beta_scalar = float(beta)
|
| 258 |
+
self._eps_scalar = float(eps)
|
| 259 |
+
|
| 260 |
+
self._xielu_cuda_obj = None
|
| 261 |
+
try:
|
| 262 |
+
import xielu.ops # noqa: F401
|
| 263 |
+
|
| 264 |
+
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
|
| 265 |
+
msg = "Using experimental xIELU CUDA."
|
| 266 |
+
try:
|
| 267 |
+
from torch.compiler import allow_in_graph
|
| 268 |
+
|
| 269 |
+
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
|
| 270 |
+
msg += " Enabled torch._dynamo for xIELU CUDA."
|
| 271 |
+
except Exception as err:
|
| 272 |
+
msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance."
|
| 273 |
+
self._xielu_cuda_fn = self._xielu_cuda
|
| 274 |
+
logger.warning_once(msg)
|
| 275 |
+
except Exception as err:
|
| 276 |
+
logger.warning_once(
|
| 277 |
+
f"CUDA-fused xIELU not available ({err}) – falling back to a Python version.\n"
|
| 278 |
+
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def _xielu_python(self, x: Tensor) -> Tensor:
|
| 282 |
+
alpha_p = nn.functional.softplus(self.alpha_p)
|
| 283 |
+
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
|
| 284 |
+
return torch.where(
|
| 285 |
+
x > 0,
|
| 286 |
+
alpha_p * x * x + self.beta * x,
|
| 287 |
+
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
def _xielu_cuda(self, x: Tensor) -> Tensor:
|
| 291 |
+
"""Firewall function to prevent torch.compile from seeing .item() calls"""
|
| 292 |
+
original_shape = x.shape
|
| 293 |
+
# CUDA kernel expects 3D tensors, reshape if needed
|
| 294 |
+
while x.dim() < 3:
|
| 295 |
+
x = x.unsqueeze(0)
|
| 296 |
+
if x.dim() > 3:
|
| 297 |
+
x = x.view(-1, 1, x.size(-1))
|
| 298 |
+
if original_shape != x.shape:
|
| 299 |
+
logger.warning_once(
|
| 300 |
+
"Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
|
| 301 |
+
original_shape,
|
| 302 |
+
x.shape,
|
| 303 |
+
)
|
| 304 |
+
result = self._xielu_cuda_obj.forward(
|
| 305 |
+
x,
|
| 306 |
+
self.alpha_p.to(x.dtype),
|
| 307 |
+
self.alpha_n.to(x.dtype),
|
| 308 |
+
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
|
| 309 |
+
self._beta_scalar,
|
| 310 |
+
self._eps_scalar,
|
| 311 |
+
self.with_vector_loads,
|
| 312 |
+
)
|
| 313 |
+
return result.view(original_shape)
|
| 314 |
+
|
| 315 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 316 |
+
if self._xielu_cuda_obj is not None and input.is_cuda:
|
| 317 |
+
if not is_torchdynamo_compiling():
|
| 318 |
+
return self._xielu_cuda_fn(input)
|
| 319 |
+
else:
|
| 320 |
+
logger.warning_once("torch._dynamo is compiling, using Python version of xIELU.")
|
| 321 |
+
return self._xielu_python(input)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
ACT2CLS = {
|
| 325 |
+
"gelu": GELUActivation,
|
| 326 |
+
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
|
| 327 |
+
"gelu_fast": FastGELUActivation,
|
| 328 |
+
"gelu_new": NewGELUActivation,
|
| 329 |
+
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
|
| 330 |
+
"gelu_pytorch_tanh": GELUTanh,
|
| 331 |
+
"gelu_python_tanh": (GELUTanh, {"use_gelu_tanh_python": True}),
|
| 332 |
+
"gelu_accurate": AccurateGELUActivation,
|
| 333 |
+
"hardswish": nn.Hardswish,
|
| 334 |
+
"laplace": LaplaceActivation,
|
| 335 |
+
"leaky_relu": nn.LeakyReLU,
|
| 336 |
+
"linear": LinearActivation,
|
| 337 |
+
"mish": MishActivation,
|
| 338 |
+
"quick_gelu": QuickGELUActivation,
|
| 339 |
+
"relu": nn.ReLU,
|
| 340 |
+
"relu2": ReLUSquaredActivation,
|
| 341 |
+
"relu6": nn.ReLU6,
|
| 342 |
+
"sigmoid": nn.Sigmoid,
|
| 343 |
+
"silu": SiLUActivation,
|
| 344 |
+
"sqrtsoftplus": SqrtSoftplusActivation,
|
| 345 |
+
"swish": nn.SiLU,
|
| 346 |
+
"tanh": nn.Tanh,
|
| 347 |
+
"prelu": nn.PReLU,
|
| 348 |
+
"xielu": XIELUActivation,
|
| 349 |
+
}
|
| 350 |
+
ACT2FN = ClassInstantier(ACT2CLS)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def get_activation(activation_string):
|
| 354 |
+
if activation_string in ACT2FN:
|
| 355 |
+
return ACT2FN[activation_string]
|
| 356 |
+
else:
|
| 357 |
+
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# For backwards compatibility with: from activations import gelu_python
|
| 361 |
+
gelu_python = get_activation("gelu_python")
|
| 362 |
+
gelu_new = get_activation("gelu_new")
|
| 363 |
+
gelu = get_activation("gelu")
|
| 364 |
+
gelu_fast = get_activation("gelu_fast")
|
| 365 |
+
gelu_pytorch_tanh = get_activation("gelu_pytorch_tanh")
|
| 366 |
+
quick_gelu = get_activation("quick_gelu")
|
| 367 |
+
silu = get_activation("silu")
|
| 368 |
+
mish = get_activation("mish")
|
| 369 |
+
linear_act = get_activation("linear")
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/cache_utils.py
ADDED
|
@@ -0,0 +1,1623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from collections.abc import Iterable
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .configuration_utils import PreTrainedConfig
|
| 7 |
+
from .utils import (
|
| 8 |
+
is_hqq_available,
|
| 9 |
+
is_optimum_quanto_available,
|
| 10 |
+
is_quanto_greater,
|
| 11 |
+
is_torch_greater_or_equal,
|
| 12 |
+
is_torchdynamo_compiling,
|
| 13 |
+
logging,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if is_hqq_available():
|
| 18 |
+
from hqq.core.quantize import Quantizer as HQQQuantizer
|
| 19 |
+
|
| 20 |
+
_is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Registry mapping ``config.layer_types[i]`` -> the dynamic cache layer class to build for
|
| 27 |
+
# that layer. ``DynamicCache.__init__`` consults this mapping when a ``config`` is provided
|
| 28 |
+
# so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own
|
| 29 |
+
# cache-layer subclass and stop needing a model-specific ``Cache`` subclass.
|
| 30 |
+
#
|
| 31 |
+
# A cache layer subclass with a class attribute ``layer_type = "..."`` auto-registers via
|
| 32 |
+
# ``CacheLayerMixin.__init_subclass__``. Each registered class must accept a
|
| 33 |
+
# ``PreTrainedConfig`` (the decoder text config) as the only positional argument.
|
| 34 |
+
LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CacheLayerMixin(ABC):
|
| 38 |
+
"""Base, abstract class for a single layer's cache."""
|
| 39 |
+
|
| 40 |
+
is_compileable = False
|
| 41 |
+
# Subclasses can set ``layer_type`` to auto-register themselves in
|
| 42 |
+
# ``LAYER_TYPE_CACHE_MAPPING`` at import time (used by ``DynamicCache`` to dispatch
|
| 43 |
+
# per-layer cache classes from ``config.layer_types``).
|
| 44 |
+
layer_type: str | None = None
|
| 45 |
+
|
| 46 |
+
def __init_subclass__(cls, **kwargs):
|
| 47 |
+
super().__init_subclass__(**kwargs)
|
| 48 |
+
layer_type = cls.__dict__.get("layer_type", None)
|
| 49 |
+
if layer_type is not None:
|
| 50 |
+
LAYER_TYPE_CACHE_MAPPING[layer_type] = cls
|
| 51 |
+
|
| 52 |
+
def __init__(self):
|
| 53 |
+
self.keys: torch.Tensor | None = None
|
| 54 |
+
self.values: torch.Tensor | None = None
|
| 55 |
+
self.is_initialized = False
|
| 56 |
+
|
| 57 |
+
def __repr__(self):
|
| 58 |
+
return f"{self.__class__.__name__}"
|
| 59 |
+
|
| 60 |
+
@abstractmethod
|
| 61 |
+
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: ...
|
| 62 |
+
|
| 63 |
+
@abstractmethod
|
| 64 |
+
def update(
|
| 65 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
|
| 66 |
+
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
| 67 |
+
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def get_mask_sizes(self, query_length: int) -> tuple[int, int]: ...
|
| 70 |
+
|
| 71 |
+
@abstractmethod
|
| 72 |
+
def get_seq_length(self) -> int: ...
|
| 73 |
+
|
| 74 |
+
@abstractmethod
|
| 75 |
+
def get_max_cache_shape(self) -> int: ...
|
| 76 |
+
|
| 77 |
+
def offload(self):
|
| 78 |
+
"""Offload this layer's data to CPU device."""
|
| 79 |
+
if self.is_initialized:
|
| 80 |
+
self.keys = self.keys.to("cpu", non_blocking=True)
|
| 81 |
+
self.values = self.values.to("cpu", non_blocking=True)
|
| 82 |
+
|
| 83 |
+
def prefetch(self):
|
| 84 |
+
"""In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
|
| 85 |
+
if self.is_initialized and self.keys.device != self.device:
|
| 86 |
+
self.keys = self.keys.to(self.device, non_blocking=True)
|
| 87 |
+
self.values = self.values.to(self.device, non_blocking=True)
|
| 88 |
+
|
| 89 |
+
def reset(self) -> None:
|
| 90 |
+
"""Resets the cache values while preserving the objects"""
|
| 91 |
+
if self.is_initialized:
|
| 92 |
+
self.keys.zero_()
|
| 93 |
+
self.values.zero_()
|
| 94 |
+
# This attribute is set on several Layers
|
| 95 |
+
if hasattr(self, "cumulative_length"):
|
| 96 |
+
# It can either be an int for dynamic layers, or a tensor for static layers
|
| 97 |
+
if isinstance(self.cumulative_length, int):
|
| 98 |
+
self.cumulative_length = 0
|
| 99 |
+
else:
|
| 100 |
+
self.cumulative_length.zero_()
|
| 101 |
+
|
| 102 |
+
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
| 103 |
+
"""Reorders this layer's cache for beam search."""
|
| 104 |
+
if self.get_seq_length() > 0:
|
| 105 |
+
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
|
| 106 |
+
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class DynamicLayer(CacheLayerMixin):
|
| 110 |
+
"""
|
| 111 |
+
A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
|
| 112 |
+
It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
is_sliding = False
|
| 116 |
+
|
| 117 |
+
def __init__(self, config: PreTrainedConfig | None = None):
|
| 118 |
+
super().__init__()
|
| 119 |
+
|
| 120 |
+
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
| 121 |
+
self.dtype, self.device = key_states.dtype, key_states.device
|
| 122 |
+
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
|
| 123 |
+
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
|
| 124 |
+
self.is_initialized = True
|
| 125 |
+
|
| 126 |
+
def update(
|
| 127 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
|
| 128 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 129 |
+
"""
|
| 130 |
+
Update the key and value caches in-place, and return the necessary keys and value states.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
key_states (`torch.Tensor`): The new key states to cache.
|
| 134 |
+
value_states (`torch.Tensor`): The new value states to cache.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
| 138 |
+
"""
|
| 139 |
+
# Lazy initialization
|
| 140 |
+
if not self.is_initialized:
|
| 141 |
+
self.lazy_initialization(key_states, value_states)
|
| 142 |
+
|
| 143 |
+
self.keys = torch.cat([self.keys, key_states], dim=-2)
|
| 144 |
+
self.values = torch.cat([self.values, value_states], dim=-2)
|
| 145 |
+
return self.keys, self.values
|
| 146 |
+
|
| 147 |
+
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
|
| 148 |
+
"""Return the length and offset of the cache, used to generate the mask"""
|
| 149 |
+
kv_offset = 0
|
| 150 |
+
kv_length = self.get_seq_length() + query_length
|
| 151 |
+
return kv_length, kv_offset
|
| 152 |
+
|
| 153 |
+
def get_seq_length(self) -> int:
|
| 154 |
+
"""Returns the sequence length of the cached states."""
|
| 155 |
+
if not self.is_initialized or self.keys.numel() == 0:
|
| 156 |
+
return 0
|
| 157 |
+
return self.keys.shape[-2]
|
| 158 |
+
|
| 159 |
+
def get_max_cache_shape(self) -> int:
|
| 160 |
+
"""Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
|
| 161 |
+
return -1
|
| 162 |
+
|
| 163 |
+
def crop(self, max_length: int) -> None:
|
| 164 |
+
"""
|
| 165 |
+
Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
|
| 166 |
+
to remove `max_length` tokens.
|
| 167 |
+
"""
|
| 168 |
+
if max_length < 0:
|
| 169 |
+
max_length = self.get_seq_length() - abs(max_length)
|
| 170 |
+
|
| 171 |
+
if self.get_seq_length() <= max_length:
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
self.keys = self.keys[..., :max_length, :]
|
| 175 |
+
self.values = self.values[..., :max_length, :]
|
| 176 |
+
|
| 177 |
+
def batch_repeat_interleave(self, repeats: int) -> None:
|
| 178 |
+
"""Repeat the cache `repeats` times in the batch dimension."""
|
| 179 |
+
if self.get_seq_length() > 0:
|
| 180 |
+
self.keys = self.keys.repeat_interleave(repeats, dim=0)
|
| 181 |
+
self.values = self.values.repeat_interleave(repeats, dim=0)
|
| 182 |
+
|
| 183 |
+
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
| 184 |
+
"""Only keep the `indices` in the batch dimension of the cache."""
|
| 185 |
+
if self.get_seq_length() > 0:
|
| 186 |
+
self.keys = self.keys[indices, ...]
|
| 187 |
+
self.values = self.values[indices, ...]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class DynamicSlidingWindowLayer(DynamicLayer):
|
| 191 |
+
"""
|
| 192 |
+
A cache layer that grows dynamically as more tokens are generated, up until the sliding window size.
|
| 193 |
+
It stores the key and value states as tensors of shape `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
is_sliding = True
|
| 197 |
+
|
| 198 |
+
def __init__(self, config: PreTrainedConfig | None = None, sliding_window: int | None = None):
|
| 199 |
+
super().__init__()
|
| 200 |
+
# Accept either a config (registry-style construction via LAYER_TYPE_CACHE_MAPPING)
|
| 201 |
+
# or a raw ``sliding_window`` int (legacy callers).
|
| 202 |
+
if sliding_window is None:
|
| 203 |
+
if config is None:
|
| 204 |
+
raise ValueError("Either `config` or `sliding_window` must be provided.")
|
| 205 |
+
sliding_window = getattr(config, "sliding_window", None) or getattr(config, "attention_chunk_size", None)
|
| 206 |
+
self.sliding_window = sliding_window
|
| 207 |
+
self.cumulative_length = 0
|
| 208 |
+
self._sliding_window_tensor = torch.tensor(self.sliding_window, dtype=torch.long)
|
| 209 |
+
|
| 210 |
+
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
| 211 |
+
super().lazy_initialization(key_states, value_states)
|
| 212 |
+
self._sliding_window_tensor = self._sliding_window_tensor.to(self.device)
|
| 213 |
+
|
| 214 |
+
def update(
|
| 215 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
|
| 216 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 217 |
+
"""
|
| 218 |
+
Update the key and value caches in-place, and return the necessary keys and value states.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
key_states (`torch.Tensor`): The new key states to cache.
|
| 222 |
+
value_states (`torch.Tensor`): The new value states to cache.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
| 226 |
+
"""
|
| 227 |
+
# Lazy initialization
|
| 228 |
+
if not self.is_initialized:
|
| 229 |
+
self.lazy_initialization(key_states, value_states)
|
| 230 |
+
|
| 231 |
+
self.cumulative_length += key_states.shape[-2]
|
| 232 |
+
|
| 233 |
+
# Compute the full states
|
| 234 |
+
full_key_states = torch.cat([self.keys, key_states], dim=-2)
|
| 235 |
+
full_value_states = torch.cat([self.values, value_states], dim=-2)
|
| 236 |
+
# Only cache the last `self.sliding_window - 1` tokens (or all of them if lower than that)
|
| 237 |
+
self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
|
| 238 |
+
self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
|
| 239 |
+
|
| 240 |
+
# Return the full states
|
| 241 |
+
return full_key_states, full_value_states
|
| 242 |
+
|
| 243 |
+
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
|
| 244 |
+
"""Return the length and offset of the cache, used to generate the attention mask"""
|
| 245 |
+
is_full = self.cumulative_length >= self.sliding_window
|
| 246 |
+
|
| 247 |
+
kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0)
|
| 248 |
+
if is_full:
|
| 249 |
+
kv_length = self.sliding_window - 1 + query_length
|
| 250 |
+
else:
|
| 251 |
+
kv_length = self.cumulative_length + query_length
|
| 252 |
+
|
| 253 |
+
return kv_length, kv_offset
|
| 254 |
+
|
| 255 |
+
def get_seq_length(self) -> int:
|
| 256 |
+
"""Returns the sequence length of the cached states."""
|
| 257 |
+
return self.cumulative_length
|
| 258 |
+
|
| 259 |
+
def get_max_cache_shape(self) -> int:
|
| 260 |
+
"""Return the maximum cache shape of the cache"""
|
| 261 |
+
return self.sliding_window
|
| 262 |
+
|
| 263 |
+
def crop(self, max_length: int) -> None:
|
| 264 |
+
"""
|
| 265 |
+
Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
|
| 266 |
+
negative to remove `max_length` tokens.
|
| 267 |
+
"""
|
| 268 |
+
if self.get_seq_length() >= self.sliding_window:
|
| 269 |
+
raise ValueError(
|
| 270 |
+
"Cannot `crop` a `DynamicSlidingWindowLayer` after it has seen more tokens than its"
|
| 271 |
+
"sliding window (otherwise some states are lost)"
|
| 272 |
+
)
|
| 273 |
+
super().crop(max_length)
|
| 274 |
+
self.cumulative_length = self.keys.shape[-2]
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class StaticLayer(CacheLayerMixin):
|
| 278 |
+
"""
|
| 279 |
+
A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
|
| 280 |
+
It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
max_cache_len (`int`):
|
| 284 |
+
Maximum number of tokens that can be stored, used for tensor preallocation.
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
is_compileable = True
|
| 288 |
+
is_sliding = False
|
| 289 |
+
|
| 290 |
+
def __init__(self, max_cache_len: int):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.max_cache_len = max_cache_len
|
| 293 |
+
# Very important that it's a tensor here, to avoid recompiling when we update it and use it to create positions
|
| 294 |
+
self.cumulative_length = torch.tensor([0], dtype=int)
|
| 295 |
+
|
| 296 |
+
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
|
| 297 |
+
"""
|
| 298 |
+
Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
|
| 299 |
+
num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
|
| 300 |
+
devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
|
| 301 |
+
|
| 302 |
+
If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
|
| 303 |
+
function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
|
| 304 |
+
internally don't compile the prefill, this is guaranteed to have been called already when compiling.
|
| 305 |
+
If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
|
| 306 |
+
it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
|
| 307 |
+
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
|
| 308 |
+
not be compiled anyway for performances!
|
| 309 |
+
"""
|
| 310 |
+
self.dtype, self.device = key_states.dtype, key_states.device
|
| 311 |
+
self.max_batch_size, self.num_heads = key_states.shape[:2]
|
| 312 |
+
self.v_head_dim = value_states.shape[-1]
|
| 313 |
+
self.k_head_dim = key_states.shape[-1]
|
| 314 |
+
|
| 315 |
+
self.keys = torch.zeros(
|
| 316 |
+
(self.max_batch_size, self.num_heads, self.max_cache_len, self.k_head_dim),
|
| 317 |
+
dtype=self.dtype,
|
| 318 |
+
device=self.device,
|
| 319 |
+
)
|
| 320 |
+
self.values = torch.zeros(
|
| 321 |
+
(self.max_batch_size, self.num_heads, self.max_cache_len, self.v_head_dim),
|
| 322 |
+
dtype=self.dtype,
|
| 323 |
+
device=self.device,
|
| 324 |
+
)
|
| 325 |
+
self.cumulative_length = self.cumulative_length.to(self.device)
|
| 326 |
+
# Note: `mark_static_address` is used to tag the tensors as a fixed data pointer, preventing compiled graph
|
| 327 |
+
# breaks or cudagraph skips due to inplace mutations when updating the cache. However, it is not supported when
|
| 328 |
+
# tracing the graph, so we skip it in this case. As prefill should never be compiled, this is not an issue and it
|
| 329 |
+
# will still be run (except when users compile prefill explicitly, but this should be avoided!)
|
| 330 |
+
# Without this, we cannot use cudagraphs!!
|
| 331 |
+
if not is_torchdynamo_compiling():
|
| 332 |
+
torch._dynamo.mark_static_address(self.keys)
|
| 333 |
+
torch._dynamo.mark_static_address(self.values)
|
| 334 |
+
torch._dynamo.mark_static_address(self.cumulative_length)
|
| 335 |
+
|
| 336 |
+
self.is_initialized = True
|
| 337 |
+
|
| 338 |
+
def update(
|
| 339 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
|
| 340 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 341 |
+
"""
|
| 342 |
+
Update the key and value caches in-place, and return the necessary keys and value states.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
key_states (`torch.Tensor`): The new key states to cache.
|
| 346 |
+
value_states (`torch.Tensor`): The new value states to cache.
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
| 350 |
+
"""
|
| 351 |
+
# Lazy initialization
|
| 352 |
+
if not self.is_initialized:
|
| 353 |
+
self.lazy_initialization(key_states, value_states)
|
| 354 |
+
|
| 355 |
+
# Create a tensor to slice the static kv at the correct indices
|
| 356 |
+
kv_length = key_states.shape[-2]
|
| 357 |
+
cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
|
| 358 |
+
# Note that has to be performed in-place, as we have a static address that we need to keep
|
| 359 |
+
self.cumulative_length.add_(kv_length)
|
| 360 |
+
|
| 361 |
+
# Update the cache
|
| 362 |
+
try:
|
| 363 |
+
self.keys.index_copy_(2, cache_position, key_states)
|
| 364 |
+
self.values.index_copy_(2, cache_position, value_states)
|
| 365 |
+
except NotImplementedError:
|
| 366 |
+
# Fallback for devices like MPS where index_copy_ might not be supported.
|
| 367 |
+
self.keys[:, :, cache_position] = key_states
|
| 368 |
+
self.values[:, :, cache_position] = value_states
|
| 369 |
+
|
| 370 |
+
return self.keys, self.values
|
| 371 |
+
|
| 372 |
+
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
|
| 373 |
+
"""Return the length and offset of the cache, used to generate the attention mask"""
|
| 374 |
+
kv_offset = 0
|
| 375 |
+
kv_length = self.max_cache_len
|
| 376 |
+
return kv_length, kv_offset
|
| 377 |
+
|
| 378 |
+
def get_seq_length(self) -> int:
|
| 379 |
+
"""Returns the sequence length of the cached states."""
|
| 380 |
+
return self.cumulative_length if self.is_initialized else 0
|
| 381 |
+
|
| 382 |
+
def get_max_cache_shape(self) -> int:
|
| 383 |
+
"""Return the maximum cache shape of the cache"""
|
| 384 |
+
return self.max_cache_len
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class StaticSlidingWindowLayer(StaticLayer):
|
| 388 |
+
"""
|
| 389 |
+
A static cache layer that stores the key and value states as static tensors of shape
|
| 390 |
+
`[batch_size, num_heads, min(max_cache_len, sliding_window), head_dim]`. It lazily allocates its full backing
|
| 391 |
+
tensors, and then mutates them in-place. Built for `torch.compile` support.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
max_cache_len (`int`):
|
| 395 |
+
Maximum number of tokens that can be stored, used for tensor preallocation.
|
| 396 |
+
sliding_window (`int`):
|
| 397 |
+
The size of the sliding window.
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
is_sliding = True
|
| 401 |
+
|
| 402 |
+
def __init__(self, max_cache_len: int, sliding_window: int):
|
| 403 |
+
effective_max_cache_len = min(sliding_window, max_cache_len)
|
| 404 |
+
super().__init__(max_cache_len=effective_max_cache_len)
|
| 405 |
+
# Here, to avoid data-dependent control flows, we also need to use a python int to keep track of the cumulative length
|
| 406 |
+
self.cumulative_length_int = 0
|
| 407 |
+
|
| 408 |
+
def update(
|
| 409 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
|
| 410 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 411 |
+
"""
|
| 412 |
+
Update the key and value caches in-place, and return the necessary keys and value states.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
key_states (`torch.Tensor`): The new key states to cache.
|
| 416 |
+
value_states (`torch.Tensor`): The new value states to cache.
|
| 417 |
+
|
| 418 |
+
Returns:
|
| 419 |
+
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
| 420 |
+
"""
|
| 421 |
+
# Lazy initialization
|
| 422 |
+
if not self.is_initialized:
|
| 423 |
+
self.lazy_initialization(key_states, value_states)
|
| 424 |
+
|
| 425 |
+
kv_length = key_states.shape[-2]
|
| 426 |
+
current_length = self.cumulative_length_int
|
| 427 |
+
is_full = current_length >= self.max_cache_len
|
| 428 |
+
# Update it now that we saved the value above
|
| 429 |
+
self.cumulative_length_int += kv_length
|
| 430 |
+
|
| 431 |
+
if is_full:
|
| 432 |
+
# In general, we should use a much simpler `cat` here as well, independently of the states size. However,
|
| 433 |
+
# dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details
|
| 434 |
+
if key_states.shape[-2] == 1:
|
| 435 |
+
# Roll all values to the left by 1 position
|
| 436 |
+
new_keys = self.keys.roll(-1, dims=-2)
|
| 437 |
+
new_values = self.values.roll(-1, dims=-2)
|
| 438 |
+
# Overwrite the last position with new states
|
| 439 |
+
# (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
|
| 440 |
+
index = torch.tensor([-1], dtype=int, device=self.device)
|
| 441 |
+
new_keys[:, :, index] = key_states
|
| 442 |
+
new_values[:, :, index] = value_states
|
| 443 |
+
|
| 444 |
+
# Copy back into `self` (do not just assign again) in order to keep the static dynamo address
|
| 445 |
+
self.keys.copy_(new_keys)
|
| 446 |
+
self.values.copy_(new_values)
|
| 447 |
+
|
| 448 |
+
# Very important to return the `self` tensors here, as they have the static dynamo address
|
| 449 |
+
return self.keys, self.values
|
| 450 |
+
# Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...)
|
| 451 |
+
else:
|
| 452 |
+
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
|
| 453 |
+
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
|
| 454 |
+
# Not yet full, but becoming full on this update
|
| 455 |
+
elif current_length + kv_length > self.max_cache_len:
|
| 456 |
+
# Fast prefill path, no need to cat() in this case, as the cache is currently empty
|
| 457 |
+
if current_length == 0:
|
| 458 |
+
full_key_states = key_states
|
| 459 |
+
full_value_states = value_states
|
| 460 |
+
else:
|
| 461 |
+
full_key_states = torch.cat((self.keys[:, :, :current_length, :], key_states), dim=-2)
|
| 462 |
+
full_value_states = torch.cat((self.values[:, :, :current_length, :], value_states), dim=-2)
|
| 463 |
+
else:
|
| 464 |
+
# Note: very important to use the tensor version of the cumulative length here, as otherwise cudagraphs
|
| 465 |
+
# (triggered by mode="reduced_overhead") will lead to random crashes, as the int would be overwritten
|
| 466 |
+
cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
|
| 467 |
+
try:
|
| 468 |
+
self.keys.index_copy_(2, cache_position, key_states)
|
| 469 |
+
self.values.index_copy_(2, cache_position, value_states)
|
| 470 |
+
except NotImplementedError:
|
| 471 |
+
self.keys[:, :, cache_position] = key_states
|
| 472 |
+
self.values[:, :, cache_position] = value_states
|
| 473 |
+
|
| 474 |
+
# Update the tensor version of the length in-place (we don't need to update it if we are already outside
|
| 475 |
+
# of this branch, as we don't need the tensor anymore)
|
| 476 |
+
self.cumulative_length.add_(kv_length)
|
| 477 |
+
|
| 478 |
+
# Very important to return the `self` tensors here, as they have the static dynamo address
|
| 479 |
+
return self.keys, self.values
|
| 480 |
+
|
| 481 |
+
# We only cache the last `sliding_window` tokens
|
| 482 |
+
self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
|
| 483 |
+
self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
|
| 484 |
+
# we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context
|
| 485 |
+
return full_key_states, full_value_states
|
| 486 |
+
|
| 487 |
+
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
|
| 488 |
+
"""Return the length and offset of the cache, used to generate the attention mask"""
|
| 489 |
+
sliding_window = self.max_cache_len
|
| 490 |
+
is_full = self.cumulative_length_int >= self.max_cache_len
|
| 491 |
+
|
| 492 |
+
kv_offset = max(self.cumulative_length_int - sliding_window + 1, 0)
|
| 493 |
+
# The cache is already full
|
| 494 |
+
if is_full:
|
| 495 |
+
kv_length = sliding_window + query_length - 1
|
| 496 |
+
# Not yet full, but becoming full on this update
|
| 497 |
+
elif self.cumulative_length_int + query_length > sliding_window:
|
| 498 |
+
kv_length = self.cumulative_length_int + query_length
|
| 499 |
+
# Here the Cache is still smaller than the local size, but we return the local size as it's static
|
| 500 |
+
else:
|
| 501 |
+
kv_length = sliding_window
|
| 502 |
+
|
| 503 |
+
return kv_length, kv_offset
|
| 504 |
+
|
| 505 |
+
def get_seq_length(self) -> int:
|
| 506 |
+
"""Returns the sequence length of the cached states."""
|
| 507 |
+
return self.cumulative_length_int
|
| 508 |
+
|
| 509 |
+
def reset(self):
|
| 510 |
+
super().reset()
|
| 511 |
+
self.cumulative_length_int = 0
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class QuantizedLayer(DynamicLayer):
|
| 515 |
+
"""
|
| 516 |
+
A quantized layer similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
|
| 517 |
+
It allows the model to generate longer sequence length without allocating too much memory for the key and value caches by
|
| 518 |
+
applying quantization.
|
| 519 |
+
|
| 520 |
+
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length`
|
| 521 |
+
is set as a maximum capacity for the original precision cache. When the length goes beyond maximum capacity, the original
|
| 522 |
+
precision cache is discarded and moved into the quantized cache. The quantization is done per-channel with a set `q_group_size`
|
| 523 |
+
for both Keys and Values, in contrast to what was described in the paper.
|
| 524 |
+
"""
|
| 525 |
+
|
| 526 |
+
def __init__(
|
| 527 |
+
self,
|
| 528 |
+
nbits: int = 4,
|
| 529 |
+
axis_key: int = 0,
|
| 530 |
+
axis_value: int = 0,
|
| 531 |
+
q_group_size: int = 64,
|
| 532 |
+
residual_length: int = 128,
|
| 533 |
+
):
|
| 534 |
+
super().__init__()
|
| 535 |
+
self.nbits = nbits
|
| 536 |
+
self.axis_key = axis_key
|
| 537 |
+
self.axis_value = axis_value
|
| 538 |
+
self.q_group_size = q_group_size
|
| 539 |
+
self.residual_length = residual_length
|
| 540 |
+
self.cumulative_length = 0
|
| 541 |
+
|
| 542 |
+
def update(
|
| 543 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
|
| 544 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 545 |
+
"""
|
| 546 |
+
Update the key and value caches in-place, and return the necessary keys and value states.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
key_states (`torch.Tensor`): The new key states to cache.
|
| 550 |
+
value_states (`torch.Tensor`): The new value states to cache.
|
| 551 |
+
|
| 552 |
+
Returns:
|
| 553 |
+
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
|
| 554 |
+
"""
|
| 555 |
+
self.cumulative_length += key_states.shape[-2]
|
| 556 |
+
|
| 557 |
+
# Lazy initialization
|
| 558 |
+
if not self.is_initialized:
|
| 559 |
+
self.lazy_initialization(key_states, value_states)
|
| 560 |
+
self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key)
|
| 561 |
+
self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value)
|
| 562 |
+
return key_states, value_states
|
| 563 |
+
|
| 564 |
+
dequant_keys = self._dequantize(self._quantized_keys)
|
| 565 |
+
dequant_values = self._dequantize(self._quantized_values)
|
| 566 |
+
keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2)
|
| 567 |
+
values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2)
|
| 568 |
+
if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
|
| 569 |
+
self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
|
| 570 |
+
self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
|
| 571 |
+
self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
|
| 572 |
+
self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
|
| 573 |
+
else:
|
| 574 |
+
self.keys = torch.cat([self.keys, key_states], dim=-2)
|
| 575 |
+
self.values = torch.cat([self.values, value_states], dim=-2)
|
| 576 |
+
|
| 577 |
+
return keys_to_return, values_to_return
|
| 578 |
+
|
| 579 |
+
@abstractmethod
|
| 580 |
+
def _quantize(self, tensor, axis): ...
|
| 581 |
+
|
| 582 |
+
@abstractmethod
|
| 583 |
+
def _dequantize(self, q_tensor): ...
|
| 584 |
+
|
| 585 |
+
def get_seq_length(self) -> int:
|
| 586 |
+
"""Returns the sequence length of the cached states."""
|
| 587 |
+
return self.cumulative_length
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class QuantoQuantizedLayer(QuantizedLayer):
|
| 591 |
+
def __init__(
|
| 592 |
+
self,
|
| 593 |
+
nbits: int = 4,
|
| 594 |
+
axis_key: int = 0,
|
| 595 |
+
axis_value: int = 0,
|
| 596 |
+
q_group_size: int = 64,
|
| 597 |
+
residual_length: int = 128,
|
| 598 |
+
):
|
| 599 |
+
super().__init__(
|
| 600 |
+
nbits=nbits,
|
| 601 |
+
axis_key=axis_key,
|
| 602 |
+
axis_value=axis_value,
|
| 603 |
+
q_group_size=q_group_size,
|
| 604 |
+
residual_length=residual_length,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
# We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
|
| 608 |
+
if not is_optimum_quanto_available():
|
| 609 |
+
raise ImportError(
|
| 610 |
+
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto "
|
| 611 |
+
"backend. Please install it via with `pip install optimum-quanto`"
|
| 612 |
+
)
|
| 613 |
+
elif is_quanto_greater("0.2.5", accept_dev=True):
|
| 614 |
+
from optimum.quanto import MaxOptimizer, qint2, qint4
|
| 615 |
+
else:
|
| 616 |
+
raise ImportError(
|
| 617 |
+
"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedLayer`. "
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
if self.nbits not in [2, 4]:
|
| 621 |
+
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
|
| 622 |
+
|
| 623 |
+
if self.axis_key not in [0, -1]:
|
| 624 |
+
raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
|
| 625 |
+
|
| 626 |
+
if self.axis_value not in [0, -1]:
|
| 627 |
+
raise ValueError(
|
| 628 |
+
f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
self.qtype = qint4 if self.nbits == 4 else qint2
|
| 632 |
+
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
|
| 633 |
+
|
| 634 |
+
def _quantize(self, tensor, axis):
|
| 635 |
+
from optimum.quanto import quantize_weight
|
| 636 |
+
|
| 637 |
+
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
|
| 638 |
+
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
|
| 639 |
+
return qtensor
|
| 640 |
+
|
| 641 |
+
def _dequantize(self, qtensor):
|
| 642 |
+
return qtensor.dequantize()
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
class HQQQuantizedLayer(QuantizedLayer):
|
| 646 |
+
def __init__(
|
| 647 |
+
self,
|
| 648 |
+
nbits: int = 4,
|
| 649 |
+
axis_key: int = 0,
|
| 650 |
+
axis_value: int = 0,
|
| 651 |
+
q_group_size: int = 64,
|
| 652 |
+
residual_length: int = 128,
|
| 653 |
+
):
|
| 654 |
+
super().__init__(
|
| 655 |
+
nbits=nbits,
|
| 656 |
+
axis_key=axis_key,
|
| 657 |
+
axis_value=axis_value,
|
| 658 |
+
q_group_size=q_group_size,
|
| 659 |
+
residual_length=residual_length,
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
if not is_hqq_available():
|
| 663 |
+
raise ImportError(
|
| 664 |
+
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
|
| 665 |
+
"Please install it via with `pip install hqq`"
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
if self.nbits not in [1, 2, 3, 4, 8]:
|
| 669 |
+
raise ValueError(
|
| 670 |
+
f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
if self.axis_key not in [0, 1]:
|
| 674 |
+
raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
|
| 675 |
+
|
| 676 |
+
if self.axis_value not in [0, 1]:
|
| 677 |
+
raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
|
| 678 |
+
|
| 679 |
+
self.quantizer = HQQQuantizer
|
| 680 |
+
|
| 681 |
+
def _quantize(self, tensor, axis):
|
| 682 |
+
qtensor, meta = self.quantizer.quantize(
|
| 683 |
+
tensor,
|
| 684 |
+
axis=axis,
|
| 685 |
+
device=self.keys.device,
|
| 686 |
+
compute_dtype=self.keys.dtype,
|
| 687 |
+
nbits=self.nbits,
|
| 688 |
+
group_size=self.q_group_size,
|
| 689 |
+
)
|
| 690 |
+
meta["compute_dtype"] = self.keys.dtype
|
| 691 |
+
self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device) # Move to device and cast to dtype
|
| 692 |
+
meta["scale"] = meta["scale"].to(qtensor.device)
|
| 693 |
+
meta["zero"] = meta["zero"].to(qtensor.device)
|
| 694 |
+
return qtensor, meta
|
| 695 |
+
|
| 696 |
+
def _dequantize(self, qtensor):
|
| 697 |
+
quant_tensor, meta = qtensor
|
| 698 |
+
tensor = self.quantizer.dequantize(quant_tensor, meta)
|
| 699 |
+
return tensor
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
class LinearAttentionCacheLayerMixin(ABC):
|
| 703 |
+
"""Base, abstract class for a linear attention single layer's cache."""
|
| 704 |
+
|
| 705 |
+
# All shapes are static by essence in a LinearAttention layer, so it is compileable
|
| 706 |
+
is_compileable = True
|
| 707 |
+
|
| 708 |
+
def __init__(self):
|
| 709 |
+
self.conv_states: torch.Tensor | None = None
|
| 710 |
+
self.recurrent_states: torch.Tensor | None = None
|
| 711 |
+
self.is_conv_states_initialized = False
|
| 712 |
+
self.is_recurrent_states_initialized = False
|
| 713 |
+
self.has_previous_state = False
|
| 714 |
+
|
| 715 |
+
def __repr__(self):
|
| 716 |
+
return f"{self.__class__.__name__}"
|
| 717 |
+
|
| 718 |
+
@abstractmethod
|
| 719 |
+
def lazy_initialization(
|
| 720 |
+
self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None
|
| 721 |
+
) -> None: ...
|
| 722 |
+
|
| 723 |
+
@abstractmethod
|
| 724 |
+
def update_conv_state(self, conv_states: torch.Tensor) -> torch.Tensor: ...
|
| 725 |
+
|
| 726 |
+
@abstractmethod
|
| 727 |
+
def update_recurrent_state(self, recurrent_states: torch.Tensor) -> torch.Tensor: ...
|
| 728 |
+
|
| 729 |
+
def offload(self):
|
| 730 |
+
"""Offload this layer's data to CPU device."""
|
| 731 |
+
if self.is_conv_states_initialized:
|
| 732 |
+
self.conv_states = self.conv_states.to("cpu", non_blocking=True)
|
| 733 |
+
if self.is_recurrent_states_initialized:
|
| 734 |
+
self.recurrent_states = self.recurrent_states.to("cpu", non_blocking=True)
|
| 735 |
+
|
| 736 |
+
def prefetch(self):
|
| 737 |
+
"""In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
|
| 738 |
+
if self.is_conv_states_initialized and self.conv_states.device != self.device:
|
| 739 |
+
self.conv_states = self.conv_states.to(self.device, non_blocking=True)
|
| 740 |
+
if self.is_recurrent_states_initialized and self.recurrent_states.device != self.device:
|
| 741 |
+
self.recurrent_states = self.recurrent_states.to(self.device, non_blocking=True)
|
| 742 |
+
|
| 743 |
+
def reset(self) -> None:
|
| 744 |
+
"""Resets the cache values while preserving the objects"""
|
| 745 |
+
if self.is_conv_states_initialized:
|
| 746 |
+
self.conv_states.zero_()
|
| 747 |
+
if self.is_recurrent_states_initialized:
|
| 748 |
+
self.recurrent_states.zero_()
|
| 749 |
+
self.has_previous_state = False
|
| 750 |
+
|
| 751 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 752 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
| 753 |
+
if self.is_conv_states_initialized:
|
| 754 |
+
self.conv_states = self.conv_states.index_select(0, beam_idx.to(self.device))
|
| 755 |
+
# recurrent_states can stay empty sometimes, see e.g. lfm2 which only uses the conv_states
|
| 756 |
+
if self.is_recurrent_states_initialized:
|
| 757 |
+
self.recurrent_states = self.recurrent_states.index_select(0, beam_idx.to(self.device))
|
| 758 |
+
|
| 759 |
+
def crop(self, max_length: int):
|
| 760 |
+
# We don't crop the linear attention cache, so simply do nothing here
|
| 761 |
+
pass
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
class LinearAttentionLayer(LinearAttentionCacheLayerMixin):
|
| 765 |
+
def __init__(self, config: PreTrainedConfig | None = None):
|
| 766 |
+
super().__init__()
|
| 767 |
+
|
| 768 |
+
def lazy_initialization(
|
| 769 |
+
self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None
|
| 770 |
+
) -> None:
|
| 771 |
+
# Here, we will lazy init both states separately, each in their own update function
|
| 772 |
+
if conv_states is not None:
|
| 773 |
+
self.dtype, self.device = conv_states.dtype, conv_states.device
|
| 774 |
+
# Even if prefill is larfer/shorter than the conv_size, the tensor is always either padded or truncated
|
| 775 |
+
self.max_batch_size, self.conv_kernel_size = conv_states.shape[0], conv_states.shape[-1]
|
| 776 |
+
# The shape is always static, so we init as such
|
| 777 |
+
self.conv_states = torch.zeros_like(conv_states, dtype=self.dtype, device=self.device)
|
| 778 |
+
# Mark as static address to be able to use cudagraphs
|
| 779 |
+
if not is_torchdynamo_compiling():
|
| 780 |
+
torch._dynamo.mark_static_address(self.conv_states)
|
| 781 |
+
self.is_conv_states_initialized = True
|
| 782 |
+
if recurrent_states is not None:
|
| 783 |
+
# The shape is always static, so we init as such
|
| 784 |
+
self.recurrent_states = torch.zeros_like(recurrent_states, dtype=self.dtype, device=self.device)
|
| 785 |
+
# Mark as static address to be able to use cudagraphs
|
| 786 |
+
if not is_torchdynamo_compiling():
|
| 787 |
+
torch._dynamo.mark_static_address(self.recurrent_states)
|
| 788 |
+
self.is_recurrent_states_initialized = True
|
| 789 |
+
|
| 790 |
+
def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 791 |
+
"""
|
| 792 |
+
Update the linear attention cache in-place, and return the necessary conv states.
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
conv_states (`torch.Tensor`): The new conv states to cache.
|
| 796 |
+
|
| 797 |
+
Returns:
|
| 798 |
+
`torch.Tensor`: The updated conv states.
|
| 799 |
+
"""
|
| 800 |
+
# Lazy initialization
|
| 801 |
+
if not self.is_conv_states_initialized:
|
| 802 |
+
self.lazy_initialization(conv_states=conv_states)
|
| 803 |
+
|
| 804 |
+
if not self.has_previous_state:
|
| 805 |
+
# Note that we copy instead of assigning, to preserve the static address for cudagraphs
|
| 806 |
+
self.conv_states.copy_(conv_states)
|
| 807 |
+
self.has_previous_state = True
|
| 808 |
+
# Technically, this update is not logically correct if the prefill is smaller than `conv_kernel_size`,
|
| 809 |
+
# as it will `roll` anyway in the first decoding step, even though it should `roll` ONLY if the cache is already full.
|
| 810 |
+
# But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now
|
| 811 |
+
else:
|
| 812 |
+
# Note that we copy instead of assigning, to preserve the static address for cudagraphs
|
| 813 |
+
num_new_tokens = conv_states.shape[-1]
|
| 814 |
+
if num_new_tokens >= self.conv_kernel_size:
|
| 815 |
+
self.conv_states.copy_(conv_states[..., -self.conv_kernel_size :])
|
| 816 |
+
else:
|
| 817 |
+
new_conv_states = self.conv_states.roll(shifts=-num_new_tokens, dims=-1)
|
| 818 |
+
new_conv_states[:, :, -num_new_tokens:] = conv_states
|
| 819 |
+
self.conv_states.copy_(new_conv_states)
|
| 820 |
+
|
| 821 |
+
return self.conv_states
|
| 822 |
+
|
| 823 |
+
def update_recurrent_state(self, recurrent_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 824 |
+
"""
|
| 825 |
+
Update the linear attention cache in-place, and return the necessary ssm states.
|
| 826 |
+
|
| 827 |
+
Args:
|
| 828 |
+
smm_states (`torch.Tensor`): The new ssm states to cache.
|
| 829 |
+
|
| 830 |
+
Returns:
|
| 831 |
+
`torch.Tensor`: The updated ssm states.
|
| 832 |
+
"""
|
| 833 |
+
if not self.is_recurrent_states_initialized:
|
| 834 |
+
self.lazy_initialization(recurrent_states=recurrent_states)
|
| 835 |
+
# Note that we copy instead of assigning, to preserve the static address for cudagraphs
|
| 836 |
+
self.recurrent_states.copy_(recurrent_states)
|
| 837 |
+
return self.recurrent_states
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
class LinearAttentionAndFullAttentionLayer(LinearAttentionLayer, DynamicLayer):
|
| 841 |
+
# The dynamic Attention part makes it non-compileable
|
| 842 |
+
is_compileable = False
|
| 843 |
+
|
| 844 |
+
def __init__(self, config: PreTrainedConfig | None = None):
|
| 845 |
+
DynamicLayer.__init__(self)
|
| 846 |
+
LinearAttentionLayer.__init__(self)
|
| 847 |
+
|
| 848 |
+
def lazy_initialization(self, *args, **kwargs) -> None:
|
| 849 |
+
# When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args
|
| 850 |
+
if len(args) == 2 and len(kwargs) == 0:
|
| 851 |
+
DynamicLayer.lazy_initialization(self, *args)
|
| 852 |
+
# Otherwise, for the LinearAttention cache, when it's called in `update_conv_state` or `update_recurrent_state`, it's
|
| 853 |
+
# always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states)
|
| 854 |
+
if len(args) == 0 and len(kwargs) == 1:
|
| 855 |
+
LinearAttentionLayer.lazy_initialization(self, **kwargs)
|
| 856 |
+
|
| 857 |
+
def reset(self) -> None:
|
| 858 |
+
LinearAttentionLayer.reset(self)
|
| 859 |
+
DynamicLayer.reset(self)
|
| 860 |
+
|
| 861 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 862 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
| 863 |
+
LinearAttentionLayer.reorder_cache(self, beam_idx)
|
| 864 |
+
DynamicLayer.reorder_cache(self, beam_idx)
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
# Pre-register the standard layer types (some classes are shared between multiple types,
|
| 868 |
+
# e.g. ``DynamicSlidingWindowLayer`` covers both ``"sliding_attention"`` and
|
| 869 |
+
# ``"chunked_attention"`` — those need an explicit map entry rather than the
|
| 870 |
+
# auto-registration via ``CacheLayerMixin.__init_subclass__``).
|
| 871 |
+
LAYER_TYPE_CACHE_MAPPING.update(
|
| 872 |
+
{
|
| 873 |
+
"full_attention": DynamicLayer,
|
| 874 |
+
# From a cache point of view, sliding and chunked are the same in how they should behave;
|
| 875 |
+
# only the mask differs.
|
| 876 |
+
"sliding_attention": DynamicSlidingWindowLayer,
|
| 877 |
+
"chunked_attention": DynamicSlidingWindowLayer,
|
| 878 |
+
# Linear-attention-shaped layers (mamba / conv / pure linear-attention / moe placeholders)
|
| 879 |
+
# don't grow per-token KV; they're tracked just so position bookkeeping stays consistent.
|
| 880 |
+
"mamba": LinearAttentionLayer,
|
| 881 |
+
"conv": LinearAttentionLayer,
|
| 882 |
+
"linear_attention": LinearAttentionLayer,
|
| 883 |
+
"moe": LinearAttentionLayer,
|
| 884 |
+
# Hybrid layers (e.g. zamba / zamba2) carry both a linear-attention state and a dynamic-attention state.
|
| 885 |
+
"hybrid": LinearAttentionAndFullAttentionLayer,
|
| 886 |
+
}
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
class Cache:
|
| 891 |
+
"""
|
| 892 |
+
A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
|
| 893 |
+
the Cache of each layer.
|
| 894 |
+
|
| 895 |
+
Args:
|
| 896 |
+
layers (`Optional`, *optional*):
|
| 897 |
+
A list of pre-created `CacheLayerMixin` or `LinearAttentionCacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate`
|
| 898 |
+
will be used.
|
| 899 |
+
layer_class_to_replicate (`type[CacheLayerMixin | LinearAttentionCacheLayerMixin]`, *optional*):
|
| 900 |
+
Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer,
|
| 901 |
+
and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current
|
| 902 |
+
list of layers.
|
| 903 |
+
offloading (`bool`, *optional*, defaults to `False`):
|
| 904 |
+
Whether to perform offloading of the layers to `cpu`, to save GPU memory.
|
| 905 |
+
offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
|
| 906 |
+
If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
|
| 907 |
+
usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
|
| 908 |
+
"""
|
| 909 |
+
|
| 910 |
+
def __init__(
|
| 911 |
+
self,
|
| 912 |
+
layers: list[CacheLayerMixin | LinearAttentionCacheLayerMixin] | None = None,
|
| 913 |
+
layer_class_to_replicate: type[CacheLayerMixin | LinearAttentionCacheLayerMixin] | None = None,
|
| 914 |
+
offloading: bool = False,
|
| 915 |
+
offload_only_non_sliding: bool = True,
|
| 916 |
+
):
|
| 917 |
+
if layers is not None and layer_class_to_replicate is not None:
|
| 918 |
+
raise ValueError(
|
| 919 |
+
"You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a "
|
| 920 |
+
"`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to "
|
| 921 |
+
"`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache."
|
| 922 |
+
)
|
| 923 |
+
if layers is None and layer_class_to_replicate is None:
|
| 924 |
+
raise ValueError(
|
| 925 |
+
"You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache."
|
| 926 |
+
)
|
| 927 |
+
self.layers = layers if layers is not None else []
|
| 928 |
+
self.layer_class_to_replicate = layer_class_to_replicate
|
| 929 |
+
self.offloading = offloading
|
| 930 |
+
if self.offloading:
|
| 931 |
+
self.only_non_sliding = offload_only_non_sliding
|
| 932 |
+
self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
|
| 933 |
+
|
| 934 |
+
def __repr__(self):
|
| 935 |
+
return f"{self.__class__.__name__}(layers={self.layers})"
|
| 936 |
+
|
| 937 |
+
def prefetch(self, layer_idx: int, only_non_sliding: bool = True):
|
| 938 |
+
"""
|
| 939 |
+
Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers
|
| 940 |
+
which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers.
|
| 941 |
+
Note that we use a non-default stream for this, to avoid blocking.
|
| 942 |
+
"""
|
| 943 |
+
if only_non_sliding:
|
| 944 |
+
# Try to find next non-sliding, starting at `layer_idx`
|
| 945 |
+
try:
|
| 946 |
+
layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False)
|
| 947 |
+
# In this case, we need to circle back to the beginning
|
| 948 |
+
except ValueError:
|
| 949 |
+
layer_idx = self.is_sliding.index(False)
|
| 950 |
+
else:
|
| 951 |
+
layer_idx = layer_idx if layer_idx < len(self.layers) else 0
|
| 952 |
+
|
| 953 |
+
# Prefetch
|
| 954 |
+
with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
|
| 955 |
+
self.layers[layer_idx].prefetch()
|
| 956 |
+
|
| 957 |
+
def offload(self, layer_idx: int, only_non_sliding: bool = True):
|
| 958 |
+
"""
|
| 959 |
+
Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a
|
| 960 |
+
non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier
|
| 961 |
+
computation in the layer's `update` methods are finished.
|
| 962 |
+
"""
|
| 963 |
+
if not (only_non_sliding and self.is_sliding[layer_idx]):
|
| 964 |
+
self.layers[layer_idx].offload()
|
| 965 |
+
|
| 966 |
+
def update(
|
| 967 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args, **kwargs
|
| 968 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 969 |
+
"""
|
| 970 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
| 971 |
+
|
| 972 |
+
Parameters:
|
| 973 |
+
key_states (`torch.Tensor`):
|
| 974 |
+
The new key states to cache.
|
| 975 |
+
value_states (`torch.Tensor`):
|
| 976 |
+
The new value states to cache.
|
| 977 |
+
layer_idx (`int`):
|
| 978 |
+
The index of the layer to cache the states for.
|
| 979 |
+
|
| 980 |
+
Return:
|
| 981 |
+
A tuple containing the updated key and value states.
|
| 982 |
+
"""
|
| 983 |
+
# In this case, the `layers` were not provided, and we must append as much as `layer_idx`
|
| 984 |
+
if self.layer_class_to_replicate is not None:
|
| 985 |
+
while len(self.layers) <= layer_idx:
|
| 986 |
+
self.layers.append(self.layer_class_to_replicate())
|
| 987 |
+
|
| 988 |
+
if self.offloading:
|
| 989 |
+
# Wait for the stream to finish if needed, and start prefetching the next layer
|
| 990 |
+
torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
|
| 991 |
+
self.prefetch(layer_idx + 1, self.only_non_sliding)
|
| 992 |
+
|
| 993 |
+
keys, values = self.layers[layer_idx].update(key_states, value_states, *args, **kwargs)
|
| 994 |
+
|
| 995 |
+
if self.offloading:
|
| 996 |
+
self.offload(layer_idx, self.only_non_sliding)
|
| 997 |
+
|
| 998 |
+
return keys, values
|
| 999 |
+
|
| 1000 |
+
def update_conv_state(self, conv_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor:
|
| 1001 |
+
"""
|
| 1002 |
+
Updates the cache with the new `conv_states` for the layer `layer_idx`.
|
| 1003 |
+
|
| 1004 |
+
Parameters:
|
| 1005 |
+
conv_states (`torch.Tensor`):
|
| 1006 |
+
The new conv states to cache.
|
| 1007 |
+
layer_idx (`int`):
|
| 1008 |
+
The index of the layer to cache the states for.
|
| 1009 |
+
|
| 1010 |
+
Return:
|
| 1011 |
+
`torch.Tensor`: The updated conv states.
|
| 1012 |
+
"""
|
| 1013 |
+
# NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support
|
| 1014 |
+
# out of the box
|
| 1015 |
+
if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin):
|
| 1016 |
+
raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!")
|
| 1017 |
+
conv_states = self.layers[layer_idx].update_conv_state(conv_states, **kwargs)
|
| 1018 |
+
return conv_states
|
| 1019 |
+
|
| 1020 |
+
def update_recurrent_state(self, recurrent_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor:
|
| 1021 |
+
"""
|
| 1022 |
+
Updates the cache with the new `recurrent_states` for the layer `layer_idx`.
|
| 1023 |
+
|
| 1024 |
+
Parameters:
|
| 1025 |
+
smm_states (`torch.Tensor`):
|
| 1026 |
+
The new ssm states to cache.
|
| 1027 |
+
layer_idx (`int`):
|
| 1028 |
+
The index of the layer to cache the states for.
|
| 1029 |
+
|
| 1030 |
+
Return:
|
| 1031 |
+
`torch.Tensor`: The updated ssm states.
|
| 1032 |
+
"""
|
| 1033 |
+
# NOTE: if we slightly break `update` arg order, we could combine this with it, and allow offloading support
|
| 1034 |
+
# out of the box
|
| 1035 |
+
if not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin):
|
| 1036 |
+
raise ValueError("Cannot call `update_conv_state` on a non-LinearAttention layer!")
|
| 1037 |
+
recurrent_states = self.layers[layer_idx].update_recurrent_state(recurrent_states, **kwargs)
|
| 1038 |
+
return recurrent_states
|
| 1039 |
+
|
| 1040 |
+
def early_initialization(
|
| 1041 |
+
self,
|
| 1042 |
+
batch_size: int,
|
| 1043 |
+
num_heads: int | list[int],
|
| 1044 |
+
head_dim: int | list[int],
|
| 1045 |
+
dtype: torch.dtype,
|
| 1046 |
+
device: torch.device,
|
| 1047 |
+
):
|
| 1048 |
+
"""
|
| 1049 |
+
Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
|
| 1050 |
+
This is useful for our `export` recipes, as `export` needs everything in advance.
|
| 1051 |
+
"""
|
| 1052 |
+
# To allow different num_heads and head_dim depending on layers, we accept lists
|
| 1053 |
+
if isinstance(num_heads, int):
|
| 1054 |
+
num_heads = [num_heads] * len(self)
|
| 1055 |
+
if isinstance(head_dim, int):
|
| 1056 |
+
head_dim = [head_dim] * len(self)
|
| 1057 |
+
|
| 1058 |
+
if len(num_heads) != len(self.layers):
|
| 1059 |
+
raise ValueError(
|
| 1060 |
+
f"`num_head` was provided as a list of length {len(num_heads)}, but the Cache currently has {len(self.layers)} layers"
|
| 1061 |
+
)
|
| 1062 |
+
if len(head_dim) != len(self.layers):
|
| 1063 |
+
raise ValueError(
|
| 1064 |
+
f"`head_dim` was provided as a list of length {len(num_heads)}, but the Cache currently has {len(self.layers)} layers"
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
for layer, layer_num_heads, layer_head_dim in zip(self.layers, num_heads, head_dim):
|
| 1068 |
+
# Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
|
| 1069 |
+
# this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
|
| 1070 |
+
# creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
|
| 1071 |
+
fake_kv_tensor = torch.zeros((batch_size, layer_num_heads, 0, layer_head_dim), dtype=dtype, device=device)
|
| 1072 |
+
# Init the layer
|
| 1073 |
+
layer.lazy_initialization(fake_kv_tensor, fake_kv_tensor)
|
| 1074 |
+
|
| 1075 |
+
def get_seq_length(self, layer_idx: int = 0) -> int:
|
| 1076 |
+
"""Returns the sequence length of the cache for the given layer."""
|
| 1077 |
+
if layer_idx >= len(self.layers):
|
| 1078 |
+
return 0
|
| 1079 |
+
|
| 1080 |
+
# For alternating attention/linear attention caches, `get_seq_length` needs to use attention layer idx when called with default layer_idx
|
| 1081 |
+
if not isinstance(self.layers[layer_idx], CacheLayerMixin):
|
| 1082 |
+
# If this is called with non-default arg, raise
|
| 1083 |
+
if layer_idx != 0:
|
| 1084 |
+
raise ValueError(
|
| 1085 |
+
f"You called `get_seq_length` on layer index {layer_idx}, but this layer is a LinearAttention layer, which "
|
| 1086 |
+
"does not track sequence length."
|
| 1087 |
+
)
|
| 1088 |
+
try:
|
| 1089 |
+
# Use the first attention layer
|
| 1090 |
+
layer_idx = next(idx for idx in range(len(self)) if isinstance(self.layers[idx], CacheLayerMixin))
|
| 1091 |
+
except StopIteration:
|
| 1092 |
+
raise ValueError(
|
| 1093 |
+
"`get_seq_length` can only be called on Attention layers, and the current Cache seem to only contain "
|
| 1094 |
+
"LinearAttention layers."
|
| 1095 |
+
)
|
| 1096 |
+
|
| 1097 |
+
return self.layers[layer_idx].get_seq_length()
|
| 1098 |
+
|
| 1099 |
+
def has_previous_state(self, layer_idx: int | None = None) -> bool:
|
| 1100 |
+
"""Returns whether the LinearAttention layer at index `layer_idx` has previous state or not."""
|
| 1101 |
+
if layer_idx is not None and layer_idx >= len(self.layers):
|
| 1102 |
+
return False
|
| 1103 |
+
|
| 1104 |
+
# In this case, use last LinearAttention layer
|
| 1105 |
+
if layer_idx is None:
|
| 1106 |
+
try:
|
| 1107 |
+
layer_idx = next(
|
| 1108 |
+
idx
|
| 1109 |
+
for idx in range(len(self) - 1, -1, -1)
|
| 1110 |
+
if isinstance(self.layers[idx], LinearAttentionCacheLayerMixin)
|
| 1111 |
+
)
|
| 1112 |
+
except StopIteration:
|
| 1113 |
+
raise ValueError(
|
| 1114 |
+
"`has_previous_state` can only be called on LinearAttention layers, and the current Cache seem to "
|
| 1115 |
+
"only contain Attention layers."
|
| 1116 |
+
)
|
| 1117 |
+
elif not isinstance(self.layers[layer_idx], LinearAttentionCacheLayerMixin):
|
| 1118 |
+
raise ValueError(
|
| 1119 |
+
f"You called `has_previous_state` on layer index {layer_idx}, but this layer is an Attention layer, which "
|
| 1120 |
+
"does not support calling it."
|
| 1121 |
+
)
|
| 1122 |
+
|
| 1123 |
+
return self.layers[layer_idx].has_previous_state
|
| 1124 |
+
|
| 1125 |
+
def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]:
|
| 1126 |
+
"""
|
| 1127 |
+
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
| 1128 |
+
the given layer at `layer_idx`.
|
| 1129 |
+
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
|
| 1130 |
+
"""
|
| 1131 |
+
# For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is
|
| 1132 |
+
# simply the query_length
|
| 1133 |
+
if layer_idx >= len(self.layers):
|
| 1134 |
+
return query_length, 0
|
| 1135 |
+
|
| 1136 |
+
# For alternating attention/linear attention caches, `get_mask_sizes` needs to use attention layer idx when called with default layer_idx
|
| 1137 |
+
if not isinstance(self.layers[layer_idx], CacheLayerMixin):
|
| 1138 |
+
# If this is called with non-default arg, raise
|
| 1139 |
+
if layer_idx != 0:
|
| 1140 |
+
raise ValueError(
|
| 1141 |
+
f"You called `get_mask_sizes` on layer index {layer_idx}, but this layer is a LinearAttention layer, which "
|
| 1142 |
+
"does not track sequence length."
|
| 1143 |
+
)
|
| 1144 |
+
try:
|
| 1145 |
+
# Use the first attention layer
|
| 1146 |
+
layer_idx = next(idx for idx in range(len(self)) if isinstance(self.layers[idx], CacheLayerMixin))
|
| 1147 |
+
except StopIteration:
|
| 1148 |
+
raise ValueError(
|
| 1149 |
+
"`get_mask_sizes` can only be called on Attention layers, and the current Cache seem to only contain "
|
| 1150 |
+
"LinearAttention layers."
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
return self.layers[layer_idx].get_mask_sizes(query_length)
|
| 1154 |
+
|
| 1155 |
+
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
|
| 1156 |
+
"""Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
|
| 1157 |
+
# For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1
|
| 1158 |
+
# as DynamicLayer does
|
| 1159 |
+
if layer_idx >= len(self.layers):
|
| 1160 |
+
return -1
|
| 1161 |
+
return self.layers[layer_idx].get_max_cache_shape()
|
| 1162 |
+
|
| 1163 |
+
def reset(self):
|
| 1164 |
+
"""Recursively reset all layers tensors"""
|
| 1165 |
+
for layer_idx in range(len(self.layers)):
|
| 1166 |
+
self.layers[layer_idx].reset()
|
| 1167 |
+
|
| 1168 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 1169 |
+
"""Reorder the cache for beam search"""
|
| 1170 |
+
for layer_idx in range(len(self.layers)):
|
| 1171 |
+
self.layers[layer_idx].reorder_cache(beam_idx)
|
| 1172 |
+
|
| 1173 |
+
def crop(self, max_length: int):
|
| 1174 |
+
"""Crop the cache to the given length"""
|
| 1175 |
+
for layer_idx in range(len(self.layers)):
|
| 1176 |
+
self.layers[layer_idx].crop(max_length)
|
| 1177 |
+
|
| 1178 |
+
def batch_repeat_interleave(self, repeats: int):
|
| 1179 |
+
"""Repeat and interleave the cache"""
|
| 1180 |
+
for layer_idx in range(len(self.layers)):
|
| 1181 |
+
self.layers[layer_idx].batch_repeat_interleave(repeats)
|
| 1182 |
+
|
| 1183 |
+
def batch_select_indices(self, indices: torch.Tensor):
|
| 1184 |
+
"""Select indices from the cache"""
|
| 1185 |
+
for layer_idx in range(len(self.layers)):
|
| 1186 |
+
self.layers[layer_idx].batch_select_indices(indices)
|
| 1187 |
+
|
| 1188 |
+
@property
|
| 1189 |
+
def max_batch_size(self) -> int:
|
| 1190 |
+
"""Return the maximum batch size of the cache"""
|
| 1191 |
+
values = [layer.max_batch_size for layer in self.layers]
|
| 1192 |
+
if len(set(values)) > 1:
|
| 1193 |
+
raise ValueError(f"Max batch size is not consistent across layers: {values}")
|
| 1194 |
+
return values[0]
|
| 1195 |
+
|
| 1196 |
+
@property
|
| 1197 |
+
def max_cache_len(self) -> int:
|
| 1198 |
+
"""Return the maximum cache length of the cache"""
|
| 1199 |
+
values = [layer.max_cache_len for layer in self.layers]
|
| 1200 |
+
return max(values)
|
| 1201 |
+
|
| 1202 |
+
@property
|
| 1203 |
+
def is_compileable(self) -> bool:
|
| 1204 |
+
"""Return whether the cache is compilable"""
|
| 1205 |
+
# For DynamicCache dispatching the layers lazily (otherwise, all([]) is True)
|
| 1206 |
+
if len(self.layers) == 0:
|
| 1207 |
+
return False
|
| 1208 |
+
return all(layer.is_compileable for layer in self.layers)
|
| 1209 |
+
|
| 1210 |
+
@property
|
| 1211 |
+
def is_initialized(self) -> bool:
|
| 1212 |
+
"""Return whether the cache data is initialized"""
|
| 1213 |
+
return len(self.layers) > 0 and all(layer.is_initialized for layer in self.layers)
|
| 1214 |
+
|
| 1215 |
+
@property
|
| 1216 |
+
def is_sliding(self) -> list[bool]:
|
| 1217 |
+
"""Return whether the layers of the cache are sliding window"""
|
| 1218 |
+
return [getattr(layer, "is_sliding", False) for layer in self.layers]
|
| 1219 |
+
|
| 1220 |
+
def __len__(self):
|
| 1221 |
+
"""
|
| 1222 |
+
This value corresponds to the number of layers in the model.
|
| 1223 |
+
"""
|
| 1224 |
+
# Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first
|
| 1225 |
+
# forward through all the layers
|
| 1226 |
+
return len(self.layers)
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
class DynamicCache(Cache):
|
| 1230 |
+
"""
|
| 1231 |
+
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
| 1232 |
+
It stores the key and value states as a list of `CacheLayer`, one for each layer. The expected shape for each tensor
|
| 1233 |
+
in the `CacheLayer`s is `[batch_size, num_heads, seq_len, head_dim]`.
|
| 1234 |
+
If a config is passed, it will additionally check for sliding or hybrid cache structure, greatly reducing the
|
| 1235 |
+
memory requirement of the cached tensors to `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
|
| 1236 |
+
|
| 1237 |
+
See `Cache` for details on common methods that are implemented by all cache classes.
|
| 1238 |
+
|
| 1239 |
+
Args:
|
| 1240 |
+
ddp_cache_data (`Iterable[tuple[torch.Tensor, torch.Tensor]]`, *optional*):
|
| 1241 |
+
It was originally added for compatibility with `torch.distributed` (DDP). In a nutshell, it is
|
| 1242 |
+
`map(gather_map, zip(*caches))`, i.e. each item in the iterable contains the key and value states
|
| 1243 |
+
for a layer gathered across replicas by torch.distributed (shape=[global batch size, num_heads, seq_len, head_dim]).
|
| 1244 |
+
Note: it needs to be the 1st arg as well to work correctly
|
| 1245 |
+
config (`PreTrainedConfig`, *optional*):
|
| 1246 |
+
The config of the model for which this Cache will be used. If passed, it will be used to check for sliding
|
| 1247 |
+
or hybrid layer structure, greatly reducing the memory requirement of the cached tensors to
|
| 1248 |
+
`[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
|
| 1249 |
+
offloading (`bool`, *optional*, defaults to `False`):
|
| 1250 |
+
Whether to perform offloading of the layers to `cpu`, to save GPU memory.
|
| 1251 |
+
offload_only_non_sliding (`bool`, *optional*, defaults to `False`):
|
| 1252 |
+
If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
|
| 1253 |
+
usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
|
| 1254 |
+
|
| 1255 |
+
Example:
|
| 1256 |
+
|
| 1257 |
+
```python
|
| 1258 |
+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
|
| 1259 |
+
|
| 1260 |
+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
| 1261 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
| 1262 |
+
|
| 1263 |
+
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
|
| 1264 |
+
|
| 1265 |
+
>>> # Prepare a cache class and pass it to model's forward
|
| 1266 |
+
>>> past_key_values = DynamicCache(config=model.config)
|
| 1267 |
+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
| 1268 |
+
>>> outputs.past_key_values # access cache filled with key/values from generation
|
| 1269 |
+
```
|
| 1270 |
+
"""
|
| 1271 |
+
|
| 1272 |
+
def __init__(
|
| 1273 |
+
self,
|
| 1274 |
+
ddp_cache_data: Iterable[tuple[torch.Tensor | None, ...]] | None = None,
|
| 1275 |
+
config: PreTrainedConfig | None = None,
|
| 1276 |
+
offloading: bool = False,
|
| 1277 |
+
offload_only_non_sliding: bool = False,
|
| 1278 |
+
):
|
| 1279 |
+
layers = []
|
| 1280 |
+
# If a config is passed, use it to infer the layer types and initialize accordingly
|
| 1281 |
+
if config is not None:
|
| 1282 |
+
decoder_config = config.get_text_config(decoder=True)
|
| 1283 |
+
sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
|
| 1284 |
+
decoder_config, "attention_chunk_size", None
|
| 1285 |
+
)
|
| 1286 |
+
layer_types = getattr(decoder_config, "layer_types", None)
|
| 1287 |
+
if layer_types is None:
|
| 1288 |
+
layer_types = []
|
| 1289 |
+
for _ in range(decoder_config.num_hidden_layers):
|
| 1290 |
+
if sliding_window is not None:
|
| 1291 |
+
layer_types.append("sliding_attention")
|
| 1292 |
+
else:
|
| 1293 |
+
layer_types.append("full_attention")
|
| 1294 |
+
# Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
|
| 1295 |
+
if hasattr(decoder_config, "num_kv_shared_layers"):
|
| 1296 |
+
layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
|
| 1297 |
+
|
| 1298 |
+
for layer_type in layer_types:
|
| 1299 |
+
cache_cls = LAYER_TYPE_CACHE_MAPPING.get(layer_type, DynamicLayer)
|
| 1300 |
+
layers.append(cache_cls(decoder_config))
|
| 1301 |
+
|
| 1302 |
+
# In this case, use the passed data to already fill in the Cache
|
| 1303 |
+
if ddp_cache_data is not None:
|
| 1304 |
+
# Init all the layers with the data
|
| 1305 |
+
for layer_idx, kv_and_optional_sliding in enumerate(ddp_cache_data):
|
| 1306 |
+
# If the config was not passed above, initialize a new cache layer for each entry of the ddp_data
|
| 1307 |
+
if config is None:
|
| 1308 |
+
# kv_and_optional_sliding contains at least two elements: the key and value states. It can also
|
| 1309 |
+
# contain a third element, which is an optional sliding window tensor.
|
| 1310 |
+
sliding_window_tensor = kv_and_optional_sliding[2] if len(kv_and_optional_sliding) == 3 else None
|
| 1311 |
+
# If there is a sliding window tensor, use it to initialize the layer
|
| 1312 |
+
if sliding_window_tensor is not None:
|
| 1313 |
+
# Since the same layer is dispatched across replicas, sliding_window is the same for all
|
| 1314 |
+
sliding_window = sliding_window_tensor[0].item()
|
| 1315 |
+
layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
|
| 1316 |
+
else:
|
| 1317 |
+
layers.append(DynamicLayer())
|
| 1318 |
+
# Update the layer with the data
|
| 1319 |
+
_, _ = layers[layer_idx].update(kv_and_optional_sliding[0], kv_and_optional_sliding[1])
|
| 1320 |
+
|
| 1321 |
+
# If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer
|
| 1322 |
+
if len(layers) == 0:
|
| 1323 |
+
super().__init__(
|
| 1324 |
+
layer_class_to_replicate=DynamicLayer,
|
| 1325 |
+
offloading=offloading,
|
| 1326 |
+
offload_only_non_sliding=offload_only_non_sliding,
|
| 1327 |
+
)
|
| 1328 |
+
else:
|
| 1329 |
+
super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
|
| 1330 |
+
|
| 1331 |
+
def __iter__(self):
|
| 1332 |
+
for layer in self.layers:
|
| 1333 |
+
yield layer.keys, layer.values, getattr(layer, "_sliding_window_tensor", None)
|
| 1334 |
+
|
| 1335 |
+
|
| 1336 |
+
class StaticCache(Cache):
|
| 1337 |
+
"""
|
| 1338 |
+
Static Cache class to be used with `torch.compile(model)` and `torch.export()`. It will check the `config`
|
| 1339 |
+
for potential hybrid cache structure, and initialize each layer accordingly.
|
| 1340 |
+
|
| 1341 |
+
See `Cache` for details on common methods that are implemented by all cache classes.
|
| 1342 |
+
|
| 1343 |
+
Args:
|
| 1344 |
+
config (`PreTrainedConfig`):
|
| 1345 |
+
The config of the model for which this Cache will be used. It will be used to check for sliding
|
| 1346 |
+
or hybrid layer structure, and initialize each layer accordingly.
|
| 1347 |
+
max_cache_len (`int`):
|
| 1348 |
+
The maximum number of tokens that this Cache should hold.
|
| 1349 |
+
offloading (`bool`, *optional*, defaults to `False`):
|
| 1350 |
+
Whether to perform offloading of the layers to `cpu`, to save GPU memory.
|
| 1351 |
+
offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
|
| 1352 |
+
If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
|
| 1353 |
+
usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
|
| 1354 |
+
|
| 1355 |
+
Example:
|
| 1356 |
+
|
| 1357 |
+
```python
|
| 1358 |
+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
|
| 1359 |
+
|
| 1360 |
+
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
| 1361 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
| 1362 |
+
|
| 1363 |
+
>>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
|
| 1364 |
+
|
| 1365 |
+
>>> # Prepare a cache class and pass it to model's forward
|
| 1366 |
+
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
|
| 1367 |
+
>>> max_generated_length = inputs.input_ids.shape[1] + 10
|
| 1368 |
+
>>> past_key_values = StaticCache(config=model.config, max_cache_len=max_generated_length)
|
| 1369 |
+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
| 1370 |
+
>>> outputs.past_key_values # access cache filled with key/values from generation
|
| 1371 |
+
StaticCache()
|
| 1372 |
+
```
|
| 1373 |
+
"""
|
| 1374 |
+
|
| 1375 |
+
# Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
|
| 1376 |
+
def __init__(
|
| 1377 |
+
self,
|
| 1378 |
+
config: PreTrainedConfig,
|
| 1379 |
+
max_cache_len: int,
|
| 1380 |
+
offloading: bool = False,
|
| 1381 |
+
offload_only_non_sliding: bool = True,
|
| 1382 |
+
**kwargs,
|
| 1383 |
+
):
|
| 1384 |
+
config = config.get_text_config(decoder=True)
|
| 1385 |
+
layer_types = getattr(config, "layer_types", None)
|
| 1386 |
+
# If `layer_types` is not explicitly provided, infer if the model is fully sliding
|
| 1387 |
+
if layer_types is None:
|
| 1388 |
+
if getattr(config, "sliding_window", None) is not None:
|
| 1389 |
+
layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
|
| 1390 |
+
elif getattr(config, "attention_chunk_size", None) is not None:
|
| 1391 |
+
layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
|
| 1392 |
+
else:
|
| 1393 |
+
layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
|
| 1394 |
+
# Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
|
| 1395 |
+
if hasattr(config, "num_kv_shared_layers"):
|
| 1396 |
+
layer_types = layer_types[: -config.num_kv_shared_layers]
|
| 1397 |
+
|
| 1398 |
+
sliding_layer_types = {
|
| 1399 |
+
name
|
| 1400 |
+
for name, cls in LAYER_TYPE_CACHE_MAPPING.items()
|
| 1401 |
+
if isinstance(cls, type) and issubclass(cls, DynamicSlidingWindowLayer) and name != "chunked_attention"
|
| 1402 |
+
}
|
| 1403 |
+
layers = []
|
| 1404 |
+
for layer_type in layer_types:
|
| 1405 |
+
if layer_type == "chunked_attention":
|
| 1406 |
+
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many
|
| 1407 |
+
# states they should return - only the mask changes to make them different at the end!
|
| 1408 |
+
layer = StaticSlidingWindowLayer(
|
| 1409 |
+
max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size
|
| 1410 |
+
)
|
| 1411 |
+
elif layer_type in sliding_layer_types:
|
| 1412 |
+
layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
|
| 1413 |
+
# LinearAttention layers are static by essence - using `"moe"` as well is a trick, see the comment about it on DynamicCache
|
| 1414 |
+
elif layer_type in ("mamba", "conv", "linear_attention", "moe"):
|
| 1415 |
+
layer = LinearAttentionLayer()
|
| 1416 |
+
else:
|
| 1417 |
+
layer = StaticLayer(max_cache_len=max_cache_len)
|
| 1418 |
+
layers.append(layer)
|
| 1419 |
+
|
| 1420 |
+
super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
|
| 1421 |
+
|
| 1422 |
+
|
| 1423 |
+
class QuantizedCache(Cache):
|
| 1424 |
+
"""
|
| 1425 |
+
A quantizer cache similar to what is described in the
|
| 1426 |
+
[KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
|
| 1427 |
+
It allows the model to generate longer sequence length without allocating too much memory for keys and values
|
| 1428 |
+
by applying quantization.
|
| 1429 |
+
The cache has two types of storage, one for original precision and one for the
|
| 1430 |
+
quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the
|
| 1431 |
+
length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache.
|
| 1432 |
+
The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was
|
| 1433 |
+
described in the paper.
|
| 1434 |
+
|
| 1435 |
+
See `Cache` for details on common methods that are implemented by all cache classes.
|
| 1436 |
+
|
| 1437 |
+
Args:
|
| 1438 |
+
backend (`str`):
|
| 1439 |
+
The quantization backend to use. One of `("quanto", "hqq").
|
| 1440 |
+
config (`PreTrainedConfig`):
|
| 1441 |
+
The config of the model for which this Cache will be used.
|
| 1442 |
+
nbits (`int`, *optional*, defaults to 4):
|
| 1443 |
+
The number of bits for quantization.
|
| 1444 |
+
axis_key (`int`, *optional*, defaults to 0):
|
| 1445 |
+
The axis on which to quantize the keys.
|
| 1446 |
+
axis_value (`int`, *optional*, defaults to 0):
|
| 1447 |
+
The axis on which to quantize the values.
|
| 1448 |
+
q_group_size (`int`, *optional*, defaults to 64):
|
| 1449 |
+
Quantization is done per-channel according to a set `q_group_size` for both keys and values.
|
| 1450 |
+
residual_length (`int`, *optional*, defaults to 128):
|
| 1451 |
+
Maximum capacity for the original precision cache
|
| 1452 |
+
"""
|
| 1453 |
+
|
| 1454 |
+
def __init__(
|
| 1455 |
+
self,
|
| 1456 |
+
backend: str,
|
| 1457 |
+
config: PreTrainedConfig,
|
| 1458 |
+
nbits: int = 4,
|
| 1459 |
+
axis_key: int = 0,
|
| 1460 |
+
axis_value: int = 0,
|
| 1461 |
+
q_group_size: int = 64,
|
| 1462 |
+
residual_length: int = 128,
|
| 1463 |
+
):
|
| 1464 |
+
if backend == "quanto":
|
| 1465 |
+
layer_class = QuantoQuantizedLayer
|
| 1466 |
+
elif backend == "hqq":
|
| 1467 |
+
layer_class = HQQQuantizedLayer
|
| 1468 |
+
else:
|
| 1469 |
+
raise ValueError(f"Unknown quantization backend `{backend}`")
|
| 1470 |
+
|
| 1471 |
+
config = config.get_text_config(decoder=True)
|
| 1472 |
+
layers = [
|
| 1473 |
+
layer_class(nbits, axis_key, axis_value, q_group_size, residual_length)
|
| 1474 |
+
for _ in range(config.num_hidden_layers)
|
| 1475 |
+
]
|
| 1476 |
+
super().__init__(layers=layers)
|
| 1477 |
+
|
| 1478 |
+
|
| 1479 |
+
class EncoderDecoderCache(Cache):
|
| 1480 |
+
"""
|
| 1481 |
+
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
|
| 1482 |
+
cross-attention caches.
|
| 1483 |
+
|
| 1484 |
+
See `Cache` for details on common methods that are implemented by all cache classes.
|
| 1485 |
+
|
| 1486 |
+
Args:
|
| 1487 |
+
caches (`Iterable`):
|
| 1488 |
+
Usually an iterable of length 2, containing 2 `Cache` objects, the first one for self-attention, the
|
| 1489 |
+
second one for cross-attention. Can optionally also be an iterable of length 1, containing a
|
| 1490 |
+
`tuple[tuple[torch.Tensor]]` (usually used for compatibility with torch dp and ddp).
|
| 1491 |
+
|
| 1492 |
+
Example:
|
| 1493 |
+
|
| 1494 |
+
```python
|
| 1495 |
+
>>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
|
| 1496 |
+
|
| 1497 |
+
>>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
|
| 1498 |
+
>>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
|
| 1499 |
+
|
| 1500 |
+
>>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
|
| 1501 |
+
|
| 1502 |
+
>>> # Prepare cache classes for encoder and decoder and pass it to model's forward
|
| 1503 |
+
>>> self_attention_cache = DynamicCache(config=self.config)
|
| 1504 |
+
>>> cross_attention_cache = DynamicCache(config=self.config)
|
| 1505 |
+
>>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
|
| 1506 |
+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
| 1507 |
+
>>> outputs.past_key_values # access cache filled with key/values from generation
|
| 1508 |
+
EncoderDecoderCache()
|
| 1509 |
+
```
|
| 1510 |
+
"""
|
| 1511 |
+
|
| 1512 |
+
def __init__(self, *caches) -> None:
|
| 1513 |
+
# For dp and ddp support, if only one argument is passed, it should be an iterable of DynamicCache ddp data
|
| 1514 |
+
if len(caches) == 1:
|
| 1515 |
+
self_attention_cache_data, cross_attention_cache_data = [], []
|
| 1516 |
+
for combined_cache_data in caches[0]:
|
| 1517 |
+
if len(combined_cache_data) == 6: # two tuple of style (self_attn_k, self_attn_v, self_attn_sliding)
|
| 1518 |
+
self_attention_cache_data.append(combined_cache_data[:3])
|
| 1519 |
+
cross_attention_cache_data.append(combined_cache_data[3:])
|
| 1520 |
+
# To support old DDP-style init, we handle the case where the tuple has no sliding window tensor
|
| 1521 |
+
elif len(combined_cache_data) == 4: # two tuple of style (self_attn_k, self_attn_v)
|
| 1522 |
+
self_attention_cache_data.append(combined_cache_data[:2])
|
| 1523 |
+
cross_attention_cache_data.append(combined_cache_data[2:])
|
| 1524 |
+
else:
|
| 1525 |
+
raise ValueError(f"Expected {len(combined_cache_data) = } to be 4 or 6.\n{combined_cache_data = }")
|
| 1526 |
+
self.self_attention_cache = DynamicCache(self_attention_cache_data)
|
| 1527 |
+
self.cross_attention_cache = DynamicCache(cross_attention_cache_data)
|
| 1528 |
+
# Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
|
| 1529 |
+
elif len(caches) == 2:
|
| 1530 |
+
if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
|
| 1531 |
+
raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }")
|
| 1532 |
+
self.self_attention_cache = caches[0]
|
| 1533 |
+
self.cross_attention_cache = caches[1]
|
| 1534 |
+
# Error case
|
| 1535 |
+
else:
|
| 1536 |
+
raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}")
|
| 1537 |
+
|
| 1538 |
+
self.is_updated = {}
|
| 1539 |
+
for layer_idx in range(len(self.cross_attention_cache)):
|
| 1540 |
+
self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)
|
| 1541 |
+
|
| 1542 |
+
def __iter__(self):
|
| 1543 |
+
"""Returns tuples of style (self_attn_k, self_attn_v, self_attn_sliding, cross_attn_k, cross_attn_v, cross_attn_sliding)"""
|
| 1544 |
+
for self_attention_layer, cross_attention_layer in zip(self.self_attention_cache, self.cross_attention_cache):
|
| 1545 |
+
yield self_attention_layer + cross_attention_layer
|
| 1546 |
+
|
| 1547 |
+
def __repr__(self) -> str:
|
| 1548 |
+
return (
|
| 1549 |
+
f"{self.__class__.__name__}(self_attention_cache={self.self_attention_cache}, cross_attention_cache="
|
| 1550 |
+
f"{self.cross_attention_cache})"
|
| 1551 |
+
)
|
| 1552 |
+
|
| 1553 |
+
def __len__(self):
|
| 1554 |
+
"""
|
| 1555 |
+
Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds
|
| 1556 |
+
to the number of layers in the model.
|
| 1557 |
+
"""
|
| 1558 |
+
return len(self.self_attention_cache)
|
| 1559 |
+
|
| 1560 |
+
def get_seq_length(self, layer_idx: int = 0) -> int:
|
| 1561 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
| 1562 |
+
return self.self_attention_cache.get_seq_length(layer_idx)
|
| 1563 |
+
|
| 1564 |
+
def reset(self):
|
| 1565 |
+
self.self_attention_cache.reset()
|
| 1566 |
+
self.cross_attention_cache.reset()
|
| 1567 |
+
for layer_idx in self.is_updated:
|
| 1568 |
+
self.is_updated[layer_idx] = False
|
| 1569 |
+
|
| 1570 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 1571 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
| 1572 |
+
self.self_attention_cache.reorder_cache(beam_idx)
|
| 1573 |
+
self.cross_attention_cache.reorder_cache(beam_idx)
|
| 1574 |
+
|
| 1575 |
+
def check_dynamic_cache(self, method: str):
|
| 1576 |
+
if not (
|
| 1577 |
+
isinstance(self.self_attention_cache, DynamicCache)
|
| 1578 |
+
and isinstance(self.cross_attention_cache, DynamicCache)
|
| 1579 |
+
):
|
| 1580 |
+
raise TypeError(
|
| 1581 |
+
f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
|
| 1582 |
+
f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
|
| 1583 |
+
)
|
| 1584 |
+
|
| 1585 |
+
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
|
| 1586 |
+
def crop(self, maximum_length: int):
|
| 1587 |
+
"""
|
| 1588 |
+
Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
|
| 1589 |
+
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search (on the Hub).
|
| 1590 |
+
"""
|
| 1591 |
+
self.check_dynamic_cache(self.crop.__name__)
|
| 1592 |
+
self.self_attention_cache.crop(maximum_length)
|
| 1593 |
+
|
| 1594 |
+
def batch_repeat_interleave(self, repeats: int):
|
| 1595 |
+
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search (on the Hub)."""
|
| 1596 |
+
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
|
| 1597 |
+
self.self_attention_cache.batch_repeat_interleave(repeats)
|
| 1598 |
+
self.cross_attention_cache.batch_repeat_interleave(repeats)
|
| 1599 |
+
|
| 1600 |
+
def batch_select_indices(self, indices: torch.Tensor):
|
| 1601 |
+
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search (on the Hub)."""
|
| 1602 |
+
self.check_dynamic_cache(self.batch_select_indices.__name__)
|
| 1603 |
+
self.self_attention_cache.batch_select_indices(indices)
|
| 1604 |
+
self.cross_attention_cache.batch_select_indices(indices)
|
| 1605 |
+
|
| 1606 |
+
def get_max_cache_shape(self) -> int:
|
| 1607 |
+
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
|
| 1608 |
+
return self.self_attention_cache.get_max_cache_shape()
|
| 1609 |
+
|
| 1610 |
+
def get_mask_sizes(self, query_length: int, layer_idx: int) -> tuple[int, int]:
|
| 1611 |
+
return self.self_attention_cache.get_mask_sizes(query_length, layer_idx)
|
| 1612 |
+
|
| 1613 |
+
@property
|
| 1614 |
+
def is_sliding(self):
|
| 1615 |
+
return self.self_attention_cache.is_sliding
|
| 1616 |
+
|
| 1617 |
+
@property
|
| 1618 |
+
def is_compileable(self) -> bool:
|
| 1619 |
+
return self.self_attention_cache.is_compileable
|
| 1620 |
+
|
| 1621 |
+
|
| 1622 |
+
# Deprecated alias: SlidingWindowCache was removed in transformers v5. StaticCache is the replacement.
|
| 1623 |
+
SlidingWindowCache = StaticCache
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/configuration_utils.py
ADDED
|
@@ -0,0 +1,1365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 2 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Configuration base class and utilities."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import json
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
from collections.abc import Sequence
|
| 22 |
+
from dataclasses import MISSING, dataclass, fields
|
| 23 |
+
from functools import wraps
|
| 24 |
+
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar, Union
|
| 25 |
+
|
| 26 |
+
from huggingface_hub import create_repo
|
| 27 |
+
from huggingface_hub.dataclasses import strict
|
| 28 |
+
from packaging import version
|
| 29 |
+
from typing_extensions import dataclass_transform
|
| 30 |
+
|
| 31 |
+
from . import __version__
|
| 32 |
+
from .dynamic_module_utils import custom_object_save
|
| 33 |
+
from .generation.configuration_utils import GenerationConfig
|
| 34 |
+
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
|
| 35 |
+
from .modeling_rope_utils import RotaryEmbeddingConfigMixin
|
| 36 |
+
from .utils import (
|
| 37 |
+
CONFIG_NAME,
|
| 38 |
+
PushToHubMixin,
|
| 39 |
+
cached_file,
|
| 40 |
+
copy_func,
|
| 41 |
+
extract_commit_hash,
|
| 42 |
+
is_torch_available,
|
| 43 |
+
logging,
|
| 44 |
+
)
|
| 45 |
+
from .utils.generic import is_timm_config_dict
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if TYPE_CHECKING:
|
| 49 |
+
import torch
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
logger = logging.get_logger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# type hinting: specifying the type of config class that inherits from PreTrainedConfig
|
| 56 |
+
SpecificPreTrainedConfigType = TypeVar("SpecificPreTrainedConfigType", bound="PreTrainedConfig")
|
| 57 |
+
|
| 58 |
+
_FLOAT_TAG_KEY = "__float__"
|
| 59 |
+
_FLOAT_TAG_VALUES = {"Infinity": float("inf"), "-Infinity": float("-inf"), "NaN": float("nan")}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
ALLOWED_LAYER_TYPES = (
|
| 63 |
+
"full_attention",
|
| 64 |
+
"sliding_attention",
|
| 65 |
+
"chunked_attention",
|
| 66 |
+
"compressed_sparse_attention", # CSA, used in deepseek_v4
|
| 67 |
+
"heavily_compressed_attention", # HCA, used in deepseek_v4
|
| 68 |
+
"linear_attention", # used in minimax
|
| 69 |
+
"conv", # used in LFMv2
|
| 70 |
+
"mamba",
|
| 71 |
+
"attention",
|
| 72 |
+
"sparse",
|
| 73 |
+
"dense",
|
| 74 |
+
"hybrid", # for layers that have both mamba and attention in zamba and zamba2
|
| 75 |
+
"moe", # for nemotron_h, which uses either attention, mamba or moe
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# copied from huggingface_hub.dataclasses.strict when `accept_kwargs=True`
|
| 80 |
+
def wrap_init_to_accept_kwargs(cls: dataclass):
|
| 81 |
+
# Get the original dataclass-generated __init__
|
| 82 |
+
original_init = cls.__init__
|
| 83 |
+
|
| 84 |
+
@wraps(original_init)
|
| 85 |
+
def __init__(self, *args, **kwargs: Any) -> None:
|
| 86 |
+
# Extract only the fields that are part of the dataclass
|
| 87 |
+
dataclass_fields = {f.name for f in fields(cls)}
|
| 88 |
+
standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields}
|
| 89 |
+
|
| 90 |
+
# We need to call bare `__init__` without `__post_init__` but the `original_init` of
|
| 91 |
+
# any dataclas contains a call to post-init at the end (without kwargs)
|
| 92 |
+
if len(args) > 0:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"{cls.__name__} accepts only keyword arguments, but found `{len(args)}` positional args."
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
for f in fields(cls): # type: ignore
|
| 98 |
+
if f.name in standard_kwargs:
|
| 99 |
+
setattr(self, f.name, standard_kwargs[f.name])
|
| 100 |
+
elif f.default is not MISSING:
|
| 101 |
+
setattr(self, f.name, f.default)
|
| 102 |
+
elif f.default_factory is not MISSING:
|
| 103 |
+
setattr(self, f.name, f.default_factory())
|
| 104 |
+
else:
|
| 105 |
+
raise TypeError(f"Missing required field - '{f.name}'")
|
| 106 |
+
|
| 107 |
+
# Pass any additional kwargs to `__post_init__` and let the object
|
| 108 |
+
# decide whether to set the attr or use for different purposes (e.g. BC checks)
|
| 109 |
+
additional_kwargs = {}
|
| 110 |
+
for name, value in kwargs.items():
|
| 111 |
+
if name not in dataclass_fields:
|
| 112 |
+
additional_kwargs[name] = value
|
| 113 |
+
|
| 114 |
+
self.__post_init__(**additional_kwargs)
|
| 115 |
+
|
| 116 |
+
cls.__init__ = __init__
|
| 117 |
+
return cls
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass_transform(kw_only_default=True)
|
| 121 |
+
@strict(accept_kwargs=True)
|
| 122 |
+
@dataclass(repr=False)
|
| 123 |
+
class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
|
| 124 |
+
# no-format
|
| 125 |
+
r"""
|
| 126 |
+
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
|
| 127 |
+
methods for loading/downloading/saving configurations.
|
| 128 |
+
|
| 129 |
+
<Tip>
|
| 130 |
+
|
| 131 |
+
A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
|
| 132 |
+
initialize a model does **not** load the model weights. It only affects the model's configuration.
|
| 133 |
+
|
| 134 |
+
</Tip>
|
| 135 |
+
|
| 136 |
+
Class attributes (overridden by derived classes):
|
| 137 |
+
|
| 138 |
+
- **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
|
| 139 |
+
the correct object in [`~transformers.AutoConfig`].
|
| 140 |
+
- **has_no_defaults_at_init** (`bool`) -- Whether the config class can be initialized without providing input arguments.
|
| 141 |
+
Some configurations requires inputs to be defined at init and have no default values, usually these are composite configs,
|
| 142 |
+
(but not necessarily) such as [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`]. They have to be initialized from
|
| 143 |
+
two or more configs of type [`~transformers.PreTrainedConfig`].
|
| 144 |
+
- **keys_to_ignore_at_inference** (`list[str]`) -- A list of keys to ignore by default when looking at dictionary
|
| 145 |
+
outputs of the model during inference.
|
| 146 |
+
- **attribute_map** (`dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
|
| 147 |
+
naming of attributes.
|
| 148 |
+
- **base_model_tp_plan** (`dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
|
| 149 |
+
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
|
| 150 |
+
- **base_model_pp_plan** (`dict[str, tuple[list[str]]]`) -- A dict that maps child-modules of a base model to a
|
| 151 |
+
pipeline parallel plan that enables users to place the child-module on the appropriate device.
|
| 152 |
+
|
| 153 |
+
Common attributes (present in all subclasses):
|
| 154 |
+
|
| 155 |
+
- **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
|
| 156 |
+
embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
|
| 157 |
+
- **hidden_size** (`int`) -- The hidden size of the model.
|
| 158 |
+
- **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
|
| 159 |
+
model.
|
| 160 |
+
- **num_hidden_layers** (`int`) -- The number of blocks in the model.
|
| 161 |
+
|
| 162 |
+
<Tip warning={true}>
|
| 163 |
+
|
| 164 |
+
Setting parameters for sequence generation in the model config is deprecated. For backward compatibility, loading
|
| 165 |
+
some of them will still be possible, but attempting to overwrite them will throw an exception -- you should set
|
| 166 |
+
them in a [~transformers.GenerationConfig]. Check the documentation of [~transformers.GenerationConfig] for more
|
| 167 |
+
information about the individual parameters.
|
| 168 |
+
|
| 169 |
+
</Tip>
|
| 170 |
+
|
| 171 |
+
Arg:
|
| 172 |
+
name_or_path (`str`, *optional*, defaults to `""`):
|
| 173 |
+
Store the string that was passed to [`PreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path`
|
| 174 |
+
if the configuration was created with such a method.
|
| 175 |
+
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
| 176 |
+
Whether or not the model should return all hidden-states.
|
| 177 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 178 |
+
Whether or not the model should returns all attentions.
|
| 179 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 180 |
+
Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
|
| 181 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `False`):
|
| 182 |
+
Whether the model is used as an encoder/decoder or not.
|
| 183 |
+
chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
|
| 184 |
+
The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
|
| 185 |
+
the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
|
| 186 |
+
sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
|
| 187 |
+
Forward Chunking work?](../glossary.html#feed-forward-chunking).
|
| 188 |
+
|
| 189 |
+
> Parameters for fine-tuning tasks
|
| 190 |
+
|
| 191 |
+
architectures (`list[str]`, *optional*):
|
| 192 |
+
Model architectures that can be used with the model pretrained weights.
|
| 193 |
+
id2label (`dict[int, str]`, *optional*):
|
| 194 |
+
A map from index (for instance prediction index, or target index) to label.
|
| 195 |
+
label2id (`dict[str, int]`, *optional*):
|
| 196 |
+
A map from label to index for the model.
|
| 197 |
+
num_labels (`int`, *optional*):
|
| 198 |
+
Number of labels to use in the last layer added to the model, typically for a classification task.
|
| 199 |
+
problem_type (`str`, *optional*):
|
| 200 |
+
Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
|
| 201 |
+
`"single_label_classification"` or `"multi_label_classification"`.
|
| 202 |
+
|
| 203 |
+
> PyTorch specific parameters
|
| 204 |
+
|
| 205 |
+
dtype (`str`, *optional*):
|
| 206 |
+
The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
|
| 207 |
+
(which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
|
| 208 |
+
model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
|
| 209 |
+
`float16` weights.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
# Class attributes that we don't want to save or have in `self.__dict__`
|
| 213 |
+
# They are not supposed to be set/changed by users. Each field is set when
|
| 214 |
+
# creating a model class
|
| 215 |
+
base_config_key: ClassVar[str] = ""
|
| 216 |
+
sub_configs: ClassVar[dict[str, type["PreTrainedConfig"]]] = {}
|
| 217 |
+
has_no_defaults_at_init: ClassVar[bool] = False
|
| 218 |
+
keys_to_ignore_at_inference: ClassVar[list[str]] = []
|
| 219 |
+
attribute_map: ClassVar[dict[str, str]] = {}
|
| 220 |
+
base_model_tp_plan: ClassVar[dict[str, Any] | None] = None
|
| 221 |
+
base_model_pp_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
|
| 222 |
+
base_model_ep_plan: ClassVar[dict[str, Sequence[list[str]]] | None] = None
|
| 223 |
+
_auto_class: ClassVar[str | None] = None
|
| 224 |
+
|
| 225 |
+
# Attributes set internally when saving and used to infer model
|
| 226 |
+
# class for `Auto` mapping
|
| 227 |
+
model_type: ClassVar[str] = ""
|
| 228 |
+
transformers_version: str | None = None
|
| 229 |
+
architectures: list[str] | None = None
|
| 230 |
+
|
| 231 |
+
# Common attributes for all models
|
| 232 |
+
output_hidden_states: bool | None = False
|
| 233 |
+
return_dict: bool | None = True
|
| 234 |
+
dtype: Union[str, "torch.dtype"] | None = None
|
| 235 |
+
chunk_size_feed_forward: int = 0
|
| 236 |
+
is_encoder_decoder: bool = False
|
| 237 |
+
|
| 238 |
+
# Fine-tuning task arguments
|
| 239 |
+
id2label: dict[int, str] | dict[str, str] | None = None
|
| 240 |
+
label2id: dict[str, int] | dict[str, str] | None = None
|
| 241 |
+
problem_type: Literal["regression", "single_label_classification", "multi_label_classification"] | None = None
|
| 242 |
+
|
| 243 |
+
def __post_init__(self, **kwargs):
|
| 244 |
+
# BC for the `torch_dtype` argument instead of the simpler `dtype`
|
| 245 |
+
# Do not warn, as it would otherwise always be triggered since most configs on the hub have `torch_dtype`
|
| 246 |
+
if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
|
| 247 |
+
# If both are provided, keep `dtype`
|
| 248 |
+
self.dtype = self.dtype if self.dtype is not None else torch_dtype
|
| 249 |
+
if self.dtype is not None and isinstance(self.dtype, str) and is_torch_available():
|
| 250 |
+
# we will start using self.dtype in v5, but to be consistent with
|
| 251 |
+
# from_pretrained's dtype arg convert it to an actual torch.dtype object
|
| 252 |
+
import torch
|
| 253 |
+
|
| 254 |
+
self.dtype = getattr(torch, self.dtype)
|
| 255 |
+
|
| 256 |
+
# Keep the default value of `num_labels=2` in case users have saved a classfier with 2 labels
|
| 257 |
+
# Our configs prev wouldn't save `id2label` for 2 labels because it is the default. In all other
|
| 258 |
+
# cases we expect the config dict to have an `id2label` field if it's a clf model, or not otherwise
|
| 259 |
+
if self.id2label is None:
|
| 260 |
+
self.num_labels = kwargs.get("num_labels", 2)
|
| 261 |
+
else:
|
| 262 |
+
if kwargs.get("num_labels") is not None and len(self.id2label) != kwargs.get("num_labels"):
|
| 263 |
+
logger.warning(
|
| 264 |
+
f"You passed `num_labels={kwargs.get('num_labels')}` which is incompatible to "
|
| 265 |
+
f"the `id2label` map of length `{len(self.id2label)}`."
|
| 266 |
+
)
|
| 267 |
+
# Keys are always strings in JSON so convert ids to int
|
| 268 |
+
self.id2label = {int(key): value for key, value in self.id2label.items()}
|
| 269 |
+
|
| 270 |
+
if self.problem_type == "single_label_classification" and self.num_labels == 1:
|
| 271 |
+
raise ValueError(
|
| 272 |
+
'`problem_type="single_label_classification"` requires `num_labels > 1`. For binary '
|
| 273 |
+
'classification use `num_labels=2`, or use `problem_type="regression"` for a '
|
| 274 |
+
"single-output regression head."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# BC for rotary embeddings. We will pop out legacy keys from kwargs and rename to new format
|
| 278 |
+
if hasattr(self, "rope_parameters"):
|
| 279 |
+
kwargs = self.convert_rope_params_to_dict(**kwargs)
|
| 280 |
+
elif kwargs.get("rope_scaling") and kwargs.get("rope_theta"):
|
| 281 |
+
logger.warning(
|
| 282 |
+
f"{self.__class__.__name__} got `key=rope_scaling` in kwargs but hasn't set it as attribute. "
|
| 283 |
+
"For RoPE standardization you need to set `self.rope_parameters` in model's config. "
|
| 284 |
+
)
|
| 285 |
+
kwargs = self.convert_rope_params_to_dict(**kwargs)
|
| 286 |
+
|
| 287 |
+
# Parameters for sequence generation saved in the config are popped instead of loading them.
|
| 288 |
+
for parameter_name in GenerationConfig._get_default_generation_params().keys():
|
| 289 |
+
kwargs.pop(parameter_name, None)
|
| 290 |
+
|
| 291 |
+
# Name or path to the pretrained checkpoint
|
| 292 |
+
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
| 293 |
+
self._commit_hash = kwargs.pop("_commit_hash", None)
|
| 294 |
+
|
| 295 |
+
# Attention/Experts implementation to use, if relevant (it sets it recursively on sub-configs)
|
| 296 |
+
self._output_attentions: bool | None = kwargs.pop("output_attentions", False)
|
| 297 |
+
self._attn_implementation: str | None = kwargs.pop("attn_implementation", None)
|
| 298 |
+
self._experts_implementation: str | None = kwargs.pop("experts_implementation", None)
|
| 299 |
+
|
| 300 |
+
# Additional attributes without default values
|
| 301 |
+
for key, value in kwargs.items():
|
| 302 |
+
# Check this to avoid deserializing problematic fields from hub configs - they should use the public field
|
| 303 |
+
if key not in ("_attn_implementation_internal", "_experts_implementation_internal"):
|
| 304 |
+
try:
|
| 305 |
+
setattr(self, key, value)
|
| 306 |
+
except AttributeError as err:
|
| 307 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
| 308 |
+
raise err
|
| 309 |
+
|
| 310 |
+
def __init_subclass__(cls, *args, **kwargs):
|
| 311 |
+
super().__init_subclass__(*args, **kwargs)
|
| 312 |
+
cls_has_custom_init = "__init__" in cls.__dict__
|
| 313 |
+
# kw_only=True ensures fields without defaults in subclasses can follow
|
| 314 |
+
# parent fields that have defaults (Python dataclass ordering rule).
|
| 315 |
+
# Config fields are always passed as keyword arguments, so this is safe.
|
| 316 |
+
cls = dataclass(cls, repr=False, kw_only=True)
|
| 317 |
+
|
| 318 |
+
if not cls_has_custom_init:
|
| 319 |
+
# Wrap all subclasses to accept arbitrary kwargs for BC
|
| 320 |
+
# only if the subclass has no custom `__init__`. Most
|
| 321 |
+
# remote code has an init defined, but some model are not
|
| 322 |
+
# See https://huggingface.co/hmellor/Ilama-3.2-1B/blob/main/configuration_ilama.py
|
| 323 |
+
cls = wrap_init_to_accept_kwargs(cls)
|
| 324 |
+
|
| 325 |
+
@property
|
| 326 |
+
def name_or_path(self) -> str | None:
|
| 327 |
+
return getattr(self, "_name_or_path", None)
|
| 328 |
+
|
| 329 |
+
@name_or_path.setter
|
| 330 |
+
def name_or_path(self, value):
|
| 331 |
+
self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
|
| 332 |
+
|
| 333 |
+
@property
|
| 334 |
+
def num_labels(self) -> int:
|
| 335 |
+
"""
|
| 336 |
+
`int`: The number of labels for classification models.
|
| 337 |
+
"""
|
| 338 |
+
return len(self.id2label) if self.id2label is not None else None
|
| 339 |
+
|
| 340 |
+
@num_labels.setter
|
| 341 |
+
def num_labels(self, num_labels: int):
|
| 342 |
+
# we do not store `num_labels` attribute in config, but instead
|
| 343 |
+
# compute it based on the length of the `id2label` map
|
| 344 |
+
if self.id2label is None or self.num_labels != num_labels:
|
| 345 |
+
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
|
| 346 |
+
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
| 347 |
+
|
| 348 |
+
@property
|
| 349 |
+
def output_attentions(self):
|
| 350 |
+
"""
|
| 351 |
+
`bool`: Whether or not the model should returns all attentions.
|
| 352 |
+
"""
|
| 353 |
+
return self._output_attentions
|
| 354 |
+
|
| 355 |
+
@output_attentions.setter
|
| 356 |
+
def output_attentions(self, value: bool):
|
| 357 |
+
# If we set `output_attentions` explicitly before the attn implementation, dispatch eager
|
| 358 |
+
if value and self._attn_implementation is None:
|
| 359 |
+
self._attn_implementation = "eager"
|
| 360 |
+
if value and self._attn_implementation != "eager":
|
| 361 |
+
raise ValueError(
|
| 362 |
+
"The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
|
| 363 |
+
f"{self._attn_implementation}. Please set it to 'eager' instead."
|
| 364 |
+
)
|
| 365 |
+
self._output_attentions = value
|
| 366 |
+
|
| 367 |
+
@property
|
| 368 |
+
def _attn_implementation(self):
|
| 369 |
+
return self._attn_implementation_internal
|
| 370 |
+
|
| 371 |
+
@_attn_implementation.setter
|
| 372 |
+
def _attn_implementation(self, value: str | dict | None):
|
| 373 |
+
"""We set it recursively on the sub-configs as well"""
|
| 374 |
+
# Set if for current config
|
| 375 |
+
current_attn = getattr(self, "_attn_implementation", None)
|
| 376 |
+
attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn)
|
| 377 |
+
self._attn_implementation_internal = attn_implementation
|
| 378 |
+
|
| 379 |
+
# Set it recursively on the subconfigs
|
| 380 |
+
for subconfig_key in self.sub_configs:
|
| 381 |
+
subconfig = getattr(self, subconfig_key, None)
|
| 382 |
+
if subconfig is not None:
|
| 383 |
+
current_subconfig_attn = getattr(subconfig, "_attn_implementation", None)
|
| 384 |
+
sub_implementation = (
|
| 385 |
+
value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn)
|
| 386 |
+
)
|
| 387 |
+
subconfig._attn_implementation = sub_implementation
|
| 388 |
+
|
| 389 |
+
@property
|
| 390 |
+
def _experts_implementation(self):
|
| 391 |
+
return self._experts_implementation_internal
|
| 392 |
+
|
| 393 |
+
@_experts_implementation.setter
|
| 394 |
+
def _experts_implementation(self, value: str | dict | None):
|
| 395 |
+
"""We set it recursively on the sub-configs as well"""
|
| 396 |
+
# Set if for current config
|
| 397 |
+
current_moe = getattr(self, "_experts_implementation", None)
|
| 398 |
+
experts_implementation = value if not isinstance(value, dict) else value.get("", current_moe)
|
| 399 |
+
self._experts_implementation_internal = experts_implementation
|
| 400 |
+
|
| 401 |
+
# Set it recursively on the subconfigs
|
| 402 |
+
for subconfig_key in self.sub_configs:
|
| 403 |
+
subconfig = getattr(self, subconfig_key, None)
|
| 404 |
+
if subconfig is not None:
|
| 405 |
+
current_subconfig_moe = getattr(subconfig, "_experts_implementation", None)
|
| 406 |
+
sub_implementation = (
|
| 407 |
+
value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_moe)
|
| 408 |
+
)
|
| 409 |
+
subconfig._experts_implementation = sub_implementation
|
| 410 |
+
|
| 411 |
+
@property
|
| 412 |
+
def torch_dtype(self):
|
| 413 |
+
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
|
| 414 |
+
return self.dtype
|
| 415 |
+
|
| 416 |
+
@property
|
| 417 |
+
def use_return_dict(self):
|
| 418 |
+
logger.warning_once("`use_return_dict` is deprecated! Use `return_dict` instead!")
|
| 419 |
+
return self.return_dict
|
| 420 |
+
|
| 421 |
+
@torch_dtype.setter
|
| 422 |
+
def torch_dtype(self, value):
|
| 423 |
+
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
|
| 424 |
+
self.dtype = value
|
| 425 |
+
|
| 426 |
+
def __setattr__(self, key, value):
|
| 427 |
+
if key in super().__getattribute__("attribute_map"):
|
| 428 |
+
key = super().__getattribute__("attribute_map")[key]
|
| 429 |
+
super().__setattr__(key, value)
|
| 430 |
+
|
| 431 |
+
def __getattribute__(self, key):
|
| 432 |
+
if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
|
| 433 |
+
key = super().__getattribute__("attribute_map")[key]
|
| 434 |
+
return super().__getattribute__(key)
|
| 435 |
+
|
| 436 |
+
def validate_output_attentions(self):
|
| 437 |
+
if self.output_attentions and self._attn_implementation not in ["eager", None]:
|
| 438 |
+
raise ValueError(
|
| 439 |
+
"The `output_attentions` attribute is not supported when using the `attn_implementation` set to "
|
| 440 |
+
f"{self._attn_implementation}. Please set it to 'eager' instead."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
def validate_architecture(self):
|
| 444 |
+
"""Part of `@strict`-powered validation. Validates the architecture of the config."""
|
| 445 |
+
if (
|
| 446 |
+
hasattr(self, "head_dim")
|
| 447 |
+
and hasattr(self, "num_heads")
|
| 448 |
+
and hasattr(self, "embed_dim")
|
| 449 |
+
and self.head_dim * self.num_heads != self.embed_dim
|
| 450 |
+
):
|
| 451 |
+
raise ValueError(
|
| 452 |
+
f"The embed_dim ({self.embed_dim}) is not a multiple of the number of attention "
|
| 453 |
+
f"heads ({self.num_heads})."
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
def validate_token_ids(self):
|
| 457 |
+
"""Part of `@strict`-powered validation. Validates the contents of the special tokens."""
|
| 458 |
+
text_config = self.get_text_config(decoder=True)
|
| 459 |
+
vocab_size = getattr(text_config, "vocab_size", None)
|
| 460 |
+
if vocab_size is not None:
|
| 461 |
+
# Check for all special tokens, e..g. pad_token_id, image_token_id, audio_token_id
|
| 462 |
+
for name in text_config:
|
| 463 |
+
value = getattr(text_config, name)
|
| 464 |
+
if name.endswith("_token_id") and isinstance(value, int) and not 0 <= value < vocab_size:
|
| 465 |
+
# Can't be an exception until we can load configs that fail validation: several configs on the Hub
|
| 466 |
+
# store invalid special tokens, e.g. `pad_token_id=-1`
|
| 467 |
+
logger.warning_once(
|
| 468 |
+
f"Model config: {name} must be `None` or an integer within the vocabulary (between 0 "
|
| 469 |
+
f"and {vocab_size - 1}), got {value}. This may result in unexpected behavior."
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
def validate_layer_type(self):
|
| 473 |
+
"""Check that `layer_types` is correctly defined."""
|
| 474 |
+
if not (getattr(self, "layer_types", None) is not None and hasattr(self, "num_hidden_layers")):
|
| 475 |
+
return
|
| 476 |
+
elif not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in self.layer_types):
|
| 477 |
+
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES} but got {self.layer_types}")
|
| 478 |
+
elif self.num_hidden_layers is not None and self.num_hidden_layers != len(self.layer_types):
|
| 479 |
+
raise ValueError(
|
| 480 |
+
f"`num_hidden_layers` ({self.num_hidden_layers}) must be equal to the number of layer types "
|
| 481 |
+
f"({len(self.layer_types)})"
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
@property
|
| 485 |
+
def rope_scaling(self):
|
| 486 |
+
return self.rope_parameters
|
| 487 |
+
|
| 488 |
+
@rope_scaling.setter
|
| 489 |
+
def rope_scaling(self, value):
|
| 490 |
+
self.rope_parameters = value
|
| 491 |
+
|
| 492 |
+
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
|
| 493 |
+
"""
|
| 494 |
+
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
| 495 |
+
[`~PreTrainedConfig.from_pretrained`] class method.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
save_directory (`str` or `os.PathLike`):
|
| 499 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
| 500 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 501 |
+
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
| 502 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 503 |
+
namespace).
|
| 504 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 505 |
+
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 506 |
+
"""
|
| 507 |
+
if os.path.isfile(save_directory):
|
| 508 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 509 |
+
|
| 510 |
+
generation_parameters = self._get_generation_parameters()
|
| 511 |
+
if len(generation_parameters) > 0:
|
| 512 |
+
raise ValueError(
|
| 513 |
+
"Some generation parameters are set in the model config. These should go into `model.generation_config`"
|
| 514 |
+
f"as opposed to `model.config`. \nGeneration parameters found: {str(generation_parameters)}",
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 518 |
+
|
| 519 |
+
if push_to_hub:
|
| 520 |
+
commit_message = kwargs.pop("commit_message", None)
|
| 521 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
| 522 |
+
repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id
|
| 523 |
+
files_timestamps = self._get_files_timestamps(save_directory)
|
| 524 |
+
|
| 525 |
+
# This attribute is important to know on load, but should not be serialized on save.
|
| 526 |
+
if "transformers_weights" in self:
|
| 527 |
+
delattr(self, "transformers_weights")
|
| 528 |
+
|
| 529 |
+
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
| 530 |
+
# loaded from the Hub.
|
| 531 |
+
if self._auto_class is not None:
|
| 532 |
+
custom_object_save(self, save_directory, config=self)
|
| 533 |
+
|
| 534 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 535 |
+
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
| 536 |
+
|
| 537 |
+
# Strict validation at save-time: prevent bad patterns from propagating
|
| 538 |
+
# Using `strict` decorator guarantees that `self.validate` exists , but not all
|
| 539 |
+
# model config might have the decorator added
|
| 540 |
+
if hasattr(self, "validate"):
|
| 541 |
+
self.validate()
|
| 542 |
+
self.to_json_file(output_config_file, use_diff=True)
|
| 543 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
| 544 |
+
|
| 545 |
+
if push_to_hub:
|
| 546 |
+
self._upload_modified_files(
|
| 547 |
+
save_directory,
|
| 548 |
+
repo_id,
|
| 549 |
+
files_timestamps,
|
| 550 |
+
commit_message=commit_message,
|
| 551 |
+
token=kwargs.get("token"),
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
@classmethod
|
| 555 |
+
def from_pretrained(
|
| 556 |
+
cls: type[SpecificPreTrainedConfigType],
|
| 557 |
+
pretrained_model_name_or_path: str | os.PathLike,
|
| 558 |
+
cache_dir: str | os.PathLike | None = None,
|
| 559 |
+
force_download: bool = False,
|
| 560 |
+
local_files_only: bool = False,
|
| 561 |
+
token: str | bool | None = None,
|
| 562 |
+
revision: str = "main",
|
| 563 |
+
**kwargs,
|
| 564 |
+
) -> SpecificPreTrainedConfigType:
|
| 565 |
+
r"""
|
| 566 |
+
Instantiate a [`PreTrainedConfig`] (or a derived class) from a pretrained model configuration.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 570 |
+
This can be either:
|
| 571 |
+
|
| 572 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
| 573 |
+
huggingface.co.
|
| 574 |
+
- a path to a *directory* containing a configuration file saved using the
|
| 575 |
+
[`~PreTrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
|
| 576 |
+
- a path to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
|
| 577 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
| 578 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 579 |
+
standard cache should not be used.
|
| 580 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 581 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if
|
| 582 |
+
they exist.
|
| 583 |
+
proxies (`dict[str, str]`, *optional*):
|
| 584 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 585 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
| 586 |
+
token (`str` or `bool`, *optional*):
|
| 587 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
| 588 |
+
the token generated when running `hf auth login` (stored in `~/.huggingface`).
|
| 589 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 590 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 591 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 592 |
+
identifier allowed by git.
|
| 593 |
+
|
| 594 |
+
<Tip>
|
| 595 |
+
|
| 596 |
+
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
|
| 597 |
+
|
| 598 |
+
</Tip>
|
| 599 |
+
|
| 600 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 601 |
+
If `False`, then this function returns just the final configuration object.
|
| 602 |
+
|
| 603 |
+
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
|
| 604 |
+
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
|
| 605 |
+
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
|
| 606 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 607 |
+
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
| 608 |
+
specify the folder name here.
|
| 609 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 610 |
+
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
| 611 |
+
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
| 612 |
+
by the `return_unused_kwargs` keyword parameter.
|
| 613 |
+
|
| 614 |
+
Returns:
|
| 615 |
+
[`PreTrainedConfig`]: The configuration object instantiated from this pretrained model.
|
| 616 |
+
|
| 617 |
+
Examples:
|
| 618 |
+
|
| 619 |
+
```python
|
| 620 |
+
# We can't instantiate directly the base class *PreTrainedConfig* so let's show the examples on a
|
| 621 |
+
# derived class: BertConfig
|
| 622 |
+
config = BertConfig.from_pretrained(
|
| 623 |
+
"google-bert/bert-base-uncased"
|
| 624 |
+
) # Download configuration from huggingface.co and cache.
|
| 625 |
+
config = BertConfig.from_pretrained(
|
| 626 |
+
"./test/saved_model/"
|
| 627 |
+
) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
|
| 628 |
+
config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
|
| 629 |
+
config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
|
| 630 |
+
assert config.output_attentions == True
|
| 631 |
+
config, unused_kwargs = BertConfig.from_pretrained(
|
| 632 |
+
"google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
|
| 633 |
+
)
|
| 634 |
+
assert config.output_attentions == True
|
| 635 |
+
assert unused_kwargs == {"foo": False}
|
| 636 |
+
```"""
|
| 637 |
+
kwargs["cache_dir"] = cache_dir
|
| 638 |
+
kwargs["force_download"] = force_download
|
| 639 |
+
kwargs["local_files_only"] = local_files_only
|
| 640 |
+
kwargs["revision"] = revision
|
| 641 |
+
|
| 642 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 643 |
+
if cls.base_config_key and cls.base_config_key in config_dict:
|
| 644 |
+
config_dict = config_dict[cls.base_config_key]
|
| 645 |
+
|
| 646 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 647 |
+
# sometimes the config has no `base_config_key` if the config is used in several composite models
|
| 648 |
+
# e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
|
| 649 |
+
for v in config_dict.values():
|
| 650 |
+
if isinstance(v, dict) and v.get("model_type") == cls.model_type:
|
| 651 |
+
config_dict = v
|
| 652 |
+
|
| 653 |
+
# raise warning only if we still can't see a match in `model_type`
|
| 654 |
+
if config_dict["model_type"] != cls.model_type:
|
| 655 |
+
logger.warning(
|
| 656 |
+
f"You are using a model of type `{config_dict['model_type']}` to instantiate a model of type "
|
| 657 |
+
f"`{cls.model_type}`. This may be expected if you are loading a checkpoint that shares a subset "
|
| 658 |
+
f"of the architecture (e.g., loading a `sam2_video` checkpoint into `Sam2Model`), but is otherwise "
|
| 659 |
+
f"not supported and can yield errors. Please verify that the checkpoint is compatible with the "
|
| 660 |
+
f"model you are instantiating."
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 664 |
+
|
| 665 |
+
@classmethod
|
| 666 |
+
def get_config_dict(
|
| 667 |
+
cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
|
| 668 |
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
| 669 |
+
"""
|
| 670 |
+
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
| 671 |
+
[`PreTrainedConfig`] using `from_dict`.
|
| 672 |
+
|
| 673 |
+
Parameters:
|
| 674 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 675 |
+
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
| 676 |
+
|
| 677 |
+
Returns:
|
| 678 |
+
`tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
|
| 679 |
+
|
| 680 |
+
"""
|
| 681 |
+
original_kwargs = copy.deepcopy(kwargs)
|
| 682 |
+
# Get config dict associated with the base config file
|
| 683 |
+
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 684 |
+
if config_dict is None:
|
| 685 |
+
return {}, kwargs
|
| 686 |
+
if "_commit_hash" in config_dict:
|
| 687 |
+
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
| 688 |
+
|
| 689 |
+
# That config file may point us toward another config file to use.
|
| 690 |
+
if "configuration_files" in config_dict:
|
| 691 |
+
configuration_file = get_configuration_file(config_dict["configuration_files"])
|
| 692 |
+
config_dict, kwargs = cls._get_config_dict(
|
| 693 |
+
pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
return config_dict, kwargs
|
| 697 |
+
|
| 698 |
+
@classmethod
|
| 699 |
+
def _get_config_dict(
|
| 700 |
+
cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
|
| 701 |
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
| 702 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 703 |
+
force_download = kwargs.pop("force_download", False)
|
| 704 |
+
proxies = kwargs.pop("proxies", None)
|
| 705 |
+
token = kwargs.pop("token", None)
|
| 706 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
| 707 |
+
revision = kwargs.pop("revision", None)
|
| 708 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
| 709 |
+
subfolder = kwargs.pop("subfolder", "")
|
| 710 |
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
| 711 |
+
from_auto_class = kwargs.pop("_from_auto", False)
|
| 712 |
+
commit_hash = kwargs.pop("_commit_hash", None)
|
| 713 |
+
|
| 714 |
+
gguf_file = kwargs.get("gguf_file")
|
| 715 |
+
|
| 716 |
+
if trust_remote_code is True:
|
| 717 |
+
logger.warning(
|
| 718 |
+
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
|
| 719 |
+
" ignored."
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
|
| 723 |
+
if from_pipeline is not None:
|
| 724 |
+
user_agent["using_pipeline"] = from_pipeline
|
| 725 |
+
|
| 726 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
| 727 |
+
|
| 728 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
| 729 |
+
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
| 730 |
+
# Special case when pretrained_model_name_or_path is a local file
|
| 731 |
+
resolved_config_file = pretrained_model_name_or_path
|
| 732 |
+
is_local = True
|
| 733 |
+
else:
|
| 734 |
+
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
|
| 735 |
+
|
| 736 |
+
try:
|
| 737 |
+
# Load from local folder or from cache or download from model Hub and cache
|
| 738 |
+
resolved_config_file = cached_file(
|
| 739 |
+
pretrained_model_name_or_path,
|
| 740 |
+
configuration_file,
|
| 741 |
+
cache_dir=cache_dir,
|
| 742 |
+
force_download=force_download,
|
| 743 |
+
proxies=proxies,
|
| 744 |
+
local_files_only=local_files_only,
|
| 745 |
+
token=token,
|
| 746 |
+
user_agent=user_agent,
|
| 747 |
+
revision=revision,
|
| 748 |
+
subfolder=subfolder,
|
| 749 |
+
_commit_hash=commit_hash,
|
| 750 |
+
)
|
| 751 |
+
if resolved_config_file is None:
|
| 752 |
+
return None, kwargs
|
| 753 |
+
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
| 754 |
+
except OSError:
|
| 755 |
+
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
| 756 |
+
# the original exception.
|
| 757 |
+
raise
|
| 758 |
+
except Exception:
|
| 759 |
+
# For any other exception, we throw a generic error.
|
| 760 |
+
raise OSError(
|
| 761 |
+
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
|
| 762 |
+
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
|
| 763 |
+
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
|
| 764 |
+
f" containing a {configuration_file} file"
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
try:
|
| 768 |
+
if gguf_file:
|
| 769 |
+
config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
|
| 770 |
+
else:
|
| 771 |
+
# Load config dict
|
| 772 |
+
config_dict = cls._dict_from_json_file(resolved_config_file)
|
| 773 |
+
|
| 774 |
+
config_dict["_commit_hash"] = commit_hash
|
| 775 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
| 776 |
+
raise OSError(f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file.")
|
| 777 |
+
|
| 778 |
+
if is_local:
|
| 779 |
+
logger.info(f"loading configuration file {resolved_config_file}")
|
| 780 |
+
else:
|
| 781 |
+
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
| 782 |
+
|
| 783 |
+
# timm models are not saved with the model_type in the config file
|
| 784 |
+
if "model_type" not in config_dict and is_timm_config_dict(config_dict):
|
| 785 |
+
config_dict["model_type"] = "timm_wrapper"
|
| 786 |
+
|
| 787 |
+
# Some checkpoints may contain the wrong model_type in the config file.
|
| 788 |
+
# Allow the user to override it but warn them that it might not work.
|
| 789 |
+
if "model_type" in kwargs and config_dict["model_type"] != kwargs["model_type"]:
|
| 790 |
+
logger.warning(
|
| 791 |
+
f"{configuration_file} has 'model_type={config_dict['model_type']}' but you overrode "
|
| 792 |
+
f"it with 'model_type={kwargs['model_type']}'. This may lead to unexpected behavior."
|
| 793 |
+
)
|
| 794 |
+
config_dict["model_type"] = kwargs["model_type"]
|
| 795 |
+
|
| 796 |
+
return config_dict, kwargs
|
| 797 |
+
|
| 798 |
+
@classmethod
|
| 799 |
+
def from_dict(
|
| 800 |
+
cls: type[SpecificPreTrainedConfigType], config_dict: dict[str, Any], **kwargs
|
| 801 |
+
) -> SpecificPreTrainedConfigType:
|
| 802 |
+
"""
|
| 803 |
+
Instantiates a [`PreTrainedConfig`] from a Python dictionary of parameters.
|
| 804 |
+
|
| 805 |
+
Args:
|
| 806 |
+
config_dict (`dict[str, Any]`):
|
| 807 |
+
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
|
| 808 |
+
retrieved from a pretrained checkpoint by leveraging the [`~PreTrainedConfig.get_config_dict`] method.
|
| 809 |
+
kwargs (`dict[str, Any]`):
|
| 810 |
+
Additional parameters from which to initialize the configuration object.
|
| 811 |
+
|
| 812 |
+
Returns:
|
| 813 |
+
[`PreTrainedConfig`]: The configuration object instantiated from those parameters.
|
| 814 |
+
"""
|
| 815 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
| 816 |
+
|
| 817 |
+
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
|
| 818 |
+
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
| 819 |
+
kwargs.setdefault("_commit_hash", config_dict["_commit_hash"])
|
| 820 |
+
|
| 821 |
+
# To remove arg here are those passed along for our internal telemetry but we still need to remove them
|
| 822 |
+
to_remove = ["_from_auto", "_from_pipeline"]
|
| 823 |
+
valid_fields = [
|
| 824 |
+
"num_labels",
|
| 825 |
+
"attn_implementation",
|
| 826 |
+
"experts_implementation",
|
| 827 |
+
"output_attentions",
|
| 828 |
+
"torch_dtype",
|
| 829 |
+
"dtype",
|
| 830 |
+
"name_or_path",
|
| 831 |
+
]
|
| 832 |
+
for key, value in kwargs.items():
|
| 833 |
+
if key in valid_fields:
|
| 834 |
+
if key not in ["torch_dtype", "dtype"]:
|
| 835 |
+
config_dict[key] = value
|
| 836 |
+
to_remove.append(key)
|
| 837 |
+
elif value != "auto":
|
| 838 |
+
config_dict[key] = value
|
| 839 |
+
|
| 840 |
+
config = cls(**config_dict)
|
| 841 |
+
|
| 842 |
+
for key, value in kwargs.items():
|
| 843 |
+
if hasattr(config, key):
|
| 844 |
+
current_attr = getattr(config, key)
|
| 845 |
+
# To authorize passing a custom subconfig as kwarg in models that have nested configs.
|
| 846 |
+
# We need to update only custom kwarg values instead and keep other attr in subconfig.
|
| 847 |
+
if isinstance(current_attr, PreTrainedConfig) and isinstance(value, dict):
|
| 848 |
+
current_attr_updated = current_attr.to_dict()
|
| 849 |
+
current_attr_updated.update(value)
|
| 850 |
+
value = current_attr.__class__(**current_attr_updated)
|
| 851 |
+
setattr(config, key, value)
|
| 852 |
+
to_remove.append(key)
|
| 853 |
+
|
| 854 |
+
for key in to_remove:
|
| 855 |
+
kwargs.pop(key, None)
|
| 856 |
+
|
| 857 |
+
logger.info(f"Model config {config}")
|
| 858 |
+
if return_unused_kwargs:
|
| 859 |
+
return config, kwargs
|
| 860 |
+
else:
|
| 861 |
+
return config
|
| 862 |
+
|
| 863 |
+
@classmethod
|
| 864 |
+
def from_json_file(
|
| 865 |
+
cls: type[SpecificPreTrainedConfigType], json_file: str | os.PathLike
|
| 866 |
+
) -> SpecificPreTrainedConfigType:
|
| 867 |
+
"""
|
| 868 |
+
Instantiates a [`PreTrainedConfig`] from the path to a JSON file of parameters.
|
| 869 |
+
|
| 870 |
+
Args:
|
| 871 |
+
json_file (`str` or `os.PathLike`):
|
| 872 |
+
Path to the JSON file containing the parameters.
|
| 873 |
+
|
| 874 |
+
Returns:
|
| 875 |
+
[`PreTrainedConfig`]: The configuration object instantiated from that JSON file.
|
| 876 |
+
|
| 877 |
+
"""
|
| 878 |
+
config_dict = cls._dict_from_json_file(json_file)
|
| 879 |
+
return cls(**config_dict)
|
| 880 |
+
|
| 881 |
+
@classmethod
|
| 882 |
+
def _dict_from_json_file(cls, json_file: str | os.PathLike):
|
| 883 |
+
with open(json_file, encoding="utf-8") as reader:
|
| 884 |
+
text = reader.read()
|
| 885 |
+
config_dict = json.loads(text)
|
| 886 |
+
|
| 887 |
+
return cls._decode_special_floats(config_dict)
|
| 888 |
+
|
| 889 |
+
@classmethod
|
| 890 |
+
def _encode_special_floats(cls, obj: Any) -> Any:
|
| 891 |
+
"""
|
| 892 |
+
Iterates over the passed object and encode specific floats that cannot be JSON-serialized. Python's JSON
|
| 893 |
+
engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
|
| 894 |
+
|
| 895 |
+
It serializes floats like `Infinity` as an object: `{'__float__': Infinity}`.
|
| 896 |
+
"""
|
| 897 |
+
if isinstance(obj, float):
|
| 898 |
+
if math.isnan(obj):
|
| 899 |
+
return {_FLOAT_TAG_KEY: "NaN"}
|
| 900 |
+
if obj == float("inf"):
|
| 901 |
+
return {_FLOAT_TAG_KEY: "Infinity"}
|
| 902 |
+
if obj == float("-inf"):
|
| 903 |
+
return {_FLOAT_TAG_KEY: "-Infinity"}
|
| 904 |
+
return obj
|
| 905 |
+
|
| 906 |
+
if isinstance(obj, dict):
|
| 907 |
+
return {k: cls._encode_special_floats(v) for k, v in obj.items()}
|
| 908 |
+
|
| 909 |
+
if isinstance(obj, (list, tuple)):
|
| 910 |
+
return [cls._encode_special_floats(v) for v in obj]
|
| 911 |
+
|
| 912 |
+
return obj
|
| 913 |
+
|
| 914 |
+
@classmethod
|
| 915 |
+
def _decode_special_floats(cls, obj: Any) -> Any:
|
| 916 |
+
"""
|
| 917 |
+
Iterates over the passed object and decode specific floats that cannot be JSON-serialized. Python's JSON
|
| 918 |
+
engine saves floats like `Infinity` (+/-) or `NaN` which are not compatible with other JSON engines.
|
| 919 |
+
|
| 920 |
+
This method deserializes objects like `{'__float__': Infinity}` to their float values like `Infinity`.
|
| 921 |
+
"""
|
| 922 |
+
if isinstance(obj, dict):
|
| 923 |
+
if set(obj.keys()) == {_FLOAT_TAG_KEY} and isinstance(obj[_FLOAT_TAG_KEY], str):
|
| 924 |
+
tag = obj[_FLOAT_TAG_KEY]
|
| 925 |
+
if tag in _FLOAT_TAG_VALUES:
|
| 926 |
+
return _FLOAT_TAG_VALUES[tag]
|
| 927 |
+
return obj
|
| 928 |
+
|
| 929 |
+
return {k: cls._decode_special_floats(v) for k, v in obj.items()}
|
| 930 |
+
|
| 931 |
+
if isinstance(obj, list):
|
| 932 |
+
return [cls._decode_special_floats(v) for v in obj]
|
| 933 |
+
|
| 934 |
+
return obj
|
| 935 |
+
|
| 936 |
+
def __eq__(self, other):
|
| 937 |
+
return isinstance(other, PreTrainedConfig) and (self.__dict__ == other.__dict__)
|
| 938 |
+
|
| 939 |
+
def __repr__(self):
|
| 940 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
| 941 |
+
|
| 942 |
+
def __iter__(self):
|
| 943 |
+
yield from self.__dict__
|
| 944 |
+
|
| 945 |
+
def to_diff_dict(self) -> dict[str, Any]:
|
| 946 |
+
"""
|
| 947 |
+
Removes all attributes from the configuration that correspond to the default config attributes for
|
| 948 |
+
better readability, while always retaining the `config` attribute from the class. Serializes to a
|
| 949 |
+
Python dictionary.
|
| 950 |
+
|
| 951 |
+
Returns:
|
| 952 |
+
dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
|
| 953 |
+
"""
|
| 954 |
+
config_dict = self.to_dict()
|
| 955 |
+
|
| 956 |
+
# Get the default config dict (from a fresh PreTrainedConfig instance)
|
| 957 |
+
default_config_dict = PreTrainedConfig().to_dict()
|
| 958 |
+
|
| 959 |
+
# get class specific config dict
|
| 960 |
+
class_config_dict = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
|
| 961 |
+
|
| 962 |
+
serializable_config_dict = {}
|
| 963 |
+
|
| 964 |
+
# Only serialize values that differ from the default config,
|
| 965 |
+
# except always keep the 'config' attribute.
|
| 966 |
+
for key, value in config_dict.items():
|
| 967 |
+
if (
|
| 968 |
+
isinstance(getattr(self, key, None), PreTrainedConfig)
|
| 969 |
+
and key in class_config_dict
|
| 970 |
+
and isinstance(class_config_dict[key], dict)
|
| 971 |
+
):
|
| 972 |
+
# For nested configs we need to clean the diff recursively
|
| 973 |
+
diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
|
| 974 |
+
if "model_type" in value:
|
| 975 |
+
# Needs to be set even if it's not in the diff
|
| 976 |
+
diff["model_type"] = value["model_type"]
|
| 977 |
+
|
| 978 |
+
serializable_config_dict[key] = diff
|
| 979 |
+
elif (
|
| 980 |
+
key not in default_config_dict
|
| 981 |
+
or key == "transformers_version"
|
| 982 |
+
or key == "vocab_file"
|
| 983 |
+
or value != default_config_dict[key]
|
| 984 |
+
or (key in default_config_dict and value != class_config_dict.get(key, value))
|
| 985 |
+
):
|
| 986 |
+
serializable_config_dict[key] = value
|
| 987 |
+
|
| 988 |
+
self._remove_keys_not_serialized(serializable_config_dict)
|
| 989 |
+
|
| 990 |
+
# Key removed only in diff dict
|
| 991 |
+
if "_name_or_path" in serializable_config_dict:
|
| 992 |
+
del serializable_config_dict["_name_or_path"]
|
| 993 |
+
|
| 994 |
+
if hasattr(self, "quantization_config"):
|
| 995 |
+
serializable_config_dict["quantization_config"] = (
|
| 996 |
+
self.quantization_config.to_dict()
|
| 997 |
+
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
| 998 |
+
else self.quantization_config
|
| 999 |
+
)
|
| 1000 |
+
self.dict_dtype_to_str(serializable_config_dict)
|
| 1001 |
+
|
| 1002 |
+
return serializable_config_dict
|
| 1003 |
+
|
| 1004 |
+
def to_dict(self) -> dict[str, Any]:
|
| 1005 |
+
"""
|
| 1006 |
+
Serializes this instance to a Python dictionary.
|
| 1007 |
+
|
| 1008 |
+
Returns:
|
| 1009 |
+
`dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
| 1010 |
+
"""
|
| 1011 |
+
output = copy.deepcopy(self.__dict__)
|
| 1012 |
+
if hasattr(self.__class__, "model_type"):
|
| 1013 |
+
output["model_type"] = self.__class__.model_type
|
| 1014 |
+
|
| 1015 |
+
# Transformers version when serializing the model
|
| 1016 |
+
output["transformers_version"] = __version__
|
| 1017 |
+
|
| 1018 |
+
# Pop "kwargs" since they are unpacked and set in the post init
|
| 1019 |
+
output.pop("kwargs", None)
|
| 1020 |
+
|
| 1021 |
+
def to_list(value):
|
| 1022 |
+
if isinstance(value, tuple):
|
| 1023 |
+
value = [to_list(item) for item in value]
|
| 1024 |
+
return value
|
| 1025 |
+
|
| 1026 |
+
for key, value in output.items():
|
| 1027 |
+
# Deal with nested configs like CLIP
|
| 1028 |
+
if isinstance(value, PreTrainedConfig):
|
| 1029 |
+
value = value.to_dict()
|
| 1030 |
+
del value["transformers_version"]
|
| 1031 |
+
|
| 1032 |
+
# Some models have defaults as tuples because dataclass
|
| 1033 |
+
# doesn't allow mutables. Let's convert back to `list``
|
| 1034 |
+
elif isinstance(value, tuple):
|
| 1035 |
+
value = to_list(value)
|
| 1036 |
+
|
| 1037 |
+
output[key] = value
|
| 1038 |
+
|
| 1039 |
+
self._remove_keys_not_serialized(output)
|
| 1040 |
+
|
| 1041 |
+
if hasattr(self, "quantization_config"):
|
| 1042 |
+
output["quantization_config"] = (
|
| 1043 |
+
self.quantization_config.to_dict()
|
| 1044 |
+
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
| 1045 |
+
else self.quantization_config
|
| 1046 |
+
)
|
| 1047 |
+
self.dict_dtype_to_str(output)
|
| 1048 |
+
|
| 1049 |
+
return output
|
| 1050 |
+
|
| 1051 |
+
def to_json_string(self, use_diff: bool = True) -> str:
|
| 1052 |
+
"""
|
| 1053 |
+
Serializes this instance to a JSON string.
|
| 1054 |
+
|
| 1055 |
+
Args:
|
| 1056 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
| 1057 |
+
If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
|
| 1058 |
+
is serialized to JSON string.
|
| 1059 |
+
|
| 1060 |
+
Returns:
|
| 1061 |
+
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
| 1062 |
+
"""
|
| 1063 |
+
if use_diff is True:
|
| 1064 |
+
config_dict = self.to_diff_dict()
|
| 1065 |
+
else:
|
| 1066 |
+
config_dict = self.to_dict()
|
| 1067 |
+
|
| 1068 |
+
# Handle +/-Infinity and NaNs
|
| 1069 |
+
config_dict = self._encode_special_floats(config_dict)
|
| 1070 |
+
|
| 1071 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
| 1072 |
+
|
| 1073 |
+
def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True):
|
| 1074 |
+
"""
|
| 1075 |
+
Save this instance to a JSON file.
|
| 1076 |
+
|
| 1077 |
+
Args:
|
| 1078 |
+
json_file_path (`str` or `os.PathLike`):
|
| 1079 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
| 1080 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
| 1081 |
+
If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
|
| 1082 |
+
is serialized to JSON file.
|
| 1083 |
+
"""
|
| 1084 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
| 1085 |
+
writer.write(self.to_json_string(use_diff=use_diff))
|
| 1086 |
+
|
| 1087 |
+
def update(self, config_dict: dict[str, Any]):
|
| 1088 |
+
"""
|
| 1089 |
+
Updates attributes of this class with attributes from `config_dict`.
|
| 1090 |
+
|
| 1091 |
+
Args:
|
| 1092 |
+
config_dict (`dict[str, Any]`): Dictionary of attributes that should be updated for this class.
|
| 1093 |
+
"""
|
| 1094 |
+
for key, value in config_dict.items():
|
| 1095 |
+
setattr(self, key, value)
|
| 1096 |
+
|
| 1097 |
+
def update_from_string(self, update_str: str):
|
| 1098 |
+
"""
|
| 1099 |
+
Updates attributes of this class with attributes from `update_str`.
|
| 1100 |
+
|
| 1101 |
+
The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
|
| 1102 |
+
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
| 1103 |
+
|
| 1104 |
+
The keys to change have to already exist in the config object.
|
| 1105 |
+
|
| 1106 |
+
Args:
|
| 1107 |
+
update_str (`str`): String with attributes that should be updated for this class.
|
| 1108 |
+
|
| 1109 |
+
"""
|
| 1110 |
+
|
| 1111 |
+
d = dict(x.split("=") for x in update_str.split(","))
|
| 1112 |
+
for k, v in d.items():
|
| 1113 |
+
if not hasattr(self, k):
|
| 1114 |
+
raise ValueError(f"key {k} isn't in the original config dict")
|
| 1115 |
+
|
| 1116 |
+
old_v = getattr(self, k)
|
| 1117 |
+
if isinstance(old_v, bool):
|
| 1118 |
+
if v.lower() in ["true", "1", "y", "yes"]:
|
| 1119 |
+
v = True
|
| 1120 |
+
elif v.lower() in ["false", "0", "n", "no"]:
|
| 1121 |
+
v = False
|
| 1122 |
+
else:
|
| 1123 |
+
raise ValueError(f"can't derive true or false from {v} (key {k})")
|
| 1124 |
+
elif isinstance(old_v, int):
|
| 1125 |
+
v = int(v)
|
| 1126 |
+
elif isinstance(old_v, float):
|
| 1127 |
+
v = float(v)
|
| 1128 |
+
elif not isinstance(old_v, str):
|
| 1129 |
+
raise TypeError(
|
| 1130 |
+
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
setattr(self, k, v)
|
| 1134 |
+
|
| 1135 |
+
def dict_dtype_to_str(self, d: dict[str, Any]) -> None:
|
| 1136 |
+
"""
|
| 1137 |
+
Checks whether the passed dictionary and its nested dicts have a *dtype* key and if it's not None,
|
| 1138 |
+
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
|
| 1139 |
+
string, which can then be stored in the json format.
|
| 1140 |
+
"""
|
| 1141 |
+
if d.get("dtype") is not None:
|
| 1142 |
+
if isinstance(d["dtype"], dict):
|
| 1143 |
+
d["dtype"] = {k: str(v).split(".")[-1] for k, v in d["dtype"].items()}
|
| 1144 |
+
# models like Emu3 can have "dtype" as token in config's vocabulary map,
|
| 1145 |
+
# so we also exclude int type here to avoid error in this special case.
|
| 1146 |
+
elif not isinstance(d["dtype"], (str, int)):
|
| 1147 |
+
d["dtype"] = str(d["dtype"]).split(".")[1]
|
| 1148 |
+
for value in d.values():
|
| 1149 |
+
if isinstance(value, dict):
|
| 1150 |
+
self.dict_dtype_to_str(value)
|
| 1151 |
+
|
| 1152 |
+
def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
|
| 1153 |
+
"""
|
| 1154 |
+
Checks and removes if there are any keys in the dict that should not be serialized when saving the config.
|
| 1155 |
+
Runs recursive check on the dict, to remove from all sub configs.
|
| 1156 |
+
"""
|
| 1157 |
+
|
| 1158 |
+
for key_to_remove in [
|
| 1159 |
+
"_is_quantized",
|
| 1160 |
+
"_auto_class",
|
| 1161 |
+
"_commit_hash",
|
| 1162 |
+
"_attn_implementation_internal",
|
| 1163 |
+
"_experts_implementation_internal",
|
| 1164 |
+
"ignore_keys_at_rope_validation",
|
| 1165 |
+
"base_model_tp_plan",
|
| 1166 |
+
"base_model_pp_plan",
|
| 1167 |
+
]:
|
| 1168 |
+
d.pop(key_to_remove, None)
|
| 1169 |
+
|
| 1170 |
+
if "_output_attentions" in d:
|
| 1171 |
+
d["output_attentions"] = d.pop("_output_attentions")
|
| 1172 |
+
|
| 1173 |
+
for value in d.values():
|
| 1174 |
+
if isinstance(value, dict):
|
| 1175 |
+
self._remove_keys_not_serialized(value)
|
| 1176 |
+
|
| 1177 |
+
@classmethod
|
| 1178 |
+
def register_for_auto_class(cls, auto_class="AutoConfig"):
|
| 1179 |
+
"""
|
| 1180 |
+
Register this class with a given auto class. This should only be used for custom configurations as the ones in
|
| 1181 |
+
the library are already mapped with `AutoConfig`.
|
| 1182 |
+
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
Args:
|
| 1186 |
+
auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
|
| 1187 |
+
The auto class to register this new configuration with.
|
| 1188 |
+
"""
|
| 1189 |
+
if not isinstance(auto_class, str):
|
| 1190 |
+
auto_class = auto_class.__name__
|
| 1191 |
+
|
| 1192 |
+
import transformers.models.auto as auto_module
|
| 1193 |
+
|
| 1194 |
+
if not hasattr(auto_module, auto_class):
|
| 1195 |
+
raise ValueError(f"{auto_class} is not a valid auto class.")
|
| 1196 |
+
|
| 1197 |
+
cls._auto_class = auto_class
|
| 1198 |
+
|
| 1199 |
+
def _get_generation_parameters(self) -> dict[str, Any]:
|
| 1200 |
+
"""
|
| 1201 |
+
Checks if there are generation parameters in `PreTrainedConfig` instance. Note that
|
| 1202 |
+
we should not save generation params in PreTrainedConfig, and we will raise error
|
| 1203 |
+
if there are any.
|
| 1204 |
+
"""
|
| 1205 |
+
generation_params = {}
|
| 1206 |
+
default_config = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
|
| 1207 |
+
for key in GenerationConfig._get_default_generation_params().keys():
|
| 1208 |
+
if key == "use_cache":
|
| 1209 |
+
continue # common key for most models
|
| 1210 |
+
if hasattr(self, key) and getattr(self, key) is not None and key not in default_config:
|
| 1211 |
+
generation_params[key] = getattr(self, key)
|
| 1212 |
+
|
| 1213 |
+
return generation_params
|
| 1214 |
+
|
| 1215 |
+
def get_text_config(self, decoder=None, encoder=None) -> "PreTrainedConfig":
|
| 1216 |
+
"""
|
| 1217 |
+
Returns the text config related to the text input (encoder) or text output (decoder) of the model. The
|
| 1218 |
+
`decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in,
|
| 1219 |
+
which is useful on models that have both text input and output modalities.
|
| 1220 |
+
|
| 1221 |
+
There are three possible outcomes of using this method:
|
| 1222 |
+
1. On most models, it returns the original config instance itself.
|
| 1223 |
+
2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set
|
| 1224 |
+
of valid names.
|
| 1225 |
+
3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa.
|
| 1226 |
+
|
| 1227 |
+
Args:
|
| 1228 |
+
decoder (`Optional[bool]`, *optional*):
|
| 1229 |
+
If set to `True`, then only search for decoder config names.
|
| 1230 |
+
encoder (`Optional[bool]`, *optional*):
|
| 1231 |
+
If set to `True`, then only search for encoder config names.
|
| 1232 |
+
"""
|
| 1233 |
+
return_both = decoder == encoder # both unset or both set -> search all possible names
|
| 1234 |
+
|
| 1235 |
+
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
| 1236 |
+
encoder_possible_text_config_names = ("text_encoder",)
|
| 1237 |
+
if return_both:
|
| 1238 |
+
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
|
| 1239 |
+
elif decoder:
|
| 1240 |
+
possible_text_config_names = decoder_possible_text_config_names
|
| 1241 |
+
else:
|
| 1242 |
+
possible_text_config_names = encoder_possible_text_config_names
|
| 1243 |
+
|
| 1244 |
+
valid_text_config_names = []
|
| 1245 |
+
for text_config_name in possible_text_config_names:
|
| 1246 |
+
if hasattr(self, text_config_name):
|
| 1247 |
+
text_config = getattr(self, text_config_name, None)
|
| 1248 |
+
if text_config is not None:
|
| 1249 |
+
valid_text_config_names += [text_config_name]
|
| 1250 |
+
|
| 1251 |
+
if len(valid_text_config_names) > 1:
|
| 1252 |
+
raise ValueError(
|
| 1253 |
+
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
|
| 1254 |
+
"case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
|
| 1255 |
+
"e.g. `text_config = config.sub_config_name`"
|
| 1256 |
+
)
|
| 1257 |
+
elif len(valid_text_config_names) == 1:
|
| 1258 |
+
config_to_return = getattr(self, valid_text_config_names[0])
|
| 1259 |
+
else:
|
| 1260 |
+
config_to_return = self
|
| 1261 |
+
|
| 1262 |
+
# handle legacy models with flat config structure, when we only want one of the configs
|
| 1263 |
+
if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
|
| 1264 |
+
config_to_return = copy.deepcopy(config_to_return)
|
| 1265 |
+
prefix_to_keep = "decoder" if decoder else "encoder"
|
| 1266 |
+
for key in config_to_return.to_dict():
|
| 1267 |
+
# NOTE: We can't discard keys because:
|
| 1268 |
+
# 1) we can't truly delete a cls attribte on a dataclass; 2) we can't set the value to `None` due to
|
| 1269 |
+
# strict validation. So we just keep it as is, since there are only a couple old models falling in this condition
|
| 1270 |
+
if key.startswith(prefix_to_keep):
|
| 1271 |
+
# [encoder/decoder]_layers -> num_hidden_layers
|
| 1272 |
+
if key == prefix_to_keep + "_layers":
|
| 1273 |
+
new_key = "num_hidden_layers"
|
| 1274 |
+
# [encoder/decoder]_attention_heads -> num_attention_heads
|
| 1275 |
+
elif key == prefix_to_keep + "_attention_heads":
|
| 1276 |
+
new_key = "num_attention_heads"
|
| 1277 |
+
# e.g. encoder_hidden_act -> hidden_act
|
| 1278 |
+
else:
|
| 1279 |
+
new_key = key[len(prefix_to_keep) + 1 :]
|
| 1280 |
+
|
| 1281 |
+
# Does the class map the new key into a different attribute name at read time? if so, let's write
|
| 1282 |
+
# into that attribute instead
|
| 1283 |
+
if new_key in config_to_return.attribute_map:
|
| 1284 |
+
new_key = config_to_return.attribute_map[new_key]
|
| 1285 |
+
|
| 1286 |
+
value = getattr(config_to_return, key)
|
| 1287 |
+
delattr(config_to_return, key)
|
| 1288 |
+
setattr(config_to_return, new_key, value)
|
| 1289 |
+
|
| 1290 |
+
return config_to_return
|
| 1291 |
+
|
| 1292 |
+
|
| 1293 |
+
def get_configuration_file(configuration_files: list[str]) -> str:
|
| 1294 |
+
"""
|
| 1295 |
+
Get the configuration file to use for this version of transformers.
|
| 1296 |
+
|
| 1297 |
+
Args:
|
| 1298 |
+
configuration_files (`list[str]`): The list of available configuration files.
|
| 1299 |
+
|
| 1300 |
+
Returns:
|
| 1301 |
+
`str`: The configuration file to use.
|
| 1302 |
+
"""
|
| 1303 |
+
configuration_files_map = {}
|
| 1304 |
+
for file_name in configuration_files:
|
| 1305 |
+
if file_name.startswith("config.") and file_name.endswith(".json") and file_name != "config.json":
|
| 1306 |
+
v = file_name.removeprefix("config.").removesuffix(".json")
|
| 1307 |
+
configuration_files_map[v] = file_name
|
| 1308 |
+
available_versions = sorted(configuration_files_map.keys())
|
| 1309 |
+
|
| 1310 |
+
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
|
| 1311 |
+
configuration_file = CONFIG_NAME
|
| 1312 |
+
transformers_version = version.parse(__version__)
|
| 1313 |
+
for v in available_versions:
|
| 1314 |
+
if version.parse(v) <= transformers_version:
|
| 1315 |
+
configuration_file = configuration_files_map[v]
|
| 1316 |
+
else:
|
| 1317 |
+
# No point going further since the versions are sorted.
|
| 1318 |
+
break
|
| 1319 |
+
|
| 1320 |
+
return configuration_file
|
| 1321 |
+
|
| 1322 |
+
|
| 1323 |
+
def recursive_diff_dict(dict_a, dict_b, config_obj=None):
|
| 1324 |
+
"""
|
| 1325 |
+
Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
|
| 1326 |
+
values from `dict_a` that are different from values in `dict_b`.
|
| 1327 |
+
|
| 1328 |
+
dict_b : the default config dictionary. We want to remove values that are in this one
|
| 1329 |
+
"""
|
| 1330 |
+
diff = {}
|
| 1331 |
+
default = config_obj.__class__().to_dict() if config_obj is not None else {}
|
| 1332 |
+
for key, value in dict_a.items():
|
| 1333 |
+
obj_value = getattr(config_obj, str(key), None)
|
| 1334 |
+
if isinstance(obj_value, PreTrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
|
| 1335 |
+
diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
|
| 1336 |
+
diff[key] = diff_value
|
| 1337 |
+
elif key not in dict_b or (value != default[key]):
|
| 1338 |
+
diff[key] = value
|
| 1339 |
+
return diff
|
| 1340 |
+
|
| 1341 |
+
|
| 1342 |
+
PreTrainedConfig.push_to_hub = copy_func(PreTrainedConfig.push_to_hub)
|
| 1343 |
+
if PreTrainedConfig.push_to_hub.__doc__ is not None:
|
| 1344 |
+
PreTrainedConfig.push_to_hub.__doc__ = PreTrainedConfig.push_to_hub.__doc__.format(
|
| 1345 |
+
object="config", object_class="AutoConfig", object_files="configuration file"
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
|
| 1349 |
+
# The alias is only here for BC - we did not have the correct CamelCasing before
|
| 1350 |
+
PretrainedConfig = PreTrainedConfig
|
| 1351 |
+
|
| 1352 |
+
|
| 1353 |
+
def layer_type_validation(layer_types: list[str], num_hidden_layers: int | None = None, attention: bool = True):
|
| 1354 |
+
logger.warning(
|
| 1355 |
+
"`layer_type_validation` is deprecated and will be removed in v5.20. "
|
| 1356 |
+
"Use `PreTrainedConfig.validate_layer_type` instead"
|
| 1357 |
+
)
|
| 1358 |
+
|
| 1359 |
+
if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
|
| 1360 |
+
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
|
| 1361 |
+
if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
|
| 1362 |
+
raise ValueError(
|
| 1363 |
+
f"`num_hidden_layers` ({num_hidden_layers}) must be equal to the number of layer types "
|
| 1364 |
+
f"({len(layer_types)})"
|
| 1365 |
+
)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/file_utils.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
File utilities: utilities related to download and cache models
|
| 16 |
+
|
| 17 |
+
This module should not be update anymore and is only left for backward compatibility.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from . import __version__
|
| 21 |
+
|
| 22 |
+
# Backward compatibility imports, to make sure all those objects can be found in file_utils
|
| 23 |
+
from .utils import (
|
| 24 |
+
CLOUDFRONT_DISTRIB_PREFIX,
|
| 25 |
+
CONFIG_NAME,
|
| 26 |
+
DUMMY_INPUTS,
|
| 27 |
+
DUMMY_MASK,
|
| 28 |
+
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
| 29 |
+
ENV_VARS_TRUE_VALUES,
|
| 30 |
+
FEATURE_EXTRACTOR_NAME,
|
| 31 |
+
HF_MODULES_CACHE,
|
| 32 |
+
MODEL_CARD_NAME,
|
| 33 |
+
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
| 34 |
+
S3_BUCKET_PREFIX,
|
| 35 |
+
SENTENCEPIECE_UNDERLINE,
|
| 36 |
+
SPIECE_UNDERLINE,
|
| 37 |
+
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
| 38 |
+
WEIGHTS_INDEX_NAME,
|
| 39 |
+
WEIGHTS_NAME,
|
| 40 |
+
ContextManagers,
|
| 41 |
+
DummyObject,
|
| 42 |
+
EntryNotFoundError,
|
| 43 |
+
ExplicitEnum,
|
| 44 |
+
ModelOutput,
|
| 45 |
+
PaddingStrategy,
|
| 46 |
+
PushToHubMixin,
|
| 47 |
+
RepositoryNotFoundError,
|
| 48 |
+
RevisionNotFoundError,
|
| 49 |
+
TensorType,
|
| 50 |
+
_LazyModule,
|
| 51 |
+
add_code_sample_docstrings,
|
| 52 |
+
add_end_docstrings,
|
| 53 |
+
add_start_docstrings,
|
| 54 |
+
add_start_docstrings_to_model_forward,
|
| 55 |
+
copy_func,
|
| 56 |
+
define_sagemaker_information,
|
| 57 |
+
get_torch_version,
|
| 58 |
+
has_file,
|
| 59 |
+
http_user_agent,
|
| 60 |
+
is_apex_available,
|
| 61 |
+
is_bs4_available,
|
| 62 |
+
is_coloredlogs_available,
|
| 63 |
+
is_datasets_available,
|
| 64 |
+
is_detectron2_available,
|
| 65 |
+
is_faiss_available,
|
| 66 |
+
is_g2p_en_available,
|
| 67 |
+
is_in_notebook,
|
| 68 |
+
is_librosa_available,
|
| 69 |
+
is_onnx_available,
|
| 70 |
+
is_pandas_available,
|
| 71 |
+
is_phonemizer_available,
|
| 72 |
+
is_protobuf_available,
|
| 73 |
+
is_psutil_available,
|
| 74 |
+
is_py3nvml_available,
|
| 75 |
+
is_pyctcdecode_available,
|
| 76 |
+
is_pytesseract_available,
|
| 77 |
+
is_pytorch_quantization_available,
|
| 78 |
+
is_rjieba_available,
|
| 79 |
+
is_sagemaker_dp_enabled,
|
| 80 |
+
is_sagemaker_mp_enabled,
|
| 81 |
+
is_scipy_available,
|
| 82 |
+
is_sentencepiece_available,
|
| 83 |
+
is_seqio_available,
|
| 84 |
+
is_sklearn_available,
|
| 85 |
+
is_soundfile_available,
|
| 86 |
+
is_spacy_available,
|
| 87 |
+
is_speech_available,
|
| 88 |
+
is_tensor,
|
| 89 |
+
is_timm_available,
|
| 90 |
+
is_tokenizers_available,
|
| 91 |
+
is_torch_available,
|
| 92 |
+
is_torch_cuda_available,
|
| 93 |
+
is_torch_fx_proxy,
|
| 94 |
+
is_torch_mps_available,
|
| 95 |
+
is_torch_tf32_available,
|
| 96 |
+
is_torch_xla_available,
|
| 97 |
+
is_torchaudio_available,
|
| 98 |
+
is_training_run_on_sagemaker,
|
| 99 |
+
is_vision_available,
|
| 100 |
+
replace_return_docstrings,
|
| 101 |
+
requires_backends,
|
| 102 |
+
to_numpy,
|
| 103 |
+
to_py_obj,
|
| 104 |
+
torch_only_method,
|
| 105 |
+
)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/fusion_mapping.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Fusion registration helpers.
|
| 16 |
+
|
| 17 |
+
See `docs/source/en/fusion_mapping.md` for the design overview and extension guide.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
import re
|
| 22 |
+
from collections.abc import Mapping
|
| 23 |
+
from typing import TYPE_CHECKING, Any
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from torch import nn
|
| 27 |
+
|
| 28 |
+
from .conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping
|
| 29 |
+
from .core_model_loading import Conv3dToLinear, WeightConverter, WeightRenaming, WeightTransform
|
| 30 |
+
from .monkey_patching import register_patch_mapping
|
| 31 |
+
from .utils import logging
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if TYPE_CHECKING:
|
| 35 |
+
from .configuration_utils import PretrainedConfig
|
| 36 |
+
from .modeling_utils import PreTrainedModel
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
_FUSION_DISCOVERY_CACHE: dict[str, dict[type, dict[str, type[nn.Module]]]] = {}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ModuleFusionSpec:
|
| 45 |
+
"""Base recipe for a fusion family.
|
| 46 |
+
|
| 47 |
+
A fusion spec decides which modules are eligible for a fusion, how to build
|
| 48 |
+
the runtime replacement class, and which weight transforms are needed to map
|
| 49 |
+
checkpoints between the original and fused layouts.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
target_modules_patterns: tuple[str, ...] = ()
|
| 53 |
+
|
| 54 |
+
def get_empty_log(self, model_name: str) -> str:
|
| 55 |
+
"""Return the log message emitted when no compatible modules are found."""
|
| 56 |
+
return f"No compatible {type(self).__name__} classes found to fuse for {model_name}"
|
| 57 |
+
|
| 58 |
+
def is_fusable(self, module: nn.Module) -> bool:
|
| 59 |
+
"""Return whether `module` is compatible with this fusion family."""
|
| 60 |
+
raise NotImplementedError
|
| 61 |
+
|
| 62 |
+
def make_fused_class(self, original_cls: type[nn.Module]) -> type[nn.Module]:
|
| 63 |
+
"""Build the runtime replacement class for a compatible module class."""
|
| 64 |
+
raise NotImplementedError
|
| 65 |
+
|
| 66 |
+
def make_transforms(self, config: "PretrainedConfig") -> list[WeightTransform]:
|
| 67 |
+
"""Build the weight transforms needed to load and save the fused runtime layout."""
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class _FusedPatchEmbeddingMixin:
|
| 72 |
+
def __init__(self, *args, **kwargs):
|
| 73 |
+
# call the original_cls.__init__()
|
| 74 |
+
super().__init__(*args, **kwargs)
|
| 75 |
+
self.patch_volume = self.proj.in_channels * math.prod(self.proj.kernel_size)
|
| 76 |
+
|
| 77 |
+
self.linear_proj = nn.Linear(
|
| 78 |
+
self.patch_volume,
|
| 79 |
+
self.proj.out_channels,
|
| 80 |
+
bias=self.proj.bias is not None,
|
| 81 |
+
device=self.proj.weight.device,
|
| 82 |
+
dtype=self.proj.weight.dtype,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
del self.proj
|
| 86 |
+
|
| 87 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
target_dtype = self.linear_proj.weight.dtype
|
| 89 |
+
hidden_states = hidden_states.view(-1, self.patch_volume)
|
| 90 |
+
hidden_states = self.linear_proj(hidden_states.to(dtype=target_dtype))
|
| 91 |
+
return hidden_states.view(-1, self.embed_dim)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class PatchEmbeddingsFusionSpec(ModuleFusionSpec):
|
| 95 |
+
"""Fuse compatible Conv3d patch embeddings into flattened Linear projections."""
|
| 96 |
+
|
| 97 |
+
target_modules_patterns = (r"(^|\.)patch_embed$",)
|
| 98 |
+
|
| 99 |
+
def is_fusable(self, module: nn.Module) -> bool:
|
| 100 |
+
if not isinstance(proj := getattr(module, "proj", None), nn.Conv3d):
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
# no overlap between the patches
|
| 104 |
+
return (
|
| 105 |
+
proj.stride == proj.kernel_size
|
| 106 |
+
and proj.padding == (0, 0, 0)
|
| 107 |
+
and proj.dilation == (1, 1, 1)
|
| 108 |
+
and proj.groups == 1
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def make_fused_class(self, original_cls: type[nn.Module]) -> type[nn.Module]:
|
| 112 |
+
fused_cls = type(f"Fused{original_cls.__name__}", (_FusedPatchEmbeddingMixin, original_cls), {})
|
| 113 |
+
fused_cls.__qualname__ = f"Fused{original_cls.__qualname__}"
|
| 114 |
+
return fused_cls
|
| 115 |
+
|
| 116 |
+
def make_transforms(self, config: "PretrainedConfig") -> list[WeightTransform]:
|
| 117 |
+
vision_config = getattr(config, "vision_config", config)
|
| 118 |
+
patch_size = vision_config.patch_size
|
| 119 |
+
if isinstance(patch_size, int):
|
| 120 |
+
patch_size = (patch_size, patch_size)
|
| 121 |
+
kernel_size = (vision_config.temporal_patch_size, *tuple(patch_size))
|
| 122 |
+
in_channels = vision_config.in_channels
|
| 123 |
+
|
| 124 |
+
return [
|
| 125 |
+
WeightConverter(
|
| 126 |
+
source_patterns=r"patch_embed\.proj\.weight$",
|
| 127 |
+
target_patterns=r"patch_embed\.linear_proj\.weight$",
|
| 128 |
+
operations=[
|
| 129 |
+
Conv3dToLinear(
|
| 130 |
+
in_channels=in_channels,
|
| 131 |
+
kernel_size=kernel_size,
|
| 132 |
+
)
|
| 133 |
+
],
|
| 134 |
+
),
|
| 135 |
+
WeightRenaming(
|
| 136 |
+
source_patterns=r"patch_embed\.proj\.bias$",
|
| 137 |
+
target_patterns=r"patch_embed\.linear_proj\.bias$",
|
| 138 |
+
),
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _discover_fusable_modules(
|
| 143 |
+
cls: "type[PreTrainedModel]",
|
| 144 |
+
config: "PretrainedConfig",
|
| 145 |
+
fusion_name: str,
|
| 146 |
+
spec: ModuleFusionSpec,
|
| 147 |
+
) -> dict[str, type[nn.Module]]:
|
| 148 |
+
"""Discover compatible module classes for one fusion family on a meta-initialized model.
|
| 149 |
+
|
| 150 |
+
This function:
|
| 151 |
+
- instantiates `cls(config)` on the meta device
|
| 152 |
+
- scans `named_modules()` for candidate modules
|
| 153 |
+
- optionally pre-filters them with `target_modules_patterns`
|
| 154 |
+
- uses `is_fusable(...)` as the final structural check
|
| 155 |
+
- builds the class-level patch mapping used by monkey patching
|
| 156 |
+
|
| 157 |
+
Results are cached per `(fusion_name, cls)` to avoid repeated meta-initialization.
|
| 158 |
+
This matches the current class-level fusion behavior, where one compatible
|
| 159 |
+
module class maps to one fused replacement class.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
cache = _FUSION_DISCOVERY_CACHE.setdefault(fusion_name, {})
|
| 163 |
+
if cls in cache:
|
| 164 |
+
return cache[cls]
|
| 165 |
+
|
| 166 |
+
with torch.device("meta"):
|
| 167 |
+
model = cls(config)
|
| 168 |
+
|
| 169 |
+
seen_classes = set()
|
| 170 |
+
patch_mapping = {}
|
| 171 |
+
target_module_pattern = (
|
| 172 |
+
re.compile("|".join(spec.target_modules_patterns)) if spec.target_modules_patterns else None
|
| 173 |
+
)
|
| 174 |
+
for module_name, module in model.named_modules():
|
| 175 |
+
module_cls = type(module)
|
| 176 |
+
if module_cls in seen_classes:
|
| 177 |
+
continue
|
| 178 |
+
if target_module_pattern is not None and target_module_pattern.search(module_name) is None:
|
| 179 |
+
continue
|
| 180 |
+
if not spec.is_fusable(module):
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
seen_classes.add(module_cls)
|
| 184 |
+
patch_mapping[module_cls.__name__] = spec.make_fused_class(module_cls)
|
| 185 |
+
|
| 186 |
+
cache[cls] = patch_mapping
|
| 187 |
+
return patch_mapping
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _register_module_fusion(
|
| 191 |
+
cls: "type[PreTrainedModel]", config: "PretrainedConfig", fusion_name: str, spec: ModuleFusionSpec
|
| 192 |
+
) -> None:
|
| 193 |
+
"""Register one fusion family for `cls`.
|
| 194 |
+
|
| 195 |
+
This function updates the two global registries used by fused loading:
|
| 196 |
+
- the monkey-patching registry, so compatible module classes are replaced before initialization
|
| 197 |
+
- the checkpoint conversion mapping, so fused runtime modules still load from the original checkpoint layout
|
| 198 |
+
|
| 199 |
+
Notes:
|
| 200 |
+
- conflicting checkpoint transforms fail fast
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
fusable_classes = _discover_fusable_modules(cls, config, fusion_name=fusion_name, spec=spec)
|
| 204 |
+
if not fusable_classes:
|
| 205 |
+
logger.info(spec.get_empty_log(cls.__name__))
|
| 206 |
+
return
|
| 207 |
+
|
| 208 |
+
register_patch_mapping(fusable_classes, overwrite=True)
|
| 209 |
+
|
| 210 |
+
if not hasattr(cls, "config_class") or not hasattr(cls.config_class, "model_type"):
|
| 211 |
+
raise ValueError(f"Model {cls.__name__} has no config class or model type")
|
| 212 |
+
model_type = cls.config_class.model_type
|
| 213 |
+
converters = spec.make_transforms(config)
|
| 214 |
+
|
| 215 |
+
existing_converters = get_checkpoint_conversion_mapping(model_type)
|
| 216 |
+
if existing_converters is not None:
|
| 217 |
+
# WeightConverter matching stops at the first matching source pattern, so
|
| 218 |
+
# conflicting converters must fail fast instead of being appended.
|
| 219 |
+
existing_converter_sources = {tuple(existing.source_patterns): existing for existing in existing_converters}
|
| 220 |
+
for converter in converters:
|
| 221 |
+
source_patterns = tuple(converter.source_patterns)
|
| 222 |
+
existing_converter = existing_converter_sources.get(source_patterns)
|
| 223 |
+
if existing_converter is not None:
|
| 224 |
+
raise ValueError(
|
| 225 |
+
f"Fusion {fusion_name} for model type {model_type} conflicts with an existing conversion mapping "
|
| 226 |
+
f"for source patterns {source_patterns}."
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# TODO: allow compatible fusions mentioned https://github.com/huggingface/transformers/pull/45041#discussion_r3028989716
|
| 230 |
+
converters = existing_converters + converters
|
| 231 |
+
|
| 232 |
+
register_checkpoint_conversion_mapping(model_type, converters, overwrite=True)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
_FUSION_REGISTRY: dict[str, ModuleFusionSpec] = {"patch_embeddings": PatchEmbeddingsFusionSpec()}
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def _iter_enabled_fusions(fusion_config: Mapping[str, bool | Mapping[str, Any]]) -> list[str]:
|
| 239 |
+
"""Validate `fusion_config` and return enabled fusion names in user-specified order."""
|
| 240 |
+
|
| 241 |
+
enabled_fusions = []
|
| 242 |
+
for fusion_name, fusion_options in fusion_config.items():
|
| 243 |
+
if fusion_name not in _FUSION_REGISTRY:
|
| 244 |
+
raise ValueError(f"Unknown fusion type: {fusion_name}")
|
| 245 |
+
if fusion_options is False:
|
| 246 |
+
continue
|
| 247 |
+
if fusion_options is not True and not isinstance(fusion_options, Mapping):
|
| 248 |
+
raise ValueError(
|
| 249 |
+
f"Invalid fusion config for {fusion_name}: expected `True`, `False`, or a mapping of options."
|
| 250 |
+
)
|
| 251 |
+
enabled_fusions.append(fusion_name)
|
| 252 |
+
return enabled_fusions
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def register_fusion_patches(
|
| 256 |
+
cls: "type[PreTrainedModel]", config, fusion_config: Mapping[str, bool | Mapping[str, Any]] | None = None
|
| 257 |
+
) -> None:
|
| 258 |
+
"""Register requested runtime fusions for `cls`.
|
| 259 |
+
|
| 260 |
+
This function:
|
| 261 |
+
- validates `fusion_config` against `_FUSION_REGISTRY`
|
| 262 |
+
- resolves the enabled fusion families in user order
|
| 263 |
+
- registers monkey patches and checkpoint transforms before model instantiation
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
if not fusion_config:
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
for fusion_name in _iter_enabled_fusions(fusion_config):
|
| 270 |
+
_register_module_fusion(cls, config, fusion_name, _FUSION_REGISTRY[fusion_name])
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_processing_backends.py
ADDED
|
@@ -0,0 +1,689 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from collections.abc import Iterable
|
| 16 |
+
from functools import lru_cache
|
| 17 |
+
from typing import Any, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from .image_processing_base import BatchFeature
|
| 22 |
+
from .image_processing_utils import BaseImageProcessor
|
| 23 |
+
from .image_transforms import (
|
| 24 |
+
center_crop as np_center_crop,
|
| 25 |
+
)
|
| 26 |
+
from .image_transforms import (
|
| 27 |
+
convert_to_rgb,
|
| 28 |
+
divide_to_patches, # noqa: F401 - re-exported for backward compat with image_processing_utils_fast
|
| 29 |
+
get_resize_output_image_size,
|
| 30 |
+
get_size_with_aspect_ratio,
|
| 31 |
+
group_images_by_shape,
|
| 32 |
+
reorder_images,
|
| 33 |
+
)
|
| 34 |
+
from .image_transforms import (
|
| 35 |
+
normalize as np_normalize,
|
| 36 |
+
)
|
| 37 |
+
from .image_transforms import (
|
| 38 |
+
rescale as np_rescale,
|
| 39 |
+
)
|
| 40 |
+
from .image_transforms import (
|
| 41 |
+
resize as np_resize,
|
| 42 |
+
)
|
| 43 |
+
from .image_utils import (
|
| 44 |
+
ChannelDimension,
|
| 45 |
+
ImageInput,
|
| 46 |
+
ImageType,
|
| 47 |
+
SizeDict,
|
| 48 |
+
get_image_size,
|
| 49 |
+
get_image_size_for_max_height_width,
|
| 50 |
+
get_image_type,
|
| 51 |
+
get_max_height_width,
|
| 52 |
+
infer_channel_dimension_format,
|
| 53 |
+
is_valid_image,
|
| 54 |
+
load_image_as_tensor,
|
| 55 |
+
)
|
| 56 |
+
from .processing_utils import ImagesKwargs, Unpack
|
| 57 |
+
from .utils import (
|
| 58 |
+
TensorType,
|
| 59 |
+
is_torch_available,
|
| 60 |
+
is_torchvision_available,
|
| 61 |
+
is_vision_available,
|
| 62 |
+
logging,
|
| 63 |
+
)
|
| 64 |
+
from .utils.import_utils import is_rocm_platform, is_torchdynamo_compiling, requires
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if is_vision_available():
|
| 68 |
+
from .image_utils import PILImageResampling
|
| 69 |
+
|
| 70 |
+
if is_torch_available():
|
| 71 |
+
import torch
|
| 72 |
+
|
| 73 |
+
if is_torchvision_available():
|
| 74 |
+
from torchvision.transforms.v2 import functional as tvF
|
| 75 |
+
|
| 76 |
+
from .image_utils import pil_torch_interpolation_mapping, torch_pil_interpolation_mapping
|
| 77 |
+
else:
|
| 78 |
+
pil_torch_interpolation_mapping = None
|
| 79 |
+
torch_pil_interpolation_mapping = None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
logger = logging.get_logger(__name__)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@requires(backends=("torch", "torchvision"))
|
| 86 |
+
class TorchvisionBackend(BaseImageProcessor):
|
| 87 |
+
"""Torchvision backend for GPU-accelerated batched image processing."""
|
| 88 |
+
|
| 89 |
+
def __init__(self, **kwargs: Unpack[ImagesKwargs]):
|
| 90 |
+
super().__init__(**kwargs)
|
| 91 |
+
self._set_attributes(**kwargs)
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def is_fast(self) -> bool:
|
| 95 |
+
"""
|
| 96 |
+
`bool`: Whether or not this image processor is using the fast (Torchvision) backend.
|
| 97 |
+
The `is_fast` property is deprecated and will be removed in v5.3 of Transformers.
|
| 98 |
+
Use the `backend` attribute instead (e.g., `processor.backend == "torchvision"`).
|
| 99 |
+
"""
|
| 100 |
+
logger.warning_once(
|
| 101 |
+
"The `is_fast` property is deprecated and will be removed in v5.3 of Transformers. "
|
| 102 |
+
"Use the `backend` attribute instead (e.g., `processor.backend == 'torchvision'`)."
|
| 103 |
+
)
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def backend(self) -> str:
|
| 108 |
+
"""
|
| 109 |
+
`str`: The backend used by this image processor.
|
| 110 |
+
"""
|
| 111 |
+
return "torchvision"
|
| 112 |
+
|
| 113 |
+
def fetch_images(self, image_url_or_urls: str | list[str] | list[list[str]]):
|
| 114 |
+
"""
|
| 115 |
+
Convert a single or a list of URLs / paths into `torch.Tensor` objects.
|
| 116 |
+
|
| 117 |
+
Already-valid image objects (tensors, numpy arrays, PIL Images) are passed through
|
| 118 |
+
unchanged so that callers who pre-load images are unaffected.
|
| 119 |
+
"""
|
| 120 |
+
if isinstance(image_url_or_urls, (list, tuple)):
|
| 121 |
+
return [self.fetch_images(x) for x in image_url_or_urls]
|
| 122 |
+
elif isinstance(image_url_or_urls, str):
|
| 123 |
+
return load_image_as_tensor(image_url_or_urls)
|
| 124 |
+
elif is_valid_image(image_url_or_urls):
|
| 125 |
+
return image_url_or_urls
|
| 126 |
+
else:
|
| 127 |
+
raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
|
| 128 |
+
|
| 129 |
+
def process_image(
|
| 130 |
+
self,
|
| 131 |
+
image: ImageInput,
|
| 132 |
+
do_convert_rgb: bool | None = None,
|
| 133 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 134 |
+
device: Optional["torch.device"] = None,
|
| 135 |
+
**kwargs: Unpack[ImagesKwargs],
|
| 136 |
+
) -> "torch.Tensor":
|
| 137 |
+
"""Process a single image for torchvision backend."""
|
| 138 |
+
image_type = get_image_type(image)
|
| 139 |
+
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
|
| 140 |
+
raise ValueError(f"Unsupported input image type {image_type}")
|
| 141 |
+
|
| 142 |
+
if do_convert_rgb:
|
| 143 |
+
image = self.convert_to_rgb(image)
|
| 144 |
+
|
| 145 |
+
if image_type == ImageType.PIL:
|
| 146 |
+
image = tvF.pil_to_tensor(image)
|
| 147 |
+
elif image_type == ImageType.NUMPY:
|
| 148 |
+
image = torch.from_numpy(image).contiguous()
|
| 149 |
+
|
| 150 |
+
if image.ndim == 2:
|
| 151 |
+
image = image.unsqueeze(0)
|
| 152 |
+
|
| 153 |
+
if input_data_format is None:
|
| 154 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 155 |
+
|
| 156 |
+
if input_data_format == ChannelDimension.LAST:
|
| 157 |
+
image = image.permute(2, 0, 1).contiguous()
|
| 158 |
+
|
| 159 |
+
if device is not None:
|
| 160 |
+
image = image.to(device)
|
| 161 |
+
|
| 162 |
+
return image
|
| 163 |
+
|
| 164 |
+
def convert_to_rgb(self, image: ImageInput) -> ImageInput:
|
| 165 |
+
"""Convert an image to RGB format."""
|
| 166 |
+
return convert_to_rgb(image)
|
| 167 |
+
|
| 168 |
+
def pad(
|
| 169 |
+
self,
|
| 170 |
+
images: list["torch.Tensor"],
|
| 171 |
+
pad_size: SizeDict = None,
|
| 172 |
+
fill_value: int | None = 0,
|
| 173 |
+
padding_mode: str | None = "constant",
|
| 174 |
+
return_mask: bool = False,
|
| 175 |
+
disable_grouping: bool | None = False,
|
| 176 |
+
is_nested: bool | None = False,
|
| 177 |
+
**kwargs,
|
| 178 |
+
) -> Union[tuple["torch.Tensor", "torch.Tensor"], "torch.Tensor"]:
|
| 179 |
+
"""Pad images using Torchvision with batched operations."""
|
| 180 |
+
if pad_size is not None:
|
| 181 |
+
if not (pad_size.height and pad_size.width):
|
| 182 |
+
raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.")
|
| 183 |
+
pad_size = (pad_size.height, pad_size.width)
|
| 184 |
+
else:
|
| 185 |
+
pad_size = get_max_height_width(images)
|
| 186 |
+
|
| 187 |
+
grouped_images, grouped_images_index = group_images_by_shape(
|
| 188 |
+
images, disable_grouping=disable_grouping, is_nested=is_nested
|
| 189 |
+
)
|
| 190 |
+
processed_images_grouped = {}
|
| 191 |
+
processed_masks_grouped = {}
|
| 192 |
+
for shape, stacked_images in grouped_images.items():
|
| 193 |
+
image_size = stacked_images.shape[-2:]
|
| 194 |
+
padding_height = pad_size[0] - image_size[0]
|
| 195 |
+
padding_width = pad_size[1] - image_size[1]
|
| 196 |
+
if padding_height < 0 or padding_width < 0:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the "
|
| 199 |
+
f"image size. Got pad_size={pad_size}, image_size={image_size}."
|
| 200 |
+
)
|
| 201 |
+
if image_size != pad_size:
|
| 202 |
+
padding = (0, 0, padding_width, padding_height)
|
| 203 |
+
stacked_images = tvF.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode)
|
| 204 |
+
processed_images_grouped[shape] = stacked_images
|
| 205 |
+
|
| 206 |
+
if return_mask:
|
| 207 |
+
stacked_masks = torch.zeros_like(stacked_images, dtype=torch.int64)[..., 0, :, :]
|
| 208 |
+
stacked_masks[..., : image_size[0], : image_size[1]] = 1
|
| 209 |
+
processed_masks_grouped[shape] = stacked_masks
|
| 210 |
+
|
| 211 |
+
processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=is_nested)
|
| 212 |
+
if return_mask:
|
| 213 |
+
processed_masks = reorder_images(processed_masks_grouped, grouped_images_index, is_nested=is_nested)
|
| 214 |
+
return processed_images, processed_masks
|
| 215 |
+
|
| 216 |
+
return processed_images
|
| 217 |
+
|
| 218 |
+
def resize(
|
| 219 |
+
self,
|
| 220 |
+
image: "torch.Tensor",
|
| 221 |
+
size: SizeDict,
|
| 222 |
+
resample: "PILImageResampling | tvF.InterpolationMode | int | None" = None,
|
| 223 |
+
antialias: bool = True,
|
| 224 |
+
**kwargs,
|
| 225 |
+
) -> "torch.Tensor":
|
| 226 |
+
"""Resize an image using Torchvision."""
|
| 227 |
+
# Convert PIL resample to torchvision interpolation if needed
|
| 228 |
+
if resample is not None:
|
| 229 |
+
if isinstance(resample, (PILImageResampling, int)):
|
| 230 |
+
interpolation = pil_torch_interpolation_mapping[resample]
|
| 231 |
+
else:
|
| 232 |
+
interpolation = resample
|
| 233 |
+
else:
|
| 234 |
+
interpolation = tvF.InterpolationMode.BILINEAR
|
| 235 |
+
if interpolation == tvF.InterpolationMode.LANCZOS:
|
| 236 |
+
logger.warning_once(
|
| 237 |
+
"You have used a torchvision backend image processor with LANCZOS resample which not yet supported for torch.Tensor. "
|
| 238 |
+
"BICUBIC resample will be used as an alternative. Please fall back to a pil backend image processor if you "
|
| 239 |
+
"want full consistency with the original model."
|
| 240 |
+
)
|
| 241 |
+
interpolation = tvF.InterpolationMode.BICUBIC
|
| 242 |
+
|
| 243 |
+
if size.shortest_edge and size.longest_edge:
|
| 244 |
+
new_size = get_size_with_aspect_ratio(
|
| 245 |
+
image.size()[-2:],
|
| 246 |
+
size.shortest_edge,
|
| 247 |
+
size.longest_edge,
|
| 248 |
+
)
|
| 249 |
+
elif size.shortest_edge:
|
| 250 |
+
new_size = get_resize_output_image_size(
|
| 251 |
+
image,
|
| 252 |
+
size=size.shortest_edge,
|
| 253 |
+
default_to_square=False,
|
| 254 |
+
input_data_format=ChannelDimension.FIRST,
|
| 255 |
+
)
|
| 256 |
+
elif size.max_height and size.max_width:
|
| 257 |
+
new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width)
|
| 258 |
+
elif size.height and size.width:
|
| 259 |
+
new_size = (size.height, size.width)
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError(
|
| 262 |
+
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
|
| 263 |
+
f" {size}."
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Workaround for torch.compile issue with uint8 on AMD GPUs
|
| 267 |
+
if is_torchdynamo_compiling() and is_rocm_platform():
|
| 268 |
+
return self._compile_friendly_resize(image, new_size, interpolation, antialias)
|
| 269 |
+
return tvF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
| 270 |
+
|
| 271 |
+
@staticmethod
|
| 272 |
+
def _compile_friendly_resize(
|
| 273 |
+
image: "torch.Tensor",
|
| 274 |
+
new_size: tuple[int, int],
|
| 275 |
+
interpolation: Optional["tvF.InterpolationMode"] = None,
|
| 276 |
+
antialias: bool = True,
|
| 277 |
+
) -> "torch.Tensor":
|
| 278 |
+
"""A wrapper around tvF.resize for torch.compile compatibility with uint8 tensors."""
|
| 279 |
+
if image.dtype == torch.uint8:
|
| 280 |
+
image = image.float() / 256
|
| 281 |
+
image = tvF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
| 282 |
+
image = image * 256
|
| 283 |
+
image = torch.where(image > 255, 255, image)
|
| 284 |
+
image = torch.where(image < 0, 0, image)
|
| 285 |
+
image = image.round().to(torch.uint8)
|
| 286 |
+
else:
|
| 287 |
+
image = tvF.resize(image, new_size, interpolation=interpolation, antialias=antialias)
|
| 288 |
+
return image
|
| 289 |
+
|
| 290 |
+
def rescale(
|
| 291 |
+
self,
|
| 292 |
+
image: "torch.Tensor",
|
| 293 |
+
scale: float,
|
| 294 |
+
**kwargs,
|
| 295 |
+
) -> "torch.Tensor":
|
| 296 |
+
"""Rescale an image by a scale factor using Torchvision."""
|
| 297 |
+
return image * scale
|
| 298 |
+
|
| 299 |
+
def normalize(
|
| 300 |
+
self,
|
| 301 |
+
image: "torch.Tensor",
|
| 302 |
+
mean: float | Iterable[float],
|
| 303 |
+
std: float | Iterable[float],
|
| 304 |
+
**kwargs,
|
| 305 |
+
) -> "torch.Tensor":
|
| 306 |
+
"""Normalize an image using Torchvision."""
|
| 307 |
+
return tvF.normalize(image, mean, std)
|
| 308 |
+
|
| 309 |
+
@lru_cache(maxsize=10)
|
| 310 |
+
def _fuse_mean_std_and_rescale_factor(
|
| 311 |
+
self,
|
| 312 |
+
do_normalize: bool | None = None,
|
| 313 |
+
image_mean: float | list[float] | None = None,
|
| 314 |
+
image_std: float | list[float] | None = None,
|
| 315 |
+
do_rescale: bool | None = None,
|
| 316 |
+
rescale_factor: float | None = None,
|
| 317 |
+
device: Optional["torch.device"] = None,
|
| 318 |
+
) -> tuple:
|
| 319 |
+
if do_rescale and do_normalize:
|
| 320 |
+
# Fused rescale and normalize
|
| 321 |
+
image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
|
| 322 |
+
image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
|
| 323 |
+
do_rescale = False
|
| 324 |
+
return image_mean, image_std, do_rescale
|
| 325 |
+
|
| 326 |
+
def rescale_and_normalize(
|
| 327 |
+
self,
|
| 328 |
+
images: "torch.Tensor",
|
| 329 |
+
do_rescale: bool,
|
| 330 |
+
rescale_factor: float,
|
| 331 |
+
do_normalize: bool,
|
| 332 |
+
image_mean: float | list[float],
|
| 333 |
+
image_std: float | list[float],
|
| 334 |
+
) -> "torch.Tensor":
|
| 335 |
+
"""Rescale and normalize images using Torchvision (fused for efficiency)."""
|
| 336 |
+
image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
|
| 337 |
+
do_normalize=do_normalize,
|
| 338 |
+
image_mean=image_mean,
|
| 339 |
+
image_std=image_std,
|
| 340 |
+
do_rescale=do_rescale,
|
| 341 |
+
rescale_factor=rescale_factor,
|
| 342 |
+
device=images.device,
|
| 343 |
+
)
|
| 344 |
+
if do_normalize:
|
| 345 |
+
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
|
| 346 |
+
elif do_rescale:
|
| 347 |
+
images = self.rescale(images, rescale_factor)
|
| 348 |
+
|
| 349 |
+
return images
|
| 350 |
+
|
| 351 |
+
def center_crop(
|
| 352 |
+
self,
|
| 353 |
+
image: "torch.Tensor",
|
| 354 |
+
size: SizeDict,
|
| 355 |
+
**kwargs,
|
| 356 |
+
) -> "torch.Tensor":
|
| 357 |
+
"""Center crop an image using Torchvision."""
|
| 358 |
+
if size.height is None or size.width is None:
|
| 359 |
+
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
|
| 360 |
+
image_height, image_width = image.shape[-2:]
|
| 361 |
+
crop_height, crop_width = size.height, size.width
|
| 362 |
+
|
| 363 |
+
if crop_width > image_width or crop_height > image_height:
|
| 364 |
+
padding_ltrb = [
|
| 365 |
+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
| 366 |
+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
| 367 |
+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
| 368 |
+
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
| 369 |
+
]
|
| 370 |
+
image = tvF.pad(image, padding_ltrb, fill=0)
|
| 371 |
+
image_height, image_width = image.shape[-2:]
|
| 372 |
+
if crop_width == image_width and crop_height == image_height:
|
| 373 |
+
return image
|
| 374 |
+
|
| 375 |
+
crop_top = int((image_height - crop_height) / 2.0)
|
| 376 |
+
crop_left = int((image_width - crop_width) / 2.0)
|
| 377 |
+
return tvF.crop(image, crop_top, crop_left, crop_height, crop_width)
|
| 378 |
+
|
| 379 |
+
def _preprocess(
|
| 380 |
+
self,
|
| 381 |
+
images: list["torch.Tensor"],
|
| 382 |
+
do_resize: bool,
|
| 383 |
+
size: SizeDict,
|
| 384 |
+
resample: "PILImageResampling | tvF.InterpolationMode | int | None",
|
| 385 |
+
do_center_crop: bool,
|
| 386 |
+
crop_size: SizeDict,
|
| 387 |
+
do_rescale: bool,
|
| 388 |
+
rescale_factor: float,
|
| 389 |
+
do_normalize: bool,
|
| 390 |
+
image_mean: float | list[float] | None,
|
| 391 |
+
image_std: float | list[float] | None,
|
| 392 |
+
do_pad: bool | None,
|
| 393 |
+
pad_size: SizeDict | None,
|
| 394 |
+
disable_grouping: bool | None,
|
| 395 |
+
return_tensors: str | TensorType | None,
|
| 396 |
+
**kwargs,
|
| 397 |
+
) -> BatchFeature:
|
| 398 |
+
"""Preprocess using Torchvision backend (fast, GPU-accelerated)."""
|
| 399 |
+
# Group images by size for batched resizing
|
| 400 |
+
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
| 401 |
+
resized_images_grouped = {}
|
| 402 |
+
for shape, stacked_images in grouped_images.items():
|
| 403 |
+
if do_resize:
|
| 404 |
+
stacked_images = self.resize(image=stacked_images, size=size, resample=resample)
|
| 405 |
+
resized_images_grouped[shape] = stacked_images
|
| 406 |
+
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
| 407 |
+
|
| 408 |
+
# Group images by size for further processing
|
| 409 |
+
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
| 410 |
+
processed_images_grouped = {}
|
| 411 |
+
for shape, stacked_images in grouped_images.items():
|
| 412 |
+
if do_center_crop:
|
| 413 |
+
stacked_images = self.center_crop(stacked_images, crop_size)
|
| 414 |
+
# Fused rescale and normalize
|
| 415 |
+
stacked_images = self.rescale_and_normalize(
|
| 416 |
+
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
| 417 |
+
)
|
| 418 |
+
processed_images_grouped[shape] = stacked_images
|
| 419 |
+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
| 420 |
+
|
| 421 |
+
if do_pad:
|
| 422 |
+
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
|
| 423 |
+
|
| 424 |
+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@requires(backends=("vision",))
|
| 428 |
+
class PilBackend(BaseImageProcessor):
|
| 429 |
+
"""PIL/NumPy backend for portable CPU-only image processing."""
|
| 430 |
+
|
| 431 |
+
def __init__(self, **kwargs: Unpack[ImagesKwargs]):
|
| 432 |
+
super().__init__(**kwargs)
|
| 433 |
+
self._set_attributes(**kwargs)
|
| 434 |
+
|
| 435 |
+
@property
|
| 436 |
+
def is_fast(self) -> bool:
|
| 437 |
+
"""
|
| 438 |
+
`bool`: Whether or not this image processor is using the fast (Torchvision) backend.
|
| 439 |
+
The `is_fast` property is deprecated and will be removed in v5.3 of Transformers.
|
| 440 |
+
Use the `backend` attribute instead (e.g., `processor.backend == "torchvision"`).
|
| 441 |
+
"""
|
| 442 |
+
logger.warning_once(
|
| 443 |
+
"The `is_fast` property is deprecated and will be removed in v5.3 of Transformers. "
|
| 444 |
+
"Use the `backend` attribute instead (e.g., `processor.backend == 'torchvision'`)."
|
| 445 |
+
)
|
| 446 |
+
return False
|
| 447 |
+
|
| 448 |
+
@property
|
| 449 |
+
def backend(self) -> str:
|
| 450 |
+
"""
|
| 451 |
+
`str`: The backend used by this image processor.
|
| 452 |
+
"""
|
| 453 |
+
return "pil"
|
| 454 |
+
|
| 455 |
+
def process_image(
|
| 456 |
+
self,
|
| 457 |
+
image: ImageInput,
|
| 458 |
+
do_convert_rgb: bool | None = None,
|
| 459 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 460 |
+
**kwargs: Unpack[ImagesKwargs],
|
| 461 |
+
) -> np.ndarray:
|
| 462 |
+
"""Process a single image for PIL backend."""
|
| 463 |
+
image_type = get_image_type(image)
|
| 464 |
+
if image_type not in [ImageType.PIL, ImageType.TORCH, ImageType.NUMPY]:
|
| 465 |
+
raise ValueError(f"Unsupported input image type {image_type}")
|
| 466 |
+
|
| 467 |
+
if do_convert_rgb:
|
| 468 |
+
image = self.convert_to_rgb(image)
|
| 469 |
+
|
| 470 |
+
if image_type == ImageType.PIL:
|
| 471 |
+
image = np.array(image)
|
| 472 |
+
# Set LAST only for multi-channel PIL images (H, W, C); for grayscale (H, W), leave as is to avoid shape errors after expand_dims.
|
| 473 |
+
if image.ndim >= 3:
|
| 474 |
+
input_data_format = ChannelDimension.LAST if input_data_format is None else input_data_format
|
| 475 |
+
elif image_type == ImageType.TORCH:
|
| 476 |
+
image = image.numpy()
|
| 477 |
+
|
| 478 |
+
if image.ndim == 2:
|
| 479 |
+
image = np.expand_dims(image, axis=0)
|
| 480 |
+
|
| 481 |
+
if input_data_format is None:
|
| 482 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 483 |
+
|
| 484 |
+
if input_data_format == ChannelDimension.LAST:
|
| 485 |
+
# Convert from channels-last to channels-first
|
| 486 |
+
if isinstance(image, np.ndarray):
|
| 487 |
+
image = np.transpose(image, (2, 0, 1))
|
| 488 |
+
|
| 489 |
+
return image
|
| 490 |
+
|
| 491 |
+
def convert_to_rgb(self, image: ImageInput) -> ImageInput:
|
| 492 |
+
"""Convert an image to RGB format."""
|
| 493 |
+
return convert_to_rgb(image)
|
| 494 |
+
|
| 495 |
+
def pad(
|
| 496 |
+
self,
|
| 497 |
+
images: list[np.ndarray],
|
| 498 |
+
pad_size: SizeDict = None,
|
| 499 |
+
fill_value: int | None = 0,
|
| 500 |
+
padding_mode: str | None = "constant",
|
| 501 |
+
return_mask: bool = False,
|
| 502 |
+
**kwargs,
|
| 503 |
+
) -> tuple[list[np.ndarray], list[np.ndarray]] | list[np.ndarray]:
|
| 504 |
+
"""Pad images to specified size using NumPy."""
|
| 505 |
+
if pad_size is not None:
|
| 506 |
+
if not (pad_size.height and pad_size.width):
|
| 507 |
+
raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.")
|
| 508 |
+
target_height, target_width = pad_size.height, pad_size.width
|
| 509 |
+
else:
|
| 510 |
+
target_height, target_width = get_max_height_width(images)
|
| 511 |
+
|
| 512 |
+
processed_images = []
|
| 513 |
+
processed_masks = []
|
| 514 |
+
|
| 515 |
+
for image in images:
|
| 516 |
+
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
| 517 |
+
padding_height = target_height - height
|
| 518 |
+
padding_width = target_width - width
|
| 519 |
+
|
| 520 |
+
if padding_height < 0 or padding_width < 0:
|
| 521 |
+
raise ValueError(
|
| 522 |
+
f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the "
|
| 523 |
+
f"image size. Got pad_size=({target_height}, {target_width}), image_size=({height}, {width})."
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
if height != target_height or width != target_width:
|
| 527 |
+
# Pad format: ((before_1, after_1), (before_2, after_2), ...)
|
| 528 |
+
# For CHW format: ((0, 0), (0, padding_height), (0, padding_width))
|
| 529 |
+
pad_width = ((0, 0), (0, padding_height), (0, padding_width))
|
| 530 |
+
if padding_mode == "constant":
|
| 531 |
+
image = np.pad(image, pad_width, mode="constant", constant_values=fill_value)
|
| 532 |
+
else:
|
| 533 |
+
image = np.pad(image, pad_width, mode=padding_mode)
|
| 534 |
+
|
| 535 |
+
processed_images.append(image)
|
| 536 |
+
|
| 537 |
+
if return_mask:
|
| 538 |
+
mask = np.zeros((target_height, target_width), dtype=np.int64)
|
| 539 |
+
mask[:height, :width] = 1
|
| 540 |
+
processed_masks.append(mask)
|
| 541 |
+
|
| 542 |
+
if return_mask:
|
| 543 |
+
return processed_images, processed_masks
|
| 544 |
+
return processed_images
|
| 545 |
+
|
| 546 |
+
def resize(
|
| 547 |
+
self,
|
| 548 |
+
image: np.ndarray,
|
| 549 |
+
size: SizeDict,
|
| 550 |
+
resample: "PILImageResampling | None" = None,
|
| 551 |
+
reducing_gap: int | None = None,
|
| 552 |
+
**kwargs,
|
| 553 |
+
) -> np.ndarray:
|
| 554 |
+
"""Resize an image using PIL/NumPy."""
|
| 555 |
+
# PIL backend only supports PILImageResampling
|
| 556 |
+
if resample is not None and not isinstance(resample, (PILImageResampling, int)):
|
| 557 |
+
if torch_pil_interpolation_mapping is not None and resample in torch_pil_interpolation_mapping:
|
| 558 |
+
resample = torch_pil_interpolation_mapping[resample]
|
| 559 |
+
else:
|
| 560 |
+
resample = PILImageResampling.BILINEAR
|
| 561 |
+
resample = resample if resample is not None else PILImageResampling.BILINEAR
|
| 562 |
+
|
| 563 |
+
if size.shortest_edge and size.longest_edge:
|
| 564 |
+
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
| 565 |
+
new_size = get_size_with_aspect_ratio(
|
| 566 |
+
(height, width),
|
| 567 |
+
size.shortest_edge,
|
| 568 |
+
size.longest_edge,
|
| 569 |
+
)
|
| 570 |
+
elif size.shortest_edge:
|
| 571 |
+
new_size = get_resize_output_image_size(
|
| 572 |
+
image,
|
| 573 |
+
size=size.shortest_edge,
|
| 574 |
+
default_to_square=False,
|
| 575 |
+
input_data_format=ChannelDimension.FIRST,
|
| 576 |
+
)
|
| 577 |
+
elif size.max_height and size.max_width:
|
| 578 |
+
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
| 579 |
+
new_size = get_image_size_for_max_height_width((height, width), size.max_height, size.max_width)
|
| 580 |
+
elif size.height and size.width:
|
| 581 |
+
new_size = (size.height, size.width)
|
| 582 |
+
else:
|
| 583 |
+
raise ValueError(
|
| 584 |
+
"Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got"
|
| 585 |
+
f" {size}."
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
return np_resize(
|
| 589 |
+
image,
|
| 590 |
+
size=new_size,
|
| 591 |
+
resample=resample,
|
| 592 |
+
reducing_gap=reducing_gap,
|
| 593 |
+
data_format=ChannelDimension.FIRST,
|
| 594 |
+
input_data_format=ChannelDimension.FIRST,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
def rescale(
|
| 598 |
+
self,
|
| 599 |
+
image: np.ndarray,
|
| 600 |
+
scale: float,
|
| 601 |
+
**kwargs,
|
| 602 |
+
) -> np.ndarray:
|
| 603 |
+
"""Rescale an image by a scale factor using NumPy."""
|
| 604 |
+
return np_rescale(
|
| 605 |
+
image,
|
| 606 |
+
scale=scale,
|
| 607 |
+
data_format=ChannelDimension.FIRST,
|
| 608 |
+
input_data_format=ChannelDimension.FIRST,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
def normalize(
|
| 612 |
+
self,
|
| 613 |
+
image: np.ndarray,
|
| 614 |
+
mean: float | Iterable[float],
|
| 615 |
+
std: float | Iterable[float],
|
| 616 |
+
**kwargs,
|
| 617 |
+
) -> np.ndarray:
|
| 618 |
+
"""Normalize an image using NumPy."""
|
| 619 |
+
return np_normalize(
|
| 620 |
+
image,
|
| 621 |
+
mean=mean,
|
| 622 |
+
std=std,
|
| 623 |
+
data_format=ChannelDimension.FIRST,
|
| 624 |
+
input_data_format=ChannelDimension.FIRST,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
def center_crop(
|
| 628 |
+
self,
|
| 629 |
+
image: np.ndarray,
|
| 630 |
+
size: SizeDict,
|
| 631 |
+
**kwargs,
|
| 632 |
+
) -> np.ndarray:
|
| 633 |
+
"""Center crop an image using NumPy."""
|
| 634 |
+
if size.height is None or size.width is None:
|
| 635 |
+
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
|
| 636 |
+
|
| 637 |
+
return np_center_crop(
|
| 638 |
+
image,
|
| 639 |
+
size=(size.height, size.width),
|
| 640 |
+
data_format=ChannelDimension.FIRST,
|
| 641 |
+
input_data_format=ChannelDimension.FIRST,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
def _preprocess(
|
| 645 |
+
self,
|
| 646 |
+
images: list[np.ndarray],
|
| 647 |
+
do_resize: bool,
|
| 648 |
+
size: SizeDict,
|
| 649 |
+
resample: "PILImageResampling | None",
|
| 650 |
+
do_center_crop: bool,
|
| 651 |
+
crop_size: SizeDict,
|
| 652 |
+
do_rescale: bool,
|
| 653 |
+
rescale_factor: float,
|
| 654 |
+
do_normalize: bool,
|
| 655 |
+
image_mean: float | list[float] | None,
|
| 656 |
+
image_std: float | list[float] | None,
|
| 657 |
+
do_pad: bool | None,
|
| 658 |
+
pad_size: SizeDict | None,
|
| 659 |
+
return_tensors: str | TensorType | None,
|
| 660 |
+
**kwargs,
|
| 661 |
+
) -> BatchFeature:
|
| 662 |
+
"""Preprocess using PIL backend (portable, CPU-only)."""
|
| 663 |
+
processed_images = []
|
| 664 |
+
for image in images:
|
| 665 |
+
if do_resize:
|
| 666 |
+
image = self.resize(image=image, size=size, resample=resample)
|
| 667 |
+
if do_center_crop:
|
| 668 |
+
image = self.center_crop(image, crop_size)
|
| 669 |
+
if do_rescale:
|
| 670 |
+
image = self.rescale(image, rescale_factor)
|
| 671 |
+
if do_normalize:
|
| 672 |
+
image = self.normalize(image, image_mean, image_std)
|
| 673 |
+
processed_images.append(image)
|
| 674 |
+
|
| 675 |
+
if do_pad:
|
| 676 |
+
processed_images = self.pad(processed_images, pad_size=pad_size)
|
| 677 |
+
|
| 678 |
+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
| 679 |
+
|
| 680 |
+
def to_dict(self) -> dict[str, Any]:
|
| 681 |
+
processor_dict = super().to_dict()
|
| 682 |
+
# Remove the "Pil" suffix from the image processor type
|
| 683 |
+
if processor_dict.get("image_processor_type", "").endswith("Pil"):
|
| 684 |
+
processor_dict["image_processor_type"] = processor_dict["image_processor_type"][:-3]
|
| 685 |
+
return processor_dict
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
# Backward-compatible alias: allow referring to TorchvisionBackend as BaseImageProcessorFast
|
| 689 |
+
BaseImageProcessorFast = TorchvisionBackend
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_processing_utils.py
ADDED
|
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from collections.abc import Iterable
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
from functools import partial
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
from huggingface_hub.dataclasses import validate_typed_dict
|
| 23 |
+
|
| 24 |
+
from .image_processing_base import BatchFeature, ImageProcessingMixin
|
| 25 |
+
from .image_transforms import center_crop, normalize, rescale
|
| 26 |
+
from .image_utils import (
|
| 27 |
+
ChannelDimension,
|
| 28 |
+
ImageInput,
|
| 29 |
+
SizeDict,
|
| 30 |
+
get_image_size,
|
| 31 |
+
make_flat_list_of_images,
|
| 32 |
+
validate_preprocess_arguments,
|
| 33 |
+
)
|
| 34 |
+
from .processing_utils import ImagesKwargs, Unpack
|
| 35 |
+
from .utils import (
|
| 36 |
+
auto_docstring,
|
| 37 |
+
is_torchvision_available,
|
| 38 |
+
is_vision_available,
|
| 39 |
+
logging,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if is_vision_available():
|
| 44 |
+
from .image_utils import PILImageResampling
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if is_torchvision_available():
|
| 48 |
+
from torchvision.transforms.v2 import functional as tvF
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
logger = logging.get_logger(__name__)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
INIT_SERVICE_KWARGS = [
|
| 55 |
+
"processor_class",
|
| 56 |
+
"image_processor_type",
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class BaseImageProcessor(ImageProcessingMixin):
|
| 61 |
+
r"""
|
| 62 |
+
Base class for image processors with an inheritance-based backend architecture.
|
| 63 |
+
|
| 64 |
+
This class defines the preprocessing pipeline: kwargs validation, input preparation, and dispatching to the
|
| 65 |
+
backend's `_preprocess` method. Backend subclasses (`TorchvisionBackend`, `PilBackend`) inherit from this class
|
| 66 |
+
and implement the actual image operations (resize, crop, rescale, normalize, etc.). Model-specific image
|
| 67 |
+
processors then inherit from the appropriate backend class.
|
| 68 |
+
|
| 69 |
+
Architecture Overview
|
| 70 |
+
---------------------
|
| 71 |
+
|
| 72 |
+
The class hierarchy is:
|
| 73 |
+
|
| 74 |
+
BaseImageProcessor (this class)
|
| 75 |
+
├── TorchvisionBackend (GPU-accelerated, torch.Tensor)
|
| 76 |
+
│ └── ModelImageProcessor (e.g. LlavaNextImageProcessor)
|
| 77 |
+
└── PilBackend (portable CPU, np.ndarray)
|
| 78 |
+
└── ModelImageProcessorPil (e.g. CLIPImageProcessorPil)
|
| 79 |
+
|
| 80 |
+
The preprocessing flow is:
|
| 81 |
+
|
| 82 |
+
__call__() → preprocess() → _preprocess_image_like_inputs() → _prepare_image_like_inputs()
|
| 83 |
+
(calls process_image per image)
|
| 84 |
+
→ _preprocess()
|
| 85 |
+
(batch operations: resize, crop, etc.)
|
| 86 |
+
|
| 87 |
+
- `process_image`: Implemented by backends. Converts a single raw input (PIL, NumPy, or Tensor) to the
|
| 88 |
+
backend's working format (torch.Tensor or np.ndarray), handles RGB conversion and channel reordering.
|
| 89 |
+
- `_preprocess`: Implemented by backends. Performs the actual batch processing (resize, center crop, rescale,
|
| 90 |
+
normalize, pad) and returns a `BatchFeature`.
|
| 91 |
+
|
| 92 |
+
Basic Implementation
|
| 93 |
+
--------------------
|
| 94 |
+
|
| 95 |
+
For processors that only need standard operations (resize, center crop, rescale, normalize), inherit from
|
| 96 |
+
a backend and define class attributes:
|
| 97 |
+
|
| 98 |
+
from transformers.image_processing_backends import PilBackend
|
| 99 |
+
|
| 100 |
+
class MyImageProcessorPil(PilBackend):
|
| 101 |
+
resample = PILImageResampling.BILINEAR
|
| 102 |
+
image_mean = IMAGENET_DEFAULT_MEAN
|
| 103 |
+
image_std = IMAGENET_DEFAULT_STD
|
| 104 |
+
size = {"height": 224, "width": 224}
|
| 105 |
+
do_resize = True
|
| 106 |
+
do_rescale = True
|
| 107 |
+
do_normalize = True
|
| 108 |
+
|
| 109 |
+
The backend's `_preprocess` method handles the standard pipeline automatically.
|
| 110 |
+
|
| 111 |
+
Custom Processing
|
| 112 |
+
-----------------
|
| 113 |
+
|
| 114 |
+
For processors that need custom logic (e.g., patch-based processing, multiple input types), override
|
| 115 |
+
`_preprocess` in your model-specific processor. The `_preprocess` method receives already-prepared images
|
| 116 |
+
(converted to the backend format with channels-first ordering) and performs the actual processing:
|
| 117 |
+
|
| 118 |
+
class MyImageProcessor(TorchvisionBackend):
|
| 119 |
+
def _preprocess(self, images, do_resize, size, do_normalize, image_mean, image_std, **kwargs):
|
| 120 |
+
# Group images by shape for efficient batched operations
|
| 121 |
+
grouped_images, grouped_images_index = group_images_by_shape(images)
|
| 122 |
+
processed_groups = {}
|
| 123 |
+
for shape, stacked_images in grouped_images.items():
|
| 124 |
+
if do_resize:
|
| 125 |
+
stacked_images = self.resize(stacked_images, size=size)
|
| 126 |
+
if do_normalize:
|
| 127 |
+
stacked_images = self.normalize(stacked_images, mean=image_mean, std=image_std)
|
| 128 |
+
processed_groups[shape] = stacked_images
|
| 129 |
+
processed_images = reorder_images(processed_groups, grouped_images_index)
|
| 130 |
+
return BatchFeature(data={"pixel_values": processed_images})
|
| 131 |
+
|
| 132 |
+
For processors handling multiple input types (e.g., images + segmentation maps), override
|
| 133 |
+
`_preprocess_image_like_inputs`:
|
| 134 |
+
|
| 135 |
+
def _preprocess_image_like_inputs(
|
| 136 |
+
self,
|
| 137 |
+
images: ImageInput,
|
| 138 |
+
segmentation_maps: ImageInput | None = None,
|
| 139 |
+
**kwargs,
|
| 140 |
+
) -> BatchFeature:
|
| 141 |
+
images = self._prepare_image_like_inputs(images, **kwargs)
|
| 142 |
+
batch_feature = self._preprocess(images, **kwargs)
|
| 143 |
+
|
| 144 |
+
if segmentation_maps is not None:
|
| 145 |
+
maps = self._prepare_image_like_inputs(segmentation_maps, **kwargs)
|
| 146 |
+
batch_feature["labels"] = self._preprocess(maps, **kwargs).pixel_values
|
| 147 |
+
|
| 148 |
+
return batch_feature
|
| 149 |
+
|
| 150 |
+
Extending Backend Behavior
|
| 151 |
+
--------------------------
|
| 152 |
+
|
| 153 |
+
To customize operations for a specific backend, subclass the backend and override its methods:
|
| 154 |
+
|
| 155 |
+
from transformers.image_processing_backends import TorchvisionBackend, PilBackend
|
| 156 |
+
|
| 157 |
+
class MyTorchvisionProcessor(TorchvisionBackend):
|
| 158 |
+
def resize(self, image, size, **kwargs):
|
| 159 |
+
# Custom resize logic for torchvision
|
| 160 |
+
return super().resize(image, size, **kwargs)
|
| 161 |
+
|
| 162 |
+
class MyPilProcessor(PilBackend):
|
| 163 |
+
def resize(self, image, size, **kwargs):
|
| 164 |
+
# Custom resize logic for PIL
|
| 165 |
+
return super().resize(image, size, **kwargs)
|
| 166 |
+
|
| 167 |
+
Custom Parameters
|
| 168 |
+
-----------------
|
| 169 |
+
|
| 170 |
+
To add parameters beyond `ImagesKwargs`, create a custom kwargs class and set it as `valid_kwargs`:
|
| 171 |
+
|
| 172 |
+
class MyImageProcessorKwargs(ImagesKwargs):
|
| 173 |
+
custom_param: int | None = None
|
| 174 |
+
|
| 175 |
+
class MyImageProcessor(TorchvisionBackend):
|
| 176 |
+
valid_kwargs = MyImageProcessorKwargs
|
| 177 |
+
custom_param = 10 # default value
|
| 178 |
+
|
| 179 |
+
Key Notes
|
| 180 |
+
---------
|
| 181 |
+
|
| 182 |
+
- Backend selection is done at the class level: inherit from `TorchvisionBackend` or `PilBackend`
|
| 183 |
+
- Backends receive images as `torch.Tensor` (Torchvision) or `np.ndarray` (PIL), always channels-first
|
| 184 |
+
- All images have channel dimension first during processing, regardless of backend
|
| 185 |
+
- Arguments not provided by users default to class attribute values
|
| 186 |
+
- Backend classes encapsulate backend-specific logic (resize, normalize, etc.) and can be overridden
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
valid_kwargs = ImagesKwargs
|
| 190 |
+
|
| 191 |
+
default_to_square = True
|
| 192 |
+
rescale_factor = 1 / 255
|
| 193 |
+
model_input_names = ["pixel_values"]
|
| 194 |
+
|
| 195 |
+
def __init__(self, **kwargs: Unpack[ImagesKwargs]):
|
| 196 |
+
super().__init__(**kwargs)
|
| 197 |
+
# We don't call self._set_attributes in BaseImageProcessor for backward compatibility with remote code
|
| 198 |
+
# We call it instead in the backend subclasses' __init__ methods.
|
| 199 |
+
|
| 200 |
+
def _set_attributes(self, **kwargs):
|
| 201 |
+
"""Resolve and set instance attributes from kwargs and class-level defaults for all valid kwargs."""
|
| 202 |
+
attributes = {}
|
| 203 |
+
for key in self.valid_kwargs.__annotations__:
|
| 204 |
+
kwarg = kwargs.pop(key, None)
|
| 205 |
+
if kwarg is not None:
|
| 206 |
+
attributes[key] = kwarg
|
| 207 |
+
else:
|
| 208 |
+
attributes[key] = deepcopy(getattr(self, key, None))
|
| 209 |
+
attributes = self._standardize_kwargs(**attributes)
|
| 210 |
+
for key, value in attributes.items():
|
| 211 |
+
setattr(self, key, value)
|
| 212 |
+
|
| 213 |
+
self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
|
| 214 |
+
|
| 215 |
+
def __call__(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
|
| 216 |
+
"""Preprocess an image or a batch of images."""
|
| 217 |
+
return self.preprocess(images, *args, **kwargs)
|
| 218 |
+
|
| 219 |
+
def process_image(self, *args, **kwargs):
|
| 220 |
+
"""
|
| 221 |
+
Process a single raw image into the backend's working format.
|
| 222 |
+
|
| 223 |
+
Implemented by backend subclasses (`TorchvisionBackend`, `PilBackend`). Converts a raw input
|
| 224 |
+
(PIL Image, NumPy array, or torch Tensor) to the backend's internal format (`torch.Tensor` for
|
| 225 |
+
Torchvision, `np.ndarray` for PIL), handles RGB conversion and ensures channels-first ordering.
|
| 226 |
+
"""
|
| 227 |
+
raise NotImplementedError
|
| 228 |
+
|
| 229 |
+
def _preprocess(self, *args, **kwargs):
|
| 230 |
+
"""
|
| 231 |
+
Perform the actual batch image preprocessing (resize, center crop, rescale, normalize, pad).
|
| 232 |
+
|
| 233 |
+
Implemented by backend subclasses (`TorchvisionBackend`, `PilBackend`). Receives a list of
|
| 234 |
+
already-prepared images (in the backend's format, channels-first) and applies the configured
|
| 235 |
+
preprocessing operations. Returns a `BatchFeature` with the processed pixel values.
|
| 236 |
+
|
| 237 |
+
Model-specific processors can override this method to implement custom preprocessing logic
|
| 238 |
+
(e.g., patch-based processing in LLaVA-NeXT).
|
| 239 |
+
"""
|
| 240 |
+
raise NotImplementedError
|
| 241 |
+
|
| 242 |
+
def _prepare_images_structure(
|
| 243 |
+
self,
|
| 244 |
+
images: ImageInput,
|
| 245 |
+
expected_ndims: int = 3,
|
| 246 |
+
) -> ImageInput:
|
| 247 |
+
"""
|
| 248 |
+
Prepare the images structure for processing.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
images (`ImageInput`):
|
| 252 |
+
The input images to process.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
`ImageInput`: The images with a valid nesting.
|
| 256 |
+
"""
|
| 257 |
+
images = self.fetch_images(images)
|
| 258 |
+
return make_flat_list_of_images(images, expected_ndims=expected_ndims)
|
| 259 |
+
|
| 260 |
+
def _prepare_image_like_inputs(
|
| 261 |
+
self,
|
| 262 |
+
images: ImageInput,
|
| 263 |
+
*args,
|
| 264 |
+
expected_ndims: int = 3,
|
| 265 |
+
**kwargs: Unpack[ImagesKwargs],
|
| 266 |
+
) -> list[Any]:
|
| 267 |
+
"""
|
| 268 |
+
Prepare image-like inputs for processing by converting each image via `process_image`.
|
| 269 |
+
|
| 270 |
+
Flattens the input structure and applies `process_image` (implemented by the backend) to each
|
| 271 |
+
individual image, converting raw inputs (PIL, NumPy, Tensor) into the backend's working format
|
| 272 |
+
with channels-first ordering.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
images (`ImageInput`):
|
| 276 |
+
The image-like inputs to process.
|
| 277 |
+
expected_ndims (`int`, *optional*, defaults to 3):
|
| 278 |
+
The expected number of dimensions for the images.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
`list[torch.Tensor]` or `list[np.ndarray]`: The prepared images in the backend's format,
|
| 282 |
+
with channels-first ordering.
|
| 283 |
+
"""
|
| 284 |
+
images = self._prepare_images_structure(images, expected_ndims=expected_ndims)
|
| 285 |
+
|
| 286 |
+
process_image_partial = partial(self.process_image, *args, **kwargs)
|
| 287 |
+
|
| 288 |
+
has_nested_structure = len(images) > 0 and isinstance(images[0], list | tuple)
|
| 289 |
+
|
| 290 |
+
if has_nested_structure:
|
| 291 |
+
processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
|
| 292 |
+
else:
|
| 293 |
+
processed_images = [process_image_partial(img) for img in images]
|
| 294 |
+
|
| 295 |
+
return processed_images
|
| 296 |
+
|
| 297 |
+
def _preprocess_image_like_inputs(
|
| 298 |
+
self,
|
| 299 |
+
images: ImageInput,
|
| 300 |
+
*args,
|
| 301 |
+
**kwargs: Unpack[ImagesKwargs],
|
| 302 |
+
) -> BatchFeature:
|
| 303 |
+
"""
|
| 304 |
+
Preprocess image-like inputs by preparing them and dispatching to `_preprocess`.
|
| 305 |
+
|
| 306 |
+
This method first calls `_prepare_image_like_inputs` to convert raw inputs into the backend's
|
| 307 |
+
format, then calls `_preprocess` for the actual batch processing. Override this method in
|
| 308 |
+
model-specific processors that need to handle multiple image-like input types (e.g., images
|
| 309 |
+
and segmentation maps) or need custom orchestration of the preprocessing pipeline.
|
| 310 |
+
"""
|
| 311 |
+
images = self._prepare_image_like_inputs(images, **kwargs)
|
| 312 |
+
return self._preprocess(images, *args, **kwargs)
|
| 313 |
+
|
| 314 |
+
def _standardize_kwargs(
|
| 315 |
+
self,
|
| 316 |
+
size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
|
| 317 |
+
crop_size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
|
| 318 |
+
pad_size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
|
| 319 |
+
default_to_square: bool | None = None,
|
| 320 |
+
image_mean: float | list[float] | None = None,
|
| 321 |
+
image_std: float | list[float] | None = None,
|
| 322 |
+
**kwargs,
|
| 323 |
+
) -> dict:
|
| 324 |
+
"""
|
| 325 |
+
Standardize kwargs to canonical format before validation.
|
| 326 |
+
Can be overridden by subclasses to customize the processing of kwargs.
|
| 327 |
+
"""
|
| 328 |
+
if kwargs is None:
|
| 329 |
+
kwargs = {}
|
| 330 |
+
if size is not None and not isinstance(size, SizeDict):
|
| 331 |
+
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
|
| 332 |
+
if crop_size is not None and not isinstance(crop_size, SizeDict):
|
| 333 |
+
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
|
| 334 |
+
if pad_size is not None and not isinstance(pad_size, SizeDict):
|
| 335 |
+
pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size"))
|
| 336 |
+
if isinstance(image_mean, list):
|
| 337 |
+
image_mean = tuple(image_mean)
|
| 338 |
+
if isinstance(image_std, list):
|
| 339 |
+
image_std = tuple(image_std)
|
| 340 |
+
|
| 341 |
+
kwargs["size"] = size
|
| 342 |
+
kwargs["crop_size"] = crop_size
|
| 343 |
+
kwargs["pad_size"] = pad_size
|
| 344 |
+
kwargs["image_mean"] = image_mean
|
| 345 |
+
kwargs["image_std"] = image_std
|
| 346 |
+
|
| 347 |
+
return kwargs
|
| 348 |
+
|
| 349 |
+
# Backwards compatibility for method that was renamed
|
| 350 |
+
_further_process_kwargs = _standardize_kwargs
|
| 351 |
+
|
| 352 |
+
def _validate_preprocess_kwargs(
|
| 353 |
+
self,
|
| 354 |
+
do_rescale: bool | None = None,
|
| 355 |
+
rescale_factor: float | None = None,
|
| 356 |
+
do_normalize: bool | None = None,
|
| 357 |
+
image_mean: float | tuple[float] | None = None,
|
| 358 |
+
image_std: float | tuple[float] | None = None,
|
| 359 |
+
do_resize: bool | None = None,
|
| 360 |
+
size: SizeDict | None = None,
|
| 361 |
+
do_center_crop: bool | None = None,
|
| 362 |
+
crop_size: SizeDict | None = None,
|
| 363 |
+
resample: "PILImageResampling | tvF.InterpolationMode | int | None" = None,
|
| 364 |
+
**kwargs,
|
| 365 |
+
):
|
| 366 |
+
"""
|
| 367 |
+
Validate the kwargs for the preprocess method.
|
| 368 |
+
"""
|
| 369 |
+
validate_preprocess_arguments(
|
| 370 |
+
do_rescale=do_rescale,
|
| 371 |
+
rescale_factor=rescale_factor,
|
| 372 |
+
do_normalize=do_normalize,
|
| 373 |
+
image_mean=image_mean,
|
| 374 |
+
image_std=image_std,
|
| 375 |
+
do_center_crop=do_center_crop,
|
| 376 |
+
crop_size=crop_size,
|
| 377 |
+
do_resize=do_resize,
|
| 378 |
+
size=size,
|
| 379 |
+
resample=resample,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
@auto_docstring
|
| 383 |
+
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
|
| 384 |
+
"""
|
| 385 |
+
Preprocess an image or a batch of images.
|
| 386 |
+
"""
|
| 387 |
+
# Perform type validation on received kwargs
|
| 388 |
+
validate_typed_dict(self.valid_kwargs, kwargs)
|
| 389 |
+
|
| 390 |
+
# Set default kwargs from self
|
| 391 |
+
for kwarg_name in self._valid_kwargs_names:
|
| 392 |
+
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
| 393 |
+
|
| 394 |
+
# Update kwargs that need further processing before being validated
|
| 395 |
+
kwargs = self._standardize_kwargs(**kwargs)
|
| 396 |
+
|
| 397 |
+
# Validate kwargs
|
| 398 |
+
self._validate_preprocess_kwargs(**kwargs)
|
| 399 |
+
|
| 400 |
+
return self._preprocess_image_like_inputs(images, *args, **kwargs)
|
| 401 |
+
|
| 402 |
+
def to_dict(self) -> dict[str, Any]:
|
| 403 |
+
processor_dict = super().to_dict()
|
| 404 |
+
|
| 405 |
+
# Filter out None values that are class defaults
|
| 406 |
+
filtered_dict = {}
|
| 407 |
+
for key, value in processor_dict.items():
|
| 408 |
+
if isinstance(value, SizeDict):
|
| 409 |
+
value = dict(value)
|
| 410 |
+
if value is None:
|
| 411 |
+
class_default = getattr(type(self), key, "NOT_FOUND")
|
| 412 |
+
# Keep None if user explicitly set it (class default is non-None)
|
| 413 |
+
if class_default != "NOT_FOUND" and class_default is not None:
|
| 414 |
+
filtered_dict[key] = value
|
| 415 |
+
else:
|
| 416 |
+
filtered_dict[key] = value
|
| 417 |
+
|
| 418 |
+
filtered_dict.pop("_valid_processor_keys", None)
|
| 419 |
+
filtered_dict.pop("_valid_kwargs_names", None)
|
| 420 |
+
return filtered_dict
|
| 421 |
+
|
| 422 |
+
def rescale(
|
| 423 |
+
self,
|
| 424 |
+
image: np.ndarray,
|
| 425 |
+
scale: float,
|
| 426 |
+
data_format: str | ChannelDimension | None = None,
|
| 427 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 428 |
+
**kwargs,
|
| 429 |
+
) -> np.ndarray:
|
| 430 |
+
"""
|
| 431 |
+
Rescale an image by a scale factor. image = image * scale.
|
| 432 |
+
|
| 433 |
+
Args:
|
| 434 |
+
image (`np.ndarray`):
|
| 435 |
+
Image to rescale.
|
| 436 |
+
scale (`float`):
|
| 437 |
+
The scaling factor to rescale pixel values by.
|
| 438 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 439 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 440 |
+
image is used. Can be one of:
|
| 441 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 442 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 443 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 444 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 445 |
+
from the input image. Can be one of:
|
| 446 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 447 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
`np.ndarray`: The rescaled image.
|
| 451 |
+
"""
|
| 452 |
+
return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
|
| 453 |
+
|
| 454 |
+
# The next methods are kept for backwards compatibility with remote code, but are overriden by backends.
|
| 455 |
+
def normalize(
|
| 456 |
+
self,
|
| 457 |
+
image: np.ndarray,
|
| 458 |
+
mean: float | Iterable[float],
|
| 459 |
+
std: float | Iterable[float],
|
| 460 |
+
data_format: str | ChannelDimension | None = None,
|
| 461 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 462 |
+
**kwargs,
|
| 463 |
+
) -> np.ndarray:
|
| 464 |
+
"""
|
| 465 |
+
Normalize an image. image = (image - image_mean) / image_std.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
image (`np.ndarray`):
|
| 469 |
+
Image to normalize.
|
| 470 |
+
mean (`float` or `Iterable[float]`):
|
| 471 |
+
Image mean to use for normalization.
|
| 472 |
+
std (`float` or `Iterable[float]`):
|
| 473 |
+
Image standard deviation to use for normalization.
|
| 474 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 475 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 476 |
+
image is used. Can be one of:
|
| 477 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 478 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 479 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 480 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 481 |
+
from the input image. Can be one of:
|
| 482 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 483 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
`np.ndarray`: The normalized image.
|
| 487 |
+
"""
|
| 488 |
+
return normalize(
|
| 489 |
+
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
def center_crop(
|
| 493 |
+
self,
|
| 494 |
+
image: np.ndarray,
|
| 495 |
+
size: dict[str, int],
|
| 496 |
+
data_format: str | ChannelDimension | None = None,
|
| 497 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 498 |
+
**kwargs,
|
| 499 |
+
) -> np.ndarray:
|
| 500 |
+
"""
|
| 501 |
+
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
|
| 502 |
+
any edge, the image is padded with 0's and then center cropped.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
image (`np.ndarray`):
|
| 506 |
+
Image to center crop.
|
| 507 |
+
size (`dict[str, int]`):
|
| 508 |
+
Size of the output image.
|
| 509 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 510 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 511 |
+
image is used. Can be one of:
|
| 512 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 513 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 514 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 515 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 516 |
+
from the input image. Can be one of:
|
| 517 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 518 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 519 |
+
"""
|
| 520 |
+
size = get_size_dict(size)
|
| 521 |
+
if "height" not in size or "width" not in size:
|
| 522 |
+
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
|
| 523 |
+
return center_crop(
|
| 524 |
+
image,
|
| 525 |
+
size=(size["height"], size["width"]),
|
| 526 |
+
data_format=data_format,
|
| 527 |
+
input_data_format=input_data_format,
|
| 528 |
+
**kwargs,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
VALID_SIZE_DICT_KEYS = (
|
| 533 |
+
{"height", "width"},
|
| 534 |
+
{"shortest_edge"},
|
| 535 |
+
{"shortest_edge", "longest_edge"},
|
| 536 |
+
{"longest_edge"},
|
| 537 |
+
{"max_height", "max_width"},
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def is_valid_size_dict(size_dict):
|
| 542 |
+
if not isinstance(size_dict, dict):
|
| 543 |
+
return False
|
| 544 |
+
|
| 545 |
+
size_dict_keys = set(size_dict.keys())
|
| 546 |
+
for allowed_keys in VALID_SIZE_DICT_KEYS:
|
| 547 |
+
if size_dict_keys == allowed_keys:
|
| 548 |
+
return True
|
| 549 |
+
return False
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def convert_to_size_dict(
|
| 553 |
+
size: int | Iterable[int] | None = None,
|
| 554 |
+
max_size: int | None = None,
|
| 555 |
+
default_to_square: bool = True,
|
| 556 |
+
height_width_order: bool = True,
|
| 557 |
+
) -> dict[str, int]:
|
| 558 |
+
# By default, if size is an int we assume it represents a tuple of (size, size).
|
| 559 |
+
if isinstance(size, int) and default_to_square:
|
| 560 |
+
if max_size is not None:
|
| 561 |
+
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
|
| 562 |
+
return {"height": size, "width": size}
|
| 563 |
+
# In other configs, if size is an int and default_to_square is False, size represents the length of
|
| 564 |
+
# the shortest edge after resizing.
|
| 565 |
+
elif isinstance(size, int) and not default_to_square:
|
| 566 |
+
size_dict = {"shortest_edge": size}
|
| 567 |
+
if max_size is not None:
|
| 568 |
+
size_dict["longest_edge"] = max_size
|
| 569 |
+
return size_dict
|
| 570 |
+
# Otherwise, if size is a tuple it's either (height, width) or (width, height)
|
| 571 |
+
elif isinstance(size, (tuple, list)) and height_width_order:
|
| 572 |
+
return {"height": size[0], "width": size[1]}
|
| 573 |
+
elif isinstance(size, (tuple, list)) and not height_width_order:
|
| 574 |
+
return {"height": size[1], "width": size[0]}
|
| 575 |
+
elif size is None and max_size is not None:
|
| 576 |
+
if default_to_square:
|
| 577 |
+
raise ValueError("Cannot specify both default_to_square=True and max_size")
|
| 578 |
+
return {"longest_edge": max_size}
|
| 579 |
+
|
| 580 |
+
raise ValueError(f"Could not convert size input to size dict: {size}")
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def get_size_dict(
|
| 584 |
+
size: int | Iterable[int] | dict[str, int] | SizeDict | None = None,
|
| 585 |
+
max_size: int | None = None,
|
| 586 |
+
height_width_order: bool = True,
|
| 587 |
+
default_to_square: bool = True,
|
| 588 |
+
param_name="size",
|
| 589 |
+
) -> dict:
|
| 590 |
+
"""
|
| 591 |
+
Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
|
| 592 |
+
compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
|
| 593 |
+
width) or (width, height) format.
|
| 594 |
+
|
| 595 |
+
- If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
|
| 596 |
+
size[0]}` if `height_width_order` is `False`.
|
| 597 |
+
- If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
|
| 598 |
+
- If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
|
| 599 |
+
is set, it is added to the dict as `{"longest_edge": max_size}`.
|
| 600 |
+
- If `size` is `None` and `default_to_square` is False, the result is `{"longest_edge": max_size}` (requires
|
| 601 |
+
`max_size` to be set). Tuple/list/SizeDict/dict `size` values do not use `max_size`.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
size (`int | Iterable[int] | dict[str, int] | SizeDict`, *optional*):
|
| 605 |
+
The `size` parameter to be cast into a size dictionary.
|
| 606 |
+
max_size (`int | None`, *optional*):
|
| 607 |
+
With `default_to_square=False`, sets `longest_edge` when `size` is an int or `None`; unused for dict,
|
| 608 |
+
`SizeDict`, or tuple/list `size`. Raises if set with `default_to_square=True` when `size` is an int or `None`.
|
| 609 |
+
height_width_order (`bool`, *optional*, defaults to `True`):
|
| 610 |
+
If `size` is a tuple, whether it's in (height, width) or (width, height) order.
|
| 611 |
+
default_to_square (`bool`, *optional*, defaults to `True`):
|
| 612 |
+
If `size` is an int, whether to default to a square image or not.
|
| 613 |
+
"""
|
| 614 |
+
if not isinstance(size, dict | SizeDict):
|
| 615 |
+
size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
|
| 616 |
+
logger.info(
|
| 617 |
+
f"{param_name} should be a dictionary with one of the following sets of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
|
| 618 |
+
f" Converted to {size_dict}.",
|
| 619 |
+
)
|
| 620 |
+
# Some remote code bypasses or overrides `_standardize_kwargs`, so handle `SizeDict` `size` here too.
|
| 621 |
+
elif isinstance(size, SizeDict):
|
| 622 |
+
size_dict = dict(size)
|
| 623 |
+
else:
|
| 624 |
+
size_dict = size
|
| 625 |
+
|
| 626 |
+
if not is_valid_size_dict(size_dict):
|
| 627 |
+
raise ValueError(
|
| 628 |
+
f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
|
| 629 |
+
)
|
| 630 |
+
return size_dict
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
|
| 634 |
+
"""
|
| 635 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
| 636 |
+
|
| 637 |
+
This is done by calculating the effective and wasted resolution for each possible resolution.
|
| 638 |
+
|
| 639 |
+
The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
original_size (tuple):
|
| 643 |
+
The original size of the image in the format (height, width).
|
| 644 |
+
possible_resolutions (list):
|
| 645 |
+
A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
|
| 646 |
+
|
| 647 |
+
Returns:
|
| 648 |
+
tuple: The best fit resolution in the format (height, width).
|
| 649 |
+
"""
|
| 650 |
+
original_height, original_width = original_size
|
| 651 |
+
best_fit = None
|
| 652 |
+
max_effective_resolution = 0
|
| 653 |
+
min_wasted_resolution = float("inf")
|
| 654 |
+
|
| 655 |
+
for height, width in possible_resolutions:
|
| 656 |
+
scale = min(width / original_width, height / original_height)
|
| 657 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
| 658 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
| 659 |
+
wasted_resolution = (width * height) - effective_resolution
|
| 660 |
+
|
| 661 |
+
if effective_resolution > max_effective_resolution or (
|
| 662 |
+
effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
|
| 663 |
+
):
|
| 664 |
+
max_effective_resolution = effective_resolution
|
| 665 |
+
min_wasted_resolution = wasted_resolution
|
| 666 |
+
best_fit = (height, width)
|
| 667 |
+
|
| 668 |
+
return best_fit
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def get_patch_output_size(image, target_resolution, input_data_format):
|
| 672 |
+
"""
|
| 673 |
+
Given an image and a target resolution, calculate the output size of the image after cropping to the target
|
| 674 |
+
"""
|
| 675 |
+
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
|
| 676 |
+
target_height, target_width = target_resolution
|
| 677 |
+
|
| 678 |
+
scale_w = target_width / original_width
|
| 679 |
+
scale_h = target_height / original_height
|
| 680 |
+
|
| 681 |
+
if scale_w < scale_h:
|
| 682 |
+
new_width = target_width
|
| 683 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
| 684 |
+
else:
|
| 685 |
+
new_height = target_height
|
| 686 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
| 687 |
+
|
| 688 |
+
return new_height, new_width
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/image_utils.py
ADDED
|
@@ -0,0 +1,1069 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import base64
|
| 16 |
+
import os
|
| 17 |
+
from collections.abc import Iterable
|
| 18 |
+
from dataclasses import dataclass, fields
|
| 19 |
+
from io import BytesIO
|
| 20 |
+
from typing import Any, Union
|
| 21 |
+
|
| 22 |
+
import httpx
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
from .utils import (
|
| 26 |
+
ExplicitEnum,
|
| 27 |
+
is_numpy_array,
|
| 28 |
+
is_torch_available,
|
| 29 |
+
is_torch_tensor,
|
| 30 |
+
is_torchvision_available,
|
| 31 |
+
is_vision_available,
|
| 32 |
+
logging,
|
| 33 |
+
requires_backends,
|
| 34 |
+
to_numpy,
|
| 35 |
+
)
|
| 36 |
+
from .utils.constants import ( # noqa: F401
|
| 37 |
+
IMAGENET_DEFAULT_MEAN,
|
| 38 |
+
IMAGENET_DEFAULT_STD,
|
| 39 |
+
IMAGENET_STANDARD_MEAN,
|
| 40 |
+
IMAGENET_STANDARD_STD,
|
| 41 |
+
OPENAI_CLIP_MEAN,
|
| 42 |
+
OPENAI_CLIP_STD,
|
| 43 |
+
)
|
| 44 |
+
from .utils.import_utils import requires
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if is_vision_available():
|
| 48 |
+
import PIL.Image
|
| 49 |
+
import PIL.ImageOps
|
| 50 |
+
|
| 51 |
+
PILImageResampling = PIL.Image.Resampling
|
| 52 |
+
|
| 53 |
+
if is_torchvision_available():
|
| 54 |
+
from torchvision.io import ImageReadMode, decode_image
|
| 55 |
+
from torchvision.transforms import InterpolationMode
|
| 56 |
+
from torchvision.transforms.functional import pil_to_tensor
|
| 57 |
+
|
| 58 |
+
pil_torch_interpolation_mapping = {
|
| 59 |
+
PILImageResampling.NEAREST: InterpolationMode.NEAREST_EXACT,
|
| 60 |
+
PILImageResampling.BOX: InterpolationMode.BOX,
|
| 61 |
+
PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
|
| 62 |
+
PILImageResampling.HAMMING: InterpolationMode.HAMMING,
|
| 63 |
+
PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
|
| 64 |
+
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
|
| 65 |
+
}
|
| 66 |
+
# Create inverse mapping: InterpolationMode -> PILImageResampling
|
| 67 |
+
torch_pil_interpolation_mapping = {v: k for k, v in pil_torch_interpolation_mapping.items()}
|
| 68 |
+
else:
|
| 69 |
+
pil_torch_interpolation_mapping = {}
|
| 70 |
+
torch_pil_interpolation_mapping = {}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if is_torch_available():
|
| 74 |
+
import torch
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
logger = logging.get_logger(__name__)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
ImageInput = Union[
|
| 81 |
+
"PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"]
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ChannelDimension(ExplicitEnum):
|
| 86 |
+
FIRST = "channels_first"
|
| 87 |
+
LAST = "channels_last"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class AnnotationFormat(ExplicitEnum):
|
| 91 |
+
COCO_DETECTION = "coco_detection"
|
| 92 |
+
COCO_PANOPTIC = "coco_panoptic"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
AnnotationType = dict[str, int | str | list[dict]]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def is_pil_image(img):
|
| 99 |
+
return is_vision_available() and isinstance(img, PIL.Image.Image)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ImageType(ExplicitEnum):
|
| 103 |
+
PIL = "pillow"
|
| 104 |
+
TORCH = "torch"
|
| 105 |
+
NUMPY = "numpy"
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_image_type(image):
|
| 109 |
+
if is_pil_image(image):
|
| 110 |
+
return ImageType.PIL
|
| 111 |
+
if is_torch_tensor(image):
|
| 112 |
+
return ImageType.TORCH
|
| 113 |
+
if is_numpy_array(image):
|
| 114 |
+
return ImageType.NUMPY
|
| 115 |
+
raise ValueError(f"Unrecognized image type {type(image)}")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def is_valid_image(img):
|
| 119 |
+
return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def is_valid_list_of_images(images: list):
|
| 123 |
+
return images and all(is_valid_image(image) for image in images)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def concatenate_list(input_list):
|
| 127 |
+
if isinstance(input_list[0], list):
|
| 128 |
+
return [item for sublist in input_list for item in sublist]
|
| 129 |
+
elif isinstance(input_list[0], np.ndarray):
|
| 130 |
+
return np.concatenate(input_list, axis=0)
|
| 131 |
+
elif isinstance(input_list[0], torch.Tensor):
|
| 132 |
+
return torch.cat(input_list, dim=0)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def valid_images(imgs):
|
| 136 |
+
# If we have an list of images, make sure every image is valid
|
| 137 |
+
if isinstance(imgs, (list, tuple)):
|
| 138 |
+
for img in imgs:
|
| 139 |
+
if not valid_images(img):
|
| 140 |
+
return False
|
| 141 |
+
# If not a list of tuple, we have been given a single image or batched tensor of images
|
| 142 |
+
elif not is_valid_image(imgs):
|
| 143 |
+
return False
|
| 144 |
+
return True
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def is_batched(img):
|
| 148 |
+
if isinstance(img, (list, tuple)):
|
| 149 |
+
return is_valid_image(img[0])
|
| 150 |
+
return False
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def is_scaled_image(image: np.ndarray) -> bool:
|
| 154 |
+
"""
|
| 155 |
+
Checks to see whether the pixel values have already been rescaled to [0, 1].
|
| 156 |
+
"""
|
| 157 |
+
if image.dtype == np.uint8:
|
| 158 |
+
return False
|
| 159 |
+
|
| 160 |
+
# It's possible the image has pixel values in [0, 255] but is of floating type
|
| 161 |
+
return np.min(image) >= 0 and np.max(image) <= 1
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]:
|
| 165 |
+
"""
|
| 166 |
+
Ensure that the output is a list of images. If the input is a single image, it is converted to a list of length 1.
|
| 167 |
+
If the input is a batch of images, it is converted to a list of images.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
images (`ImageInput`):
|
| 171 |
+
Image of images to turn into a list of images.
|
| 172 |
+
expected_ndims (`int`, *optional*, defaults to 3):
|
| 173 |
+
Expected number of dimensions for a single input image. If the input image has a different number of
|
| 174 |
+
dimensions, an error is raised.
|
| 175 |
+
"""
|
| 176 |
+
if is_batched(images):
|
| 177 |
+
return images
|
| 178 |
+
|
| 179 |
+
# Either the input is a single image, in which case we create a list of length 1
|
| 180 |
+
if is_pil_image(images):
|
| 181 |
+
# PIL images are never batched
|
| 182 |
+
return [images]
|
| 183 |
+
|
| 184 |
+
if is_valid_image(images):
|
| 185 |
+
if images.ndim == expected_ndims + 1:
|
| 186 |
+
# Batch of images
|
| 187 |
+
images = list(images)
|
| 188 |
+
elif images.ndim == expected_ndims:
|
| 189 |
+
# Single image
|
| 190 |
+
images = [images]
|
| 191 |
+
else:
|
| 192 |
+
raise ValueError(
|
| 193 |
+
f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
|
| 194 |
+
f" {images.ndim} dimensions."
|
| 195 |
+
)
|
| 196 |
+
return images
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, or torch.Tensor, but got {type(images)}."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def make_flat_list_of_images(
|
| 203 |
+
images: list[ImageInput] | ImageInput,
|
| 204 |
+
expected_ndims: int = 3,
|
| 205 |
+
) -> ImageInput:
|
| 206 |
+
"""
|
| 207 |
+
Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
|
| 208 |
+
If the input is a nested list of images, it is converted to a flat list of images.
|
| 209 |
+
Args:
|
| 210 |
+
images (`Union[list[ImageInput], ImageInput]`):
|
| 211 |
+
The input image.
|
| 212 |
+
expected_ndims (`int`, *optional*, defaults to 3):
|
| 213 |
+
The expected number of dimensions for a single input image.
|
| 214 |
+
Returns:
|
| 215 |
+
list: A list of images or a 4d array of images.
|
| 216 |
+
"""
|
| 217 |
+
# If the input is a nested list of images, we flatten it
|
| 218 |
+
if (
|
| 219 |
+
isinstance(images, (list, tuple))
|
| 220 |
+
and all(isinstance(images_i, (list, tuple)) for images_i in images)
|
| 221 |
+
and all(is_valid_list_of_images(images_i) or not images_i for images_i in images)
|
| 222 |
+
):
|
| 223 |
+
return [img for img_list in images for img in img_list]
|
| 224 |
+
|
| 225 |
+
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
|
| 226 |
+
if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
|
| 227 |
+
return images
|
| 228 |
+
if images[0].ndim == expected_ndims + 1:
|
| 229 |
+
return [img for img_list in images for img in img_list]
|
| 230 |
+
|
| 231 |
+
if is_valid_image(images):
|
| 232 |
+
if is_pil_image(images) or images.ndim == expected_ndims:
|
| 233 |
+
return [images]
|
| 234 |
+
if images.ndim == expected_ndims + 1:
|
| 235 |
+
return list(images)
|
| 236 |
+
|
| 237 |
+
raise ValueError(f"Could not make a flat list of images from {images}")
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def make_nested_list_of_images(
|
| 241 |
+
images: list[ImageInput] | ImageInput,
|
| 242 |
+
expected_ndims: int = 3,
|
| 243 |
+
) -> list[ImageInput]:
|
| 244 |
+
"""
|
| 245 |
+
Ensure that the output is a nested list of images.
|
| 246 |
+
Args:
|
| 247 |
+
images (`Union[list[ImageInput], ImageInput]`):
|
| 248 |
+
The input image.
|
| 249 |
+
expected_ndims (`int`, *optional*, defaults to 3):
|
| 250 |
+
The expected number of dimensions for a single input image.
|
| 251 |
+
Returns:
|
| 252 |
+
list: A list of list of images or a list of 4d array of images.
|
| 253 |
+
"""
|
| 254 |
+
# If it's a list of batches, it's already in the right format
|
| 255 |
+
if (
|
| 256 |
+
isinstance(images, (list, tuple))
|
| 257 |
+
and all(isinstance(images_i, (list, tuple)) for images_i in images)
|
| 258 |
+
and all(is_valid_list_of_images(images_i) or not images_i for images_i in images)
|
| 259 |
+
):
|
| 260 |
+
return images
|
| 261 |
+
|
| 262 |
+
# If it's a list of images, it's a single batch, so convert it to a list of lists
|
| 263 |
+
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
|
| 264 |
+
if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
|
| 265 |
+
return [images]
|
| 266 |
+
if images[0].ndim == expected_ndims + 1:
|
| 267 |
+
return [list(image) for image in images]
|
| 268 |
+
|
| 269 |
+
# If it's a single image, convert it to a list of lists
|
| 270 |
+
if is_valid_image(images):
|
| 271 |
+
if is_pil_image(images) or images.ndim == expected_ndims:
|
| 272 |
+
return [[images]]
|
| 273 |
+
if images.ndim == expected_ndims + 1:
|
| 274 |
+
return [list(images)]
|
| 275 |
+
|
| 276 |
+
raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def to_numpy_array(img) -> np.ndarray:
|
| 280 |
+
if not is_valid_image(img):
|
| 281 |
+
raise ValueError(f"Invalid image type: {type(img)}")
|
| 282 |
+
|
| 283 |
+
if is_vision_available() and isinstance(img, PIL.Image.Image):
|
| 284 |
+
return np.array(img)
|
| 285 |
+
return to_numpy(img)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def infer_channel_dimension_format(
|
| 289 |
+
image: np.ndarray, num_channels: int | tuple[int, ...] | None = None
|
| 290 |
+
) -> ChannelDimension:
|
| 291 |
+
"""
|
| 292 |
+
Infers the channel dimension format of `image`.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
image (`np.ndarray`):
|
| 296 |
+
The image to infer the channel dimension of.
|
| 297 |
+
num_channels (`int` or `tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
|
| 298 |
+
The number of channels of the image.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
The channel dimension of the image.
|
| 302 |
+
"""
|
| 303 |
+
num_channels = num_channels if num_channels is not None else (1, 3)
|
| 304 |
+
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
|
| 305 |
+
|
| 306 |
+
if image.ndim == 3:
|
| 307 |
+
first_dim, last_dim = 0, 2
|
| 308 |
+
elif image.ndim == 4:
|
| 309 |
+
first_dim, last_dim = 1, 3
|
| 310 |
+
elif image.ndim == 5:
|
| 311 |
+
first_dim, last_dim = 2, 4
|
| 312 |
+
else:
|
| 313 |
+
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
|
| 314 |
+
|
| 315 |
+
if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
|
| 316 |
+
logger.warning(
|
| 317 |
+
f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) parameter to assign the channel dimension."
|
| 318 |
+
)
|
| 319 |
+
return ChannelDimension.FIRST
|
| 320 |
+
elif image.shape[first_dim] in num_channels:
|
| 321 |
+
return ChannelDimension.FIRST
|
| 322 |
+
elif image.shape[last_dim] in num_channels:
|
| 323 |
+
return ChannelDimension.LAST
|
| 324 |
+
raise ValueError("Unable to infer channel dimension format")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def get_channel_dimension_axis(image: np.ndarray, input_data_format: ChannelDimension | str | None = None) -> int:
|
| 328 |
+
"""
|
| 329 |
+
Returns the channel dimension axis of the image.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
image (`np.ndarray`):
|
| 333 |
+
The image to get the channel dimension axis of.
|
| 334 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 335 |
+
The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
The channel dimension axis of the image.
|
| 339 |
+
"""
|
| 340 |
+
if input_data_format is None:
|
| 341 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 342 |
+
if input_data_format == ChannelDimension.FIRST:
|
| 343 |
+
return image.ndim - 3
|
| 344 |
+
elif input_data_format == ChannelDimension.LAST:
|
| 345 |
+
return image.ndim - 1
|
| 346 |
+
raise ValueError(f"Unsupported data format: {input_data_format}")
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension | None = None) -> tuple[int, int]:
|
| 350 |
+
"""
|
| 351 |
+
Returns the (height, width) dimensions of the image.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
image (`np.ndarray`):
|
| 355 |
+
The image to get the dimensions of.
|
| 356 |
+
channel_dim (`ChannelDimension`, *optional*):
|
| 357 |
+
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
A tuple of the image's height and width.
|
| 361 |
+
"""
|
| 362 |
+
if channel_dim is None:
|
| 363 |
+
channel_dim = infer_channel_dimension_format(image)
|
| 364 |
+
|
| 365 |
+
if channel_dim == ChannelDimension.FIRST:
|
| 366 |
+
return image.shape[-2], image.shape[-1]
|
| 367 |
+
elif channel_dim == ChannelDimension.LAST:
|
| 368 |
+
return image.shape[-3], image.shape[-2]
|
| 369 |
+
else:
|
| 370 |
+
raise ValueError(f"Unsupported data format: {channel_dim}")
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def get_image_size_for_max_height_width(
|
| 374 |
+
image_size: tuple[int, int],
|
| 375 |
+
max_height: int,
|
| 376 |
+
max_width: int,
|
| 377 |
+
) -> tuple[int, int]:
|
| 378 |
+
"""
|
| 379 |
+
Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
|
| 380 |
+
Important, even if image_height < max_height and image_width < max_width, the image will be resized
|
| 381 |
+
to at least one of the edges be equal to max_height or max_width.
|
| 382 |
+
|
| 383 |
+
For example:
|
| 384 |
+
- input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
|
| 385 |
+
- input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
image_size (`tuple[int, int]`):
|
| 389 |
+
The image to resize.
|
| 390 |
+
max_height (`int`):
|
| 391 |
+
The maximum allowed height.
|
| 392 |
+
max_width (`int`):
|
| 393 |
+
The maximum allowed width.
|
| 394 |
+
"""
|
| 395 |
+
height, width = image_size
|
| 396 |
+
height_scale = max_height / height
|
| 397 |
+
width_scale = max_width / width
|
| 398 |
+
min_scale = min(height_scale, width_scale)
|
| 399 |
+
new_height = int(height * min_scale)
|
| 400 |
+
new_width = int(width * min_scale)
|
| 401 |
+
return new_height, new_width
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def max_across_indices(values: Iterable[Any]) -> list[Any]:
|
| 405 |
+
"""
|
| 406 |
+
Return the maximum value across all indices of an iterable of values.
|
| 407 |
+
"""
|
| 408 |
+
return [max(values_i) for values_i in zip(*values)]
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def get_max_height_width(
|
| 412 |
+
images: list[Union["torch.Tensor", np.ndarray]], input_data_format: str | ChannelDimension = ChannelDimension.FIRST
|
| 413 |
+
) -> list[int]:
|
| 414 |
+
"""
|
| 415 |
+
Get the maximum height and width across all images in a batch.
|
| 416 |
+
"""
|
| 417 |
+
if input_data_format == ChannelDimension.FIRST:
|
| 418 |
+
_, max_height, max_width = max_across_indices([img.shape for img in images])
|
| 419 |
+
elif input_data_format == ChannelDimension.LAST:
|
| 420 |
+
max_height, max_width, _ = max_across_indices([img.shape for img in images])
|
| 421 |
+
else:
|
| 422 |
+
raise ValueError(f"Invalid channel dimension format: {input_data_format}")
|
| 423 |
+
return (max_height, max_width)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def is_valid_annotation_coco_detection(annotation: dict[str, list | tuple]) -> bool:
|
| 427 |
+
if (
|
| 428 |
+
isinstance(annotation, dict)
|
| 429 |
+
and "image_id" in annotation
|
| 430 |
+
and "annotations" in annotation
|
| 431 |
+
and isinstance(annotation["annotations"], (list, tuple))
|
| 432 |
+
and (
|
| 433 |
+
# an image can have no annotations
|
| 434 |
+
len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
|
| 435 |
+
)
|
| 436 |
+
):
|
| 437 |
+
return True
|
| 438 |
+
return False
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def is_valid_annotation_coco_panoptic(annotation: dict[str, list | tuple]) -> bool:
|
| 442 |
+
if (
|
| 443 |
+
isinstance(annotation, dict)
|
| 444 |
+
and "image_id" in annotation
|
| 445 |
+
and "segments_info" in annotation
|
| 446 |
+
and "file_name" in annotation
|
| 447 |
+
and isinstance(annotation["segments_info"], (list, tuple))
|
| 448 |
+
and (
|
| 449 |
+
# an image can have no segments
|
| 450 |
+
len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
|
| 451 |
+
)
|
| 452 |
+
):
|
| 453 |
+
return True
|
| 454 |
+
return False
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def valid_coco_detection_annotations(annotations: Iterable[dict[str, list | tuple]]) -> bool:
|
| 458 |
+
return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def valid_coco_panoptic_annotations(annotations: Iterable[dict[str, list | tuple]]) -> bool:
|
| 462 |
+
return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def load_image(
|
| 466 |
+
image: Union[str, "PIL.Image.Image"],
|
| 467 |
+
timeout: float | None = None,
|
| 468 |
+
) -> "PIL.Image.Image":
|
| 469 |
+
"""
|
| 470 |
+
Loads `image` to a PIL Image.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
image (`str` or `PIL.Image.Image`):
|
| 474 |
+
The image to convert to the PIL Image format.
|
| 475 |
+
timeout (`float`, *optional*):
|
| 476 |
+
The timeout value in seconds for the URL request.
|
| 477 |
+
|
| 478 |
+
Returns:
|
| 479 |
+
`PIL.Image.Image`: A PIL Image.
|
| 480 |
+
"""
|
| 481 |
+
requires_backends(load_image, ["vision"])
|
| 482 |
+
if isinstance(image, str):
|
| 483 |
+
if image.startswith("http://") or image.startswith("https://"):
|
| 484 |
+
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
| 485 |
+
# like http_huggingface_co.png
|
| 486 |
+
image = PIL.Image.open(BytesIO(httpx.get(image, timeout=timeout, follow_redirects=True).content))
|
| 487 |
+
elif os.path.isfile(image):
|
| 488 |
+
image = PIL.Image.open(image)
|
| 489 |
+
else:
|
| 490 |
+
if image.startswith("data:image/"):
|
| 491 |
+
image = image.split(",")[1]
|
| 492 |
+
|
| 493 |
+
# Try to load as base64
|
| 494 |
+
try:
|
| 495 |
+
b64 = base64.decodebytes(image.encode())
|
| 496 |
+
image = PIL.Image.open(BytesIO(b64))
|
| 497 |
+
except Exception as e:
|
| 498 |
+
raise ValueError(
|
| 499 |
+
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
|
| 500 |
+
)
|
| 501 |
+
elif not isinstance(image, PIL.Image.Image):
|
| 502 |
+
raise TypeError(
|
| 503 |
+
"Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
|
| 504 |
+
)
|
| 505 |
+
image = PIL.ImageOps.exif_transpose(image)
|
| 506 |
+
image = image.convert("RGB")
|
| 507 |
+
return image
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
@requires(backends=("torchvision",))
|
| 511 |
+
def load_image_as_tensor(
|
| 512 |
+
image: Union[str, "PIL.Image.Image"],
|
| 513 |
+
timeout: float | None = None,
|
| 514 |
+
) -> "torch.Tensor":
|
| 515 |
+
"""
|
| 516 |
+
Loads `image` directly to a `torch.Tensor` using torchvision.
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
image (`str` or `PIL.Image.Image`):
|
| 520 |
+
The image to convert to the PIL Image format.
|
| 521 |
+
timeout (`float`, *optional*):
|
| 522 |
+
The timeout value in seconds for the URL request.
|
| 523 |
+
|
| 524 |
+
Returns:
|
| 525 |
+
`torch.Tensor`: A `[C, H, W]` uint8 tensor in RGB channel order.
|
| 526 |
+
"""
|
| 527 |
+
import torch
|
| 528 |
+
|
| 529 |
+
if isinstance(image, str):
|
| 530 |
+
if image.startswith("http://") or image.startswith("https://"):
|
| 531 |
+
raw = httpx.get(image, timeout=timeout, follow_redirects=True).content
|
| 532 |
+
buf = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
| 533 |
+
return decode_image(buf, mode=ImageReadMode.RGB)
|
| 534 |
+
elif os.path.isfile(image):
|
| 535 |
+
return decode_image(image, mode=ImageReadMode.RGB)
|
| 536 |
+
else:
|
| 537 |
+
if image.startswith("data:image/"):
|
| 538 |
+
image = image.split(",")[1]
|
| 539 |
+
try:
|
| 540 |
+
raw = base64.decodebytes(image.encode())
|
| 541 |
+
except Exception as e:
|
| 542 |
+
raise ValueError(
|
| 543 |
+
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
|
| 544 |
+
)
|
| 545 |
+
buf = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
| 546 |
+
return decode_image(buf, mode=ImageReadMode.RGB)
|
| 547 |
+
elif isinstance(image, PIL.Image.Image):
|
| 548 |
+
image = PIL.ImageOps.exif_transpose(image)
|
| 549 |
+
return pil_to_tensor(image.convert("RGB"))
|
| 550 |
+
else:
|
| 551 |
+
raise TypeError(
|
| 552 |
+
"Incorrect format used for image. Should be a URL, a local path, a base64 string, or a PIL image."
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def load_images(
|
| 557 |
+
images: Union[list, tuple, str, "PIL.Image.Image"], timeout: float | None = None
|
| 558 |
+
) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]:
|
| 559 |
+
"""Loads images, handling different levels of nesting.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
images: A single image, a list of images, or a list of lists of images to load.
|
| 563 |
+
timeout: Timeout for loading images.
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
A single image, a list of images, a list of lists of images.
|
| 567 |
+
"""
|
| 568 |
+
if isinstance(images, (list, tuple)):
|
| 569 |
+
if len(images) and isinstance(images[0], (list, tuple)):
|
| 570 |
+
return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
|
| 571 |
+
else:
|
| 572 |
+
return [load_image(image, timeout=timeout) for image in images]
|
| 573 |
+
else:
|
| 574 |
+
return load_image(images, timeout=timeout)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def validate_preprocess_arguments(
|
| 578 |
+
do_rescale: bool | None = None,
|
| 579 |
+
rescale_factor: float | None = None,
|
| 580 |
+
do_normalize: bool | None = None,
|
| 581 |
+
image_mean: float | list[float] | None = None,
|
| 582 |
+
image_std: float | list[float] | None = None,
|
| 583 |
+
do_pad: bool | None = None,
|
| 584 |
+
pad_size: dict[str, int] | int | None = None,
|
| 585 |
+
do_center_crop: bool | None = None,
|
| 586 |
+
crop_size: dict[str, int] | None = None,
|
| 587 |
+
do_resize: bool | None = None,
|
| 588 |
+
size: dict[str, int] | None = None,
|
| 589 |
+
resample: Union["PILImageResampling", "InterpolationMode", int] | None = None,
|
| 590 |
+
):
|
| 591 |
+
"""
|
| 592 |
+
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
|
| 593 |
+
Raises `ValueError` if arguments incompatibility is caught.
|
| 594 |
+
Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
|
| 595 |
+
sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
|
| 596 |
+
existing arguments when possible.
|
| 597 |
+
|
| 598 |
+
"""
|
| 599 |
+
if do_rescale and rescale_factor is None:
|
| 600 |
+
raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
|
| 601 |
+
|
| 602 |
+
if do_pad and pad_size is None:
|
| 603 |
+
# Processors pad images using different args depending on the model, so the below check is pointless
|
| 604 |
+
# but we keep it for BC for now. TODO: remove in v5
|
| 605 |
+
# Usually padding can be called with:
|
| 606 |
+
# - "pad_size/size" if we're padding to specific values
|
| 607 |
+
# - "size_divisor" if we're padding to any value divisible by X
|
| 608 |
+
# - "None" if we're padding to the maximum size image in batch
|
| 609 |
+
raise ValueError(
|
| 610 |
+
"Depending on the model, `size_divisor` or `pad_size` or `size` must be specified if `do_pad` is `True`."
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
if do_normalize and (image_mean is None or image_std is None):
|
| 614 |
+
raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
|
| 615 |
+
|
| 616 |
+
if do_center_crop and crop_size is None:
|
| 617 |
+
raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
|
| 618 |
+
|
| 619 |
+
if do_resize and not (size is not None and resample is not None):
|
| 620 |
+
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
class ImageFeatureExtractionMixin:
|
| 624 |
+
"""
|
| 625 |
+
Mixin that contain utilities for preparing image features.
|
| 626 |
+
"""
|
| 627 |
+
|
| 628 |
+
def _ensure_format_supported(self, image):
|
| 629 |
+
if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
|
| 630 |
+
raise ValueError(
|
| 631 |
+
f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.ndarray` and "
|
| 632 |
+
"`torch.Tensor` are."
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
def to_pil_image(self, image, rescale=None):
|
| 636 |
+
"""
|
| 637 |
+
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
| 638 |
+
needed.
|
| 639 |
+
|
| 640 |
+
Args:
|
| 641 |
+
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
|
| 642 |
+
The image to convert to the PIL Image format.
|
| 643 |
+
rescale (`bool`, *optional*):
|
| 644 |
+
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
|
| 645 |
+
default to `True` if the image type is a floating type, `False` otherwise.
|
| 646 |
+
"""
|
| 647 |
+
self._ensure_format_supported(image)
|
| 648 |
+
|
| 649 |
+
if is_torch_tensor(image):
|
| 650 |
+
image = image.numpy()
|
| 651 |
+
|
| 652 |
+
if isinstance(image, np.ndarray):
|
| 653 |
+
if rescale is None:
|
| 654 |
+
# rescale default to the array being of floating type.
|
| 655 |
+
rescale = isinstance(image.flat[0], np.floating)
|
| 656 |
+
# If the channel as been moved to first dim, we put it back at the end.
|
| 657 |
+
if image.ndim == 3 and image.shape[0] in [1, 3]:
|
| 658 |
+
image = image.transpose(1, 2, 0)
|
| 659 |
+
if rescale:
|
| 660 |
+
image = image * 255
|
| 661 |
+
image = image.astype(np.uint8)
|
| 662 |
+
return PIL.Image.fromarray(image)
|
| 663 |
+
return image
|
| 664 |
+
|
| 665 |
+
def convert_rgb(self, image):
|
| 666 |
+
"""
|
| 667 |
+
Converts `PIL.Image.Image` to RGB format.
|
| 668 |
+
|
| 669 |
+
Args:
|
| 670 |
+
image (`PIL.Image.Image`):
|
| 671 |
+
The image to convert.
|
| 672 |
+
"""
|
| 673 |
+
self._ensure_format_supported(image)
|
| 674 |
+
if not isinstance(image, PIL.Image.Image):
|
| 675 |
+
return image
|
| 676 |
+
|
| 677 |
+
return image.convert("RGB")
|
| 678 |
+
|
| 679 |
+
def rescale(self, image: np.ndarray, scale: float | int) -> np.ndarray:
|
| 680 |
+
"""
|
| 681 |
+
Rescale a numpy image by scale amount
|
| 682 |
+
"""
|
| 683 |
+
self._ensure_format_supported(image)
|
| 684 |
+
return image * scale
|
| 685 |
+
|
| 686 |
+
def to_numpy_array(self, image, rescale=None, channel_first=True):
|
| 687 |
+
"""
|
| 688 |
+
Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
|
| 689 |
+
dimension.
|
| 690 |
+
|
| 691 |
+
Args:
|
| 692 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 693 |
+
The image to convert to a NumPy array.
|
| 694 |
+
rescale (`bool`, *optional*):
|
| 695 |
+
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
|
| 696 |
+
default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
|
| 697 |
+
channel_first (`bool`, *optional*, defaults to `True`):
|
| 698 |
+
Whether or not to permute the dimensions of the image to put the channel dimension first.
|
| 699 |
+
"""
|
| 700 |
+
self._ensure_format_supported(image)
|
| 701 |
+
|
| 702 |
+
if isinstance(image, PIL.Image.Image):
|
| 703 |
+
image = np.array(image)
|
| 704 |
+
|
| 705 |
+
if is_torch_tensor(image):
|
| 706 |
+
image = image.numpy()
|
| 707 |
+
|
| 708 |
+
rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
|
| 709 |
+
|
| 710 |
+
if rescale:
|
| 711 |
+
image = self.rescale(image.astype(np.float32), 1 / 255.0)
|
| 712 |
+
|
| 713 |
+
if channel_first and image.ndim == 3:
|
| 714 |
+
image = image.transpose(2, 0, 1)
|
| 715 |
+
|
| 716 |
+
return image
|
| 717 |
+
|
| 718 |
+
def expand_dims(self, image):
|
| 719 |
+
"""
|
| 720 |
+
Expands 2-dimensional `image` to 3 dimensions.
|
| 721 |
+
|
| 722 |
+
Args:
|
| 723 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 724 |
+
The image to expand.
|
| 725 |
+
"""
|
| 726 |
+
self._ensure_format_supported(image)
|
| 727 |
+
|
| 728 |
+
# Do nothing if PIL image
|
| 729 |
+
if isinstance(image, PIL.Image.Image):
|
| 730 |
+
return image
|
| 731 |
+
|
| 732 |
+
if is_torch_tensor(image):
|
| 733 |
+
image = image.unsqueeze(0)
|
| 734 |
+
else:
|
| 735 |
+
image = np.expand_dims(image, axis=0)
|
| 736 |
+
return image
|
| 737 |
+
|
| 738 |
+
def normalize(self, image, mean, std, rescale=False):
|
| 739 |
+
"""
|
| 740 |
+
Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
|
| 741 |
+
if it's a PIL Image.
|
| 742 |
+
|
| 743 |
+
Args:
|
| 744 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 745 |
+
The image to normalize.
|
| 746 |
+
mean (`list[float]` or `np.ndarray` or `torch.Tensor`):
|
| 747 |
+
The mean (per channel) to use for normalization.
|
| 748 |
+
std (`list[float]` or `np.ndarray` or `torch.Tensor`):
|
| 749 |
+
The standard deviation (per channel) to use for normalization.
|
| 750 |
+
rescale (`bool`, *optional*, defaults to `False`):
|
| 751 |
+
Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
|
| 752 |
+
happen automatically.
|
| 753 |
+
"""
|
| 754 |
+
self._ensure_format_supported(image)
|
| 755 |
+
|
| 756 |
+
if isinstance(image, PIL.Image.Image):
|
| 757 |
+
image = self.to_numpy_array(image, rescale=True)
|
| 758 |
+
# If the input image is a PIL image, it automatically gets rescaled. If it's another
|
| 759 |
+
# type it may need rescaling.
|
| 760 |
+
elif rescale:
|
| 761 |
+
if isinstance(image, np.ndarray):
|
| 762 |
+
image = self.rescale(image.astype(np.float32), 1 / 255.0)
|
| 763 |
+
elif is_torch_tensor(image):
|
| 764 |
+
image = self.rescale(image.float(), 1 / 255.0)
|
| 765 |
+
|
| 766 |
+
if isinstance(image, np.ndarray):
|
| 767 |
+
if not isinstance(mean, np.ndarray):
|
| 768 |
+
mean = np.array(mean).astype(image.dtype)
|
| 769 |
+
if not isinstance(std, np.ndarray):
|
| 770 |
+
std = np.array(std).astype(image.dtype)
|
| 771 |
+
elif is_torch_tensor(image):
|
| 772 |
+
import torch
|
| 773 |
+
|
| 774 |
+
if not isinstance(mean, torch.Tensor):
|
| 775 |
+
if isinstance(mean, np.ndarray):
|
| 776 |
+
mean = torch.from_numpy(mean)
|
| 777 |
+
else:
|
| 778 |
+
mean = torch.tensor(mean)
|
| 779 |
+
if not isinstance(std, torch.Tensor):
|
| 780 |
+
if isinstance(std, np.ndarray):
|
| 781 |
+
std = torch.from_numpy(std)
|
| 782 |
+
else:
|
| 783 |
+
std = torch.tensor(std)
|
| 784 |
+
|
| 785 |
+
if image.ndim == 3 and image.shape[0] in [1, 3]:
|
| 786 |
+
return (image - mean[:, None, None]) / std[:, None, None]
|
| 787 |
+
else:
|
| 788 |
+
return (image - mean) / std
|
| 789 |
+
|
| 790 |
+
def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
|
| 791 |
+
"""
|
| 792 |
+
Resizes `image`. Enforces conversion of input to PIL.Image.
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 796 |
+
The image to resize.
|
| 797 |
+
size (`int` or `tuple[int, int]`):
|
| 798 |
+
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
|
| 799 |
+
matched to this.
|
| 800 |
+
|
| 801 |
+
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
|
| 802 |
+
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
|
| 803 |
+
this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
| 804 |
+
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 805 |
+
The filter to user for resampling.
|
| 806 |
+
default_to_square (`bool`, *optional*, defaults to `True`):
|
| 807 |
+
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
|
| 808 |
+
square (`size`,`size`). If set to `False`, will replicate
|
| 809 |
+
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
|
| 810 |
+
with support for resizing only the smallest edge and providing an optional `max_size`.
|
| 811 |
+
max_size (`int`, *optional*, defaults to `None`):
|
| 812 |
+
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
|
| 813 |
+
greater than `max_size` after being resized according to `size`, then the image is resized again so
|
| 814 |
+
that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
|
| 815 |
+
edge may be shorter than `size`. Only used if `default_to_square` is `False`.
|
| 816 |
+
|
| 817 |
+
Returns:
|
| 818 |
+
image: A resized `PIL.Image.Image`.
|
| 819 |
+
"""
|
| 820 |
+
resample = resample if resample is not None else PILImageResampling.BILINEAR
|
| 821 |
+
|
| 822 |
+
self._ensure_format_supported(image)
|
| 823 |
+
|
| 824 |
+
if not isinstance(image, PIL.Image.Image):
|
| 825 |
+
image = self.to_pil_image(image)
|
| 826 |
+
|
| 827 |
+
if isinstance(size, list):
|
| 828 |
+
size = tuple(size)
|
| 829 |
+
|
| 830 |
+
if isinstance(size, int) or len(size) == 1:
|
| 831 |
+
if default_to_square:
|
| 832 |
+
size = (size, size) if isinstance(size, int) else (size[0], size[0])
|
| 833 |
+
else:
|
| 834 |
+
width, height = image.size
|
| 835 |
+
# specified size only for the smallest edge
|
| 836 |
+
short, long = (width, height) if width <= height else (height, width)
|
| 837 |
+
requested_new_short = size if isinstance(size, int) else size[0]
|
| 838 |
+
|
| 839 |
+
if short == requested_new_short:
|
| 840 |
+
return image
|
| 841 |
+
|
| 842 |
+
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
|
| 843 |
+
|
| 844 |
+
if max_size is not None:
|
| 845 |
+
if max_size <= requested_new_short:
|
| 846 |
+
raise ValueError(
|
| 847 |
+
f"max_size = {max_size} must be strictly greater than the requested "
|
| 848 |
+
f"size for the smaller edge size = {size}"
|
| 849 |
+
)
|
| 850 |
+
if new_long > max_size:
|
| 851 |
+
new_short, new_long = int(max_size * new_short / new_long), max_size
|
| 852 |
+
|
| 853 |
+
size = (new_short, new_long) if width <= height else (new_long, new_short)
|
| 854 |
+
|
| 855 |
+
return image.resize(size, resample=resample)
|
| 856 |
+
|
| 857 |
+
def center_crop(self, image, size):
|
| 858 |
+
"""
|
| 859 |
+
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
|
| 860 |
+
size given, it will be padded (so the returned result has the size asked).
|
| 861 |
+
|
| 862 |
+
Args:
|
| 863 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
|
| 864 |
+
The image to resize.
|
| 865 |
+
size (`int` or `tuple[int, int]`):
|
| 866 |
+
The size to which crop the image.
|
| 867 |
+
|
| 868 |
+
Returns:
|
| 869 |
+
new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
|
| 870 |
+
height, width).
|
| 871 |
+
"""
|
| 872 |
+
self._ensure_format_supported(image)
|
| 873 |
+
|
| 874 |
+
if not isinstance(size, tuple):
|
| 875 |
+
size = (size, size)
|
| 876 |
+
|
| 877 |
+
# PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
|
| 878 |
+
if is_torch_tensor(image) or isinstance(image, np.ndarray):
|
| 879 |
+
if image.ndim == 2:
|
| 880 |
+
image = self.expand_dims(image)
|
| 881 |
+
image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
|
| 882 |
+
else:
|
| 883 |
+
image_shape = (image.size[1], image.size[0])
|
| 884 |
+
|
| 885 |
+
top = (image_shape[0] - size[0]) // 2
|
| 886 |
+
bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
|
| 887 |
+
left = (image_shape[1] - size[1]) // 2
|
| 888 |
+
right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
|
| 889 |
+
|
| 890 |
+
# For PIL Images we have a method to crop directly.
|
| 891 |
+
if isinstance(image, PIL.Image.Image):
|
| 892 |
+
return image.crop((left, top, right, bottom))
|
| 893 |
+
|
| 894 |
+
# Check if image is in (n_channels, height, width) or (height, width, n_channels) format
|
| 895 |
+
channel_first = image.shape[0] in [1, 3]
|
| 896 |
+
|
| 897 |
+
# Transpose (height, width, n_channels) format images
|
| 898 |
+
if not channel_first:
|
| 899 |
+
if isinstance(image, np.ndarray):
|
| 900 |
+
image = image.transpose(2, 0, 1)
|
| 901 |
+
if is_torch_tensor(image):
|
| 902 |
+
image = image.permute(2, 0, 1)
|
| 903 |
+
|
| 904 |
+
# Check if cropped area is within image boundaries
|
| 905 |
+
if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
|
| 906 |
+
return image[..., top:bottom, left:right]
|
| 907 |
+
|
| 908 |
+
# Otherwise, we may need to pad if the image is too small. Oh joy...
|
| 909 |
+
new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
|
| 910 |
+
if isinstance(image, np.ndarray):
|
| 911 |
+
new_image = np.zeros_like(image, shape=new_shape)
|
| 912 |
+
elif is_torch_tensor(image):
|
| 913 |
+
new_image = image.new_zeros(new_shape)
|
| 914 |
+
|
| 915 |
+
top_pad = (new_shape[-2] - image_shape[0]) // 2
|
| 916 |
+
bottom_pad = top_pad + image_shape[0]
|
| 917 |
+
left_pad = (new_shape[-1] - image_shape[1]) // 2
|
| 918 |
+
right_pad = left_pad + image_shape[1]
|
| 919 |
+
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
|
| 920 |
+
|
| 921 |
+
top += top_pad
|
| 922 |
+
bottom += top_pad
|
| 923 |
+
left += left_pad
|
| 924 |
+
right += left_pad
|
| 925 |
+
|
| 926 |
+
new_image = new_image[
|
| 927 |
+
..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
|
| 928 |
+
]
|
| 929 |
+
|
| 930 |
+
return new_image
|
| 931 |
+
|
| 932 |
+
def flip_channel_order(self, image):
|
| 933 |
+
"""
|
| 934 |
+
Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
|
| 935 |
+
`image` to a NumPy array if it's a PIL Image.
|
| 936 |
+
|
| 937 |
+
Args:
|
| 938 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 939 |
+
The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
|
| 940 |
+
be first.
|
| 941 |
+
"""
|
| 942 |
+
self._ensure_format_supported(image)
|
| 943 |
+
|
| 944 |
+
if isinstance(image, PIL.Image.Image):
|
| 945 |
+
image = self.to_numpy_array(image)
|
| 946 |
+
|
| 947 |
+
return image[::-1, :, :]
|
| 948 |
+
|
| 949 |
+
def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
|
| 950 |
+
"""
|
| 951 |
+
Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
|
| 952 |
+
counter clockwise around its centre.
|
| 953 |
+
|
| 954 |
+
Args:
|
| 955 |
+
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
| 956 |
+
The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
|
| 957 |
+
rotating.
|
| 958 |
+
|
| 959 |
+
Returns:
|
| 960 |
+
image: A rotated `PIL.Image.Image`.
|
| 961 |
+
"""
|
| 962 |
+
resample = resample if resample is not None else PIL.Image.NEAREST
|
| 963 |
+
|
| 964 |
+
self._ensure_format_supported(image)
|
| 965 |
+
|
| 966 |
+
if not isinstance(image, PIL.Image.Image):
|
| 967 |
+
image = self.to_pil_image(image)
|
| 968 |
+
|
| 969 |
+
return image.rotate(
|
| 970 |
+
angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
def validate_annotations(
|
| 975 |
+
annotation_format: AnnotationFormat,
|
| 976 |
+
supported_annotation_formats: tuple[AnnotationFormat, ...],
|
| 977 |
+
annotations: list[dict],
|
| 978 |
+
) -> None:
|
| 979 |
+
if annotation_format not in supported_annotation_formats:
|
| 980 |
+
raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
|
| 981 |
+
|
| 982 |
+
if annotation_format is AnnotationFormat.COCO_DETECTION:
|
| 983 |
+
if not valid_coco_detection_annotations(annotations):
|
| 984 |
+
raise ValueError(
|
| 985 |
+
"Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
|
| 986 |
+
"(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
|
| 987 |
+
"being a list of annotations in the COCO format."
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
if annotation_format is AnnotationFormat.COCO_PANOPTIC:
|
| 991 |
+
if not valid_coco_panoptic_annotations(annotations):
|
| 992 |
+
raise ValueError(
|
| 993 |
+
"Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
|
| 994 |
+
"(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
|
| 995 |
+
"the latter being a list of annotations in the COCO format."
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def validate_kwargs(valid_processor_keys: list[str], captured_kwargs: list[str]):
|
| 1000 |
+
unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
|
| 1001 |
+
if unused_keys:
|
| 1002 |
+
unused_key_str = ", ".join(unused_keys)
|
| 1003 |
+
# TODO raise a warning here instead of simply logging?
|
| 1004 |
+
logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
@dataclass()
|
| 1008 |
+
class SizeDict:
|
| 1009 |
+
"""
|
| 1010 |
+
Hashable dictionary to store image size information.
|
| 1011 |
+
"""
|
| 1012 |
+
|
| 1013 |
+
height: int | None = None
|
| 1014 |
+
width: int | None = None
|
| 1015 |
+
longest_edge: int | None = None
|
| 1016 |
+
shortest_edge: int | None = None
|
| 1017 |
+
max_height: int | None = None
|
| 1018 |
+
max_width: int | None = None
|
| 1019 |
+
|
| 1020 |
+
def __getitem__(self, key):
|
| 1021 |
+
if hasattr(self, key):
|
| 1022 |
+
return getattr(self, key)
|
| 1023 |
+
raise KeyError(f"Key {key} not found in SizeDict.")
|
| 1024 |
+
|
| 1025 |
+
def get(self, key, default=None):
|
| 1026 |
+
if hasattr(self, key) and getattr(self, key) is not None:
|
| 1027 |
+
return getattr(self, key)
|
| 1028 |
+
return default
|
| 1029 |
+
|
| 1030 |
+
def __iter__(self):
|
| 1031 |
+
# Yield only non-None (key, value) pairs so dict(self) excludes missing values.
|
| 1032 |
+
for f in fields(self):
|
| 1033 |
+
val = getattr(self, f.name)
|
| 1034 |
+
if val is not None:
|
| 1035 |
+
yield f.name, val
|
| 1036 |
+
|
| 1037 |
+
def __hash__(self):
|
| 1038 |
+
return hash((self.height, self.width, self.longest_edge, self.shortest_edge, self.max_height, self.max_width))
|
| 1039 |
+
|
| 1040 |
+
def __contains__(self, key):
|
| 1041 |
+
return hasattr(self, key) and getattr(self, key) is not None
|
| 1042 |
+
|
| 1043 |
+
def __setitem__(self, key, value):
|
| 1044 |
+
if not hasattr(self, key):
|
| 1045 |
+
raise KeyError(f"Key {key} is not a valid field of SizeDict.")
|
| 1046 |
+
object.__setattr__(self, key, value)
|
| 1047 |
+
|
| 1048 |
+
def __eq__(self, other):
|
| 1049 |
+
if isinstance(other, dict):
|
| 1050 |
+
return dict(self) == other
|
| 1051 |
+
if isinstance(other, SizeDict):
|
| 1052 |
+
return tuple(getattr(self, f.name) for f in fields(self)) == tuple(
|
| 1053 |
+
getattr(other, f.name) for f in fields(self)
|
| 1054 |
+
)
|
| 1055 |
+
return NotImplemented
|
| 1056 |
+
|
| 1057 |
+
def __or__(self, other) -> "SizeDict":
|
| 1058 |
+
if isinstance(other, dict | SizeDict):
|
| 1059 |
+
merged = dict(self)
|
| 1060 |
+
merged.update(dict(other))
|
| 1061 |
+
return SizeDict(**merged)
|
| 1062 |
+
return NotImplemented
|
| 1063 |
+
|
| 1064 |
+
def __ror__(self, other) -> dict:
|
| 1065 |
+
if isinstance(other, dict):
|
| 1066 |
+
merged = dict(other)
|
| 1067 |
+
merged.update(dict(self))
|
| 1068 |
+
return merged
|
| 1069 |
+
return NotImplemented
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/masking_utils.py
ADDED
|
@@ -0,0 +1,1514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from collections.abc import Callable
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
from .cache_utils import Cache
|
| 20 |
+
from .configuration_utils import PreTrainedConfig
|
| 21 |
+
from .utils import is_torch_xpu_available, logging
|
| 22 |
+
from .utils.generic import GeneralInterface, is_flash_attention_requested
|
| 23 |
+
from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_tracing
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_torch_flex_attn_available():
|
| 27 |
+
from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size
|
| 28 |
+
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
|
| 29 |
+
else:
|
| 30 |
+
# Register a fake type to avoid crashing for annotations and `isinstance` checks
|
| 31 |
+
BlockMask = torch.Tensor
|
| 32 |
+
|
| 33 |
+
_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
|
| 34 |
+
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
| 35 |
+
_is_torch_xpu_available = is_torch_xpu_available()
|
| 36 |
+
|
| 37 |
+
if _is_torch_greater_or_equal_than_2_6:
|
| 38 |
+
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def and_masks(*mask_functions: Callable) -> Callable:
|
| 45 |
+
"""Returns a mask function that is the intersection of provided mask functions"""
|
| 46 |
+
if not all(callable(arg) for arg in mask_functions):
|
| 47 |
+
raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")
|
| 48 |
+
|
| 49 |
+
def and_mask(batch_idx, head_idx, q_idx, kv_idx):
|
| 50 |
+
result = q_idx.new_ones((), dtype=torch.bool)
|
| 51 |
+
for mask in mask_functions:
|
| 52 |
+
result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
|
| 53 |
+
return result
|
| 54 |
+
|
| 55 |
+
return and_mask
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def or_masks(*mask_functions: Callable) -> Callable:
|
| 59 |
+
"""Returns a mask function that is the union of provided mask functions"""
|
| 60 |
+
if not all(callable(arg) for arg in mask_functions):
|
| 61 |
+
raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")
|
| 62 |
+
|
| 63 |
+
def or_mask(batch_idx, head_idx, q_idx, kv_idx):
|
| 64 |
+
result = q_idx.new_zeros((), dtype=torch.bool)
|
| 65 |
+
for mask in mask_functions:
|
| 66 |
+
result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
|
| 67 |
+
return result
|
| 68 |
+
|
| 69 |
+
return or_mask
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 73 |
+
"""
|
| 74 |
+
This creates a basic lower-diagonal causal mask.
|
| 75 |
+
"""
|
| 76 |
+
return kv_idx <= q_idx
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def bidirectional_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 80 |
+
"""
|
| 81 |
+
This creates a full bidirectional mask.
|
| 82 |
+
|
| 83 |
+
NOTE: It is important to keep an index-based version for non-vmap expansion.
|
| 84 |
+
"""
|
| 85 |
+
return q_idx >= 0
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def sliding_window_overlay(sliding_window: int) -> Callable:
|
| 89 |
+
"""
|
| 90 |
+
This is an overlay depicting a sliding window pattern. Add it on top of a causal mask for a proper sliding
|
| 91 |
+
window mask.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 95 |
+
return kv_idx > q_idx - sliding_window
|
| 96 |
+
|
| 97 |
+
return inner_mask
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def chunked_overlay(chunk_size: int, left_padding: torch.Tensor) -> Callable:
|
| 101 |
+
"""
|
| 102 |
+
This is an overlay depicting a chunked attention pattern. Add it on top of a causal mask for a proper chunked
|
| 103 |
+
attention mask.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 107 |
+
return (kv_idx - left_padding[batch_idx]) // chunk_size == (q_idx - left_padding[batch_idx]) // chunk_size
|
| 108 |
+
|
| 109 |
+
return inner_mask
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def blockwise_overlay(block_sequence_ids: torch.Tensor) -> Callable:
|
| 113 |
+
"""
|
| 114 |
+
This is an overlay depicting a blockwise masking pattern. Instead of a single
|
| 115 |
+
token, each block consists of arbitrary length tokens. In causal setup, each block
|
| 116 |
+
can attend to prev block causally and can't attend to future blocks. Within one block
|
| 117 |
+
the attention is always bidirectional.
|
| 118 |
+
Mostly used in MLLMs when non-text data attends bidirectionally to itself.
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 122 |
+
# Unmask if the q and kv come from same group which is not -1 (i.e. non-text)
|
| 123 |
+
q_group = block_sequence_ids[batch_idx, q_idx]
|
| 124 |
+
kv_group = block_sequence_ids[batch_idx, kv_idx]
|
| 125 |
+
return (q_group == kv_group) & (q_group >= 0)
|
| 126 |
+
|
| 127 |
+
return inner_mask
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
|
| 131 |
+
"""
|
| 132 |
+
This return the mask_function function to create a sliding window mask.
|
| 133 |
+
"""
|
| 134 |
+
return and_masks(sliding_window_overlay(sliding_window), causal_mask_function)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def sliding_window_bidirectional_overlay(sliding_window: int) -> Callable:
|
| 138 |
+
"""
|
| 139 |
+
This is an overlay depicting a bidirectional sliding window pattern.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 143 |
+
"""A token can attend to any other token if their absolute distance is within
|
| 144 |
+
the (inclusive) sliding window size (distance <= sliding_window)."""
|
| 145 |
+
return abs(q_idx - kv_idx) <= sliding_window
|
| 146 |
+
|
| 147 |
+
return inner_mask
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def sliding_window_bidirectional_mask_function(sliding_window: int) -> Callable:
|
| 151 |
+
"""
|
| 152 |
+
This return the mask_function function to create a bidirectional sliding window mask.
|
| 153 |
+
"""
|
| 154 |
+
return and_masks(sliding_window_bidirectional_overlay(sliding_window), bidirectional_mask_function)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) -> Callable:
|
| 158 |
+
"""
|
| 159 |
+
This return the mask_function function to create a chunked attention mask.
|
| 160 |
+
"""
|
| 161 |
+
return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
|
| 165 |
+
"""
|
| 166 |
+
This return the mask_function function corresponding to a 2D padding mask.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 170 |
+
# Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
|
| 171 |
+
# we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
|
| 172 |
+
# vectorizable on accelerator devices
|
| 173 |
+
return padding_mask[batch_idx, kv_idx]
|
| 174 |
+
|
| 175 |
+
return inner_mask
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def packed_sequence_mask_function(packed_sequence_mask: torch.Tensor) -> Callable:
|
| 179 |
+
"""
|
| 180 |
+
This return the mask_function function corresponding to a 2D packed sequence mask.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 184 |
+
return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx]
|
| 185 |
+
|
| 186 |
+
return inner_mask
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable:
|
| 190 |
+
"""
|
| 191 |
+
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
| 192 |
+
not start and end indices.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
| 196 |
+
return mask_function(batch_idx, head_idx, q_idx + q_offset, kv_idx + kv_offset)
|
| 197 |
+
|
| 198 |
+
return inner_mask
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def prepare_padding_mask(attention_mask: torch.Tensor | None, kv_length: int, kv_offset: int) -> torch.Tensor | None:
|
| 202 |
+
"""
|
| 203 |
+
From the 2D attention mask, prepare the correct padding mask to use by potentially padding it.
|
| 204 |
+
"""
|
| 205 |
+
local_padding_mask = attention_mask
|
| 206 |
+
if attention_mask is not None:
|
| 207 |
+
# Pad it if necessary
|
| 208 |
+
if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
|
| 209 |
+
local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length))
|
| 210 |
+
return local_padding_mask
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def maybe_pad_block_sequence_ids(
|
| 214 |
+
block_sequence_ids: torch.Tensor, attention_mask: torch.Tensor | None, kv_length: int, kv_offset: int
|
| 215 |
+
) -> torch.Tensor:
|
| 216 |
+
"""
|
| 217 |
+
Pads the `block_sequence_ids` in case the total length is less than `kv_length`.
|
| 218 |
+
Usually that happens with `StaticCache` generation or generating without cache.
|
| 219 |
+
Pads to the right with `-1`.
|
| 220 |
+
"""
|
| 221 |
+
if (padding_length := kv_length + kv_offset - block_sequence_ids.shape[-1]) > 0:
|
| 222 |
+
block_sequence_ids = F.pad(block_sequence_ids, pad=(0, padding_length), value=-1)
|
| 223 |
+
return block_sequence_ids
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _can_skip_causal_mask_xpu(
|
| 227 |
+
padding_mask: torch.Tensor | None,
|
| 228 |
+
query_length: int,
|
| 229 |
+
kv_length: int,
|
| 230 |
+
local_attention_size: int | None,
|
| 231 |
+
) -> bool:
|
| 232 |
+
"""
|
| 233 |
+
XPU-specific logic for determining if we can skip causal mask creation.
|
| 234 |
+
|
| 235 |
+
For XPU devices, we have special handling:
|
| 236 |
+
- Single query tokens (query_length == 1) use the same logic as CUDA
|
| 237 |
+
- Multi-query tokens can skip if padding_mask is provided and correctly structured
|
| 238 |
+
The mask must have all True values in the query window and all False after
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
if is_tracing(padding_mask):
|
| 242 |
+
return False
|
| 243 |
+
|
| 244 |
+
# Check local attention constraint (same as CUDA)
|
| 245 |
+
if local_attention_size is not None and kv_length >= local_attention_size:
|
| 246 |
+
return False
|
| 247 |
+
|
| 248 |
+
if padding_mask is None:
|
| 249 |
+
# Without padding mask, can skip if single query token or full causal attention
|
| 250 |
+
return query_length == 1 or kv_length == query_length
|
| 251 |
+
|
| 252 |
+
# XPU allows skipping under additional conditions when padding_mask is provided
|
| 253 |
+
if query_length == 1:
|
| 254 |
+
# Single query token: skip only if no padding tokens present
|
| 255 |
+
return padding_mask.all()
|
| 256 |
+
|
| 257 |
+
# XPU-specific: check if query window is all True and rest is all False
|
| 258 |
+
# This allows XPU to optimize the 1st token in static cache
|
| 259 |
+
return padding_mask[:, :query_length].all() and not padding_mask[:, query_length:].any()
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _ignore_causal_mask_sdpa(
|
| 263 |
+
padding_mask: torch.Tensor | None,
|
| 264 |
+
query_length: int,
|
| 265 |
+
kv_length: int,
|
| 266 |
+
kv_offset: int,
|
| 267 |
+
local_attention_size: int | None = None,
|
| 268 |
+
) -> bool:
|
| 269 |
+
"""
|
| 270 |
+
Detects whether the causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
|
| 271 |
+
|
| 272 |
+
In case no token is masked in the 2D `padding_mask` argument, if `query_length == 1` or
|
| 273 |
+
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
|
| 274 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
|
| 275 |
+
passed).
|
| 276 |
+
"""
|
| 277 |
+
if padding_mask is not None and padding_mask.shape[-1] > kv_length:
|
| 278 |
+
mask_indices = torch.arange(kv_length, device=padding_mask.device)
|
| 279 |
+
mask_indices += kv_offset
|
| 280 |
+
padding_mask = padding_mask[:, mask_indices]
|
| 281 |
+
|
| 282 |
+
if _is_torch_xpu_available:
|
| 283 |
+
# XPU devices have special handling for mask skipping:
|
| 284 |
+
# - Single query tokens use the same logic as CUDA
|
| 285 |
+
# - Multi-query tokens can skip if padding_mask is provided and correctly structured
|
| 286 |
+
# (all True in query window, all False after)
|
| 287 |
+
return _can_skip_causal_mask_xpu(padding_mask, query_length, kv_length, local_attention_size)
|
| 288 |
+
# When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
|
| 289 |
+
# hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
|
| 290 |
+
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
|
| 291 |
+
# `ignore_causal_mask = True` if we are not tracing
|
| 292 |
+
if (
|
| 293 |
+
not is_tracing(padding_mask)
|
| 294 |
+
# only cases when lower and upper diags are the same, see https://github.com/pytorch/pytorch/issues/108108
|
| 295 |
+
and (query_length == 1 or kv_length == query_length)
|
| 296 |
+
# in this case we need to add special patterns to the mask so cannot be skipped otherwise
|
| 297 |
+
and (local_attention_size is None or kv_length < local_attention_size)
|
| 298 |
+
# In this case, we need to add padding to the mask, so cannot be skipped otherwise
|
| 299 |
+
and (padding_mask is None or padding_mask.all())
|
| 300 |
+
):
|
| 301 |
+
return True
|
| 302 |
+
|
| 303 |
+
return False
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _can_skip_bidirectional_mask_xpu(
|
| 307 |
+
padding_mask: torch.Tensor | None,
|
| 308 |
+
kv_length: int,
|
| 309 |
+
local_attention_size: int | None,
|
| 310 |
+
) -> bool:
|
| 311 |
+
"""
|
| 312 |
+
XPU-specific logic for determining if we can skip bidirectional mask creation.
|
| 313 |
+
|
| 314 |
+
For XPU devices, we have special handling:
|
| 315 |
+
- Skip if no padding and no local attention constraint
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
if is_tracing(padding_mask):
|
| 319 |
+
return False
|
| 320 |
+
|
| 321 |
+
# Check local attention constraint (same as CUDA)
|
| 322 |
+
if local_attention_size is not None and kv_length >= local_attention_size:
|
| 323 |
+
return False
|
| 324 |
+
|
| 325 |
+
if padding_mask is None:
|
| 326 |
+
# Without padding mask, can always skip for full bidirectional attention
|
| 327 |
+
return True
|
| 328 |
+
|
| 329 |
+
# Skip only if no padding tokens present
|
| 330 |
+
return padding_mask.all()
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _ignore_bidirectional_mask_sdpa(
|
| 334 |
+
padding_mask: torch.Tensor | None,
|
| 335 |
+
kv_length: int,
|
| 336 |
+
local_attention_size: int | None = None,
|
| 337 |
+
) -> bool:
|
| 338 |
+
"""
|
| 339 |
+
Detects whether the bidirectional mask can be ignored in case PyTorch's SDPA is used.
|
| 340 |
+
|
| 341 |
+
In case no token is masked in the 2D `padding_mask` argument and no local attention constraint applies
|
| 342 |
+
(i.e. `local_attention_size` is None or `kv_length < local_attention_size`), we skip mask creation,
|
| 343 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
|
| 344 |
+
passed).
|
| 345 |
+
"""
|
| 346 |
+
if _is_torch_xpu_available:
|
| 347 |
+
# XPU devices have special handling for mask skipping:
|
| 348 |
+
# - Skip if no padding and no local attention constraint
|
| 349 |
+
return _can_skip_bidirectional_mask_xpu(padding_mask, kv_length, local_attention_size)
|
| 350 |
+
|
| 351 |
+
# When using `torch.export` or `torch.onnx.dynamo_export`, we need to avoid to check the contents of the mask;
|
| 352 |
+
# otherwise, we will encounter dynamic control flows
|
| 353 |
+
if (
|
| 354 |
+
not is_tracing(padding_mask)
|
| 355 |
+
and (padding_mask is None or padding_mask.all())
|
| 356 |
+
# in this case we need to add special patterns to the mask so cannot be skipped otherwise
|
| 357 |
+
and (local_attention_size is None or kv_length < local_attention_size)
|
| 358 |
+
):
|
| 359 |
+
return True
|
| 360 |
+
|
| 361 |
+
return False
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _vmap_expansion_sdpa(mask_function: Callable) -> Callable:
|
| 365 |
+
"""
|
| 366 |
+
Used to vmap our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
|
| 367 |
+
Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
|
| 368 |
+
functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
|
| 369 |
+
"""
|
| 370 |
+
# We vmap the function over all 4 dimensions, broadcasting [b_idx, h_idx, q_idx, kv_idx]
|
| 371 |
+
dimensions = [(None, None, None, 0), (None, None, 0, None), (None, 0, None, None), (0, None, None, None)]
|
| 372 |
+
for dims in dimensions:
|
| 373 |
+
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
|
| 374 |
+
return mask_function
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def _non_vmap_expansion_sdpa(
|
| 378 |
+
batch_indices: torch.Tensor, head_indices: torch.Tensor, q_indices: torch.Tensor, kv_indices: torch.Tensor
|
| 379 |
+
):
|
| 380 |
+
"""
|
| 381 |
+
Used to broadcast our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
|
| 382 |
+
Allows the usage of any index-based mask function without relying on vmap.
|
| 383 |
+
|
| 384 |
+
NOTE: This is limited to index based functions only and is not guaranteed to work otherwise.
|
| 385 |
+
|
| 386 |
+
Reference:
|
| 387 |
+
- https://github.com/huggingface/optimum-onnx/blob/c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365
|
| 388 |
+
"""
|
| 389 |
+
batch_indices = batch_indices[:, None, None, None]
|
| 390 |
+
head_indices = head_indices[None, :, None, None]
|
| 391 |
+
q_indices = q_indices[None, None, :, None]
|
| 392 |
+
kv_indices = kv_indices[None, None, None, :]
|
| 393 |
+
return batch_indices, head_indices, q_indices, kv_indices
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def sdpa_mask(
|
| 397 |
+
batch_size: int,
|
| 398 |
+
q_length: int,
|
| 399 |
+
kv_length: int,
|
| 400 |
+
q_offset: int = 0,
|
| 401 |
+
kv_offset: int = 0,
|
| 402 |
+
mask_function: Callable = causal_mask_function,
|
| 403 |
+
attention_mask: torch.Tensor | None = None,
|
| 404 |
+
local_size: int | None = None,
|
| 405 |
+
allow_is_causal_skip: bool = True,
|
| 406 |
+
allow_is_bidirectional_skip: bool = False,
|
| 407 |
+
allow_torch_fix: bool = True,
|
| 408 |
+
use_vmap: bool = False,
|
| 409 |
+
device: torch.device | str = "cpu",
|
| 410 |
+
**kwargs,
|
| 411 |
+
) -> torch.Tensor | None:
|
| 412 |
+
"""
|
| 413 |
+
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
|
| 414 |
+
the element should take part in the attention computation, and False that it should not.
|
| 415 |
+
This function can only be used with torch>=2.5, as the context manager is otherwise not available.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
batch_size (`int`):
|
| 419 |
+
The batch size of the input sequence.
|
| 420 |
+
q_length (`int`):
|
| 421 |
+
The size that the query states will have during the attention computation.
|
| 422 |
+
kv_length (`int`):
|
| 423 |
+
The size that the key and value states will have during the attention computation.
|
| 424 |
+
kv_offset (`int`, optional):
|
| 425 |
+
An optional offset to indicate at which first position the key and values states will refer to.
|
| 426 |
+
q_offset (`int`, optional):
|
| 427 |
+
An optional offset to indicate at which first position the query states will refer to.
|
| 428 |
+
mask_function (`Callable`):
|
| 429 |
+
The mask factory function describing the mask pattern.
|
| 430 |
+
attention_mask (`torch.Tensor`, optional):
|
| 431 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
| 432 |
+
local_size (`int`, optional):
|
| 433 |
+
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
|
| 434 |
+
to try to skip mask creation if possible.
|
| 435 |
+
allow_is_causal_skip (`bool`, optional):
|
| 436 |
+
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
|
| 437 |
+
`torch.sdpa` instead. Default to `True`.
|
| 438 |
+
allow_is_bidirectional_skip (`bool`, optional):
|
| 439 |
+
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
|
| 440 |
+
i.e. full attention without any padding. Default to `False`.
|
| 441 |
+
allow_torch_fix (`bool`, optional):
|
| 442 |
+
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
|
| 443 |
+
versions. We need an arg to skip it when using eager. By default `True`.
|
| 444 |
+
use_vmap (`bool`, optional):
|
| 445 |
+
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
|
| 446 |
+
index-based (for the cost of speed performance). By default `False`.
|
| 447 |
+
device (`torch.device` or `str`, optional):
|
| 448 |
+
An optional device to create the mask on.
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
## Creating a simple causal mask:
|
| 452 |
+
|
| 453 |
+
To create the following causal mask:
|
| 454 |
+
|
| 455 |
+
0 ■ ⬚ ⬚ ⬚ ⬚
|
| 456 |
+
1 ■ ■ ⬚ ⬚ ⬚
|
| 457 |
+
2 ■ ■ ■ ⬚ ⬚
|
| 458 |
+
3 ■ ■ ■ ■ ⬚
|
| 459 |
+
4 ■ ■ ■ ■ ■
|
| 460 |
+
|
| 461 |
+
You can do
|
| 462 |
+
|
| 463 |
+
```python
|
| 464 |
+
>>> sdpa_mask(batch_size=1, q_length=5, kv_length=5)
|
| 465 |
+
>>> tensor([[[[ True, False, False, False, False],
|
| 466 |
+
[ True, True, False, False, False],
|
| 467 |
+
[ True, True, True, False, False],
|
| 468 |
+
[ True, True, True, True, False],
|
| 469 |
+
[ True, True, True, True, True]]]])
|
| 470 |
+
```
|
| 471 |
+
|
| 472 |
+
## Creating a sliding window mask:
|
| 473 |
+
|
| 474 |
+
To create the following sliding window mask (`sliding_window=3`):
|
| 475 |
+
|
| 476 |
+
0 ■ ⬚ ⬚ ⬚ ⬚
|
| 477 |
+
1 ■ ■ ⬚ ⬚ ⬚
|
| 478 |
+
2 ■ ■ ■ ⬚ ⬚
|
| 479 |
+
3 ⬚ ■ ■ ■ ⬚
|
| 480 |
+
4 ⬚ ⬚ ■ ■ ■
|
| 481 |
+
|
| 482 |
+
You can do
|
| 483 |
+
|
| 484 |
+
```python
|
| 485 |
+
>>> sdpa_mask(batch_size=1, q_length=5, kv_length=5, mask_function=sliding_window_causal_mask_function(3))
|
| 486 |
+
>>> tensor([[[[ True, False, False, False, False],
|
| 487 |
+
[ True, True, False, False, False],
|
| 488 |
+
[ True, True, True, False, False],
|
| 489 |
+
[False, True, True, True, False],
|
| 490 |
+
[False, False, True, True, True]]]])
|
| 491 |
+
```
|
| 492 |
+
|
| 493 |
+
## Creating a chunked attention mask
|
| 494 |
+
|
| 495 |
+
To create the following chunked attention mask (`chunk_size=3`):
|
| 496 |
+
|
| 497 |
+
0 ■ ⬚ ⬚ ⬚ ⬚
|
| 498 |
+
1 ■ ■ ⬚ ⬚ ⬚
|
| 499 |
+
2 ■ ■ ■ ⬚ ⬚
|
| 500 |
+
3 ⬚ ⬚ ⬚ ■ ⬚
|
| 501 |
+
4 ⬚ ⬚ ⬚ ■ ■
|
| 502 |
+
|
| 503 |
+
You can do
|
| 504 |
+
|
| 505 |
+
```python
|
| 506 |
+
>>> sdpa_mask(batch_size=1, q_length=5, kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
|
| 507 |
+
>>> tensor([[[[ True, False, False, False, False],
|
| 508 |
+
[ True, True, False, False, False],
|
| 509 |
+
[ True, True, True, False, False],
|
| 510 |
+
[False, False, False, True, False],
|
| 511 |
+
[False, False, False, True, True]]]])
|
| 512 |
+
```
|
| 513 |
+
|
| 514 |
+
"""
|
| 515 |
+
# Potentially pad the 2D mask
|
| 516 |
+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
|
| 517 |
+
|
| 518 |
+
# Under specific conditions, we can avoid materializing the mask
|
| 519 |
+
# 1. Causal masks can rely on the `is_causal` argument
|
| 520 |
+
# 2. Bidirectional do not need any further processing (no bias)
|
| 521 |
+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
|
| 522 |
+
return None
|
| 523 |
+
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask, kv_length, local_size):
|
| 524 |
+
return None
|
| 525 |
+
|
| 526 |
+
# Potentially add the padding 2D mask
|
| 527 |
+
if padding_mask is not None:
|
| 528 |
+
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
| 529 |
+
|
| 530 |
+
batch_arange = torch.arange(batch_size, device=device)
|
| 531 |
+
head_arange = torch.arange(1, device=device)
|
| 532 |
+
q_arange = torch.arange(q_length, device=device) + q_offset
|
| 533 |
+
kv_arange = torch.arange(kv_length, device=device) + kv_offset
|
| 534 |
+
|
| 535 |
+
# Actual mask creation
|
| 536 |
+
# Option 1: Fast non-vmap mask creation (default)
|
| 537 |
+
if not use_vmap:
|
| 538 |
+
# Apply mask function element-wise through broadcasting
|
| 539 |
+
attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_arange, head_arange, q_arange, kv_arange))
|
| 540 |
+
# Expand the mask to match batch size and query length if they weren't used in the mask function
|
| 541 |
+
attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
|
| 542 |
+
|
| 543 |
+
# Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
|
| 544 |
+
elif _is_torch_greater_or_equal_than_2_6:
|
| 545 |
+
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
|
| 546 |
+
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
|
| 547 |
+
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
|
| 548 |
+
with TransformGetItemToIndex():
|
| 549 |
+
attention_mask = _vmap_expansion_sdpa(mask_function)(batch_arange, head_arange, q_arange, kv_arange)
|
| 550 |
+
|
| 551 |
+
# Option 3: Error out since it indicates that the user did something custom, which they shouldn't have (torch<2.6)
|
| 552 |
+
else:
|
| 553 |
+
raise ValueError(
|
| 554 |
+
"The vmap functionality for mask creation is only supported from torch>=2.6. "
|
| 555 |
+
"Please update your torch version or use `use_vmap=False` with index-based masks."
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any
|
| 559 |
+
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
|
| 560 |
+
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
|
| 561 |
+
attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True)
|
| 562 |
+
|
| 563 |
+
return attention_mask
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def eager_mask(
|
| 567 |
+
batch_size: int,
|
| 568 |
+
q_length: int,
|
| 569 |
+
kv_length: int,
|
| 570 |
+
q_offset: int = 0,
|
| 571 |
+
kv_offset: int = 0,
|
| 572 |
+
mask_function: Callable = causal_mask_function,
|
| 573 |
+
attention_mask: torch.Tensor | None = None,
|
| 574 |
+
dtype: torch.dtype = torch.float32,
|
| 575 |
+
allow_is_bidirectional_skip: bool = False,
|
| 576 |
+
use_vmap: bool = False,
|
| 577 |
+
device: torch.device | str = "cpu",
|
| 578 |
+
**kwargs,
|
| 579 |
+
) -> torch.Tensor:
|
| 580 |
+
"""
|
| 581 |
+
Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
|
| 582 |
+
the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
|
| 583 |
+
it should not.
|
| 584 |
+
|
| 585 |
+
Args:
|
| 586 |
+
batch_size (`int`):
|
| 587 |
+
The batch size of the input sequence.
|
| 588 |
+
q_length (`int`):
|
| 589 |
+
The size that the query states will have during the attention computation.
|
| 590 |
+
kv_length (`int`):
|
| 591 |
+
The size that the key and value states will have during the attention computation.
|
| 592 |
+
q_offset (`int`, optional):
|
| 593 |
+
An optional offset to indicate at which first position the query states will refer to.
|
| 594 |
+
kv_offset (`int`, optional):
|
| 595 |
+
An optional offset to indicate at which first position the key and values states will refer to.
|
| 596 |
+
mask_function (`Callable`):
|
| 597 |
+
The mask factory function describing the mask pattern.
|
| 598 |
+
attention_mask (`torch.Tensor`, optional):
|
| 599 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
| 600 |
+
dtype (`torch.dtype`, optional):
|
| 601 |
+
The dtype to use for the mask. By default, `torch.float32`.
|
| 602 |
+
allow_is_bidirectional_skip (`bool`, optional):
|
| 603 |
+
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
|
| 604 |
+
i.e. full attention without any padding. Default to `False`.
|
| 605 |
+
use_vmap (`bool`, optional):
|
| 606 |
+
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
|
| 607 |
+
index-based (for the cost of speed performance). By default `False`.
|
| 608 |
+
device (`torch.device` or `str`, optional):
|
| 609 |
+
An optional device to create the mask on.
|
| 610 |
+
"""
|
| 611 |
+
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
| 612 |
+
_ = kwargs.pop("allow_is_causal_skip", None)
|
| 613 |
+
_ = kwargs.pop("allow_torch_fix", None)
|
| 614 |
+
mask = sdpa_mask(
|
| 615 |
+
batch_size=batch_size,
|
| 616 |
+
q_length=q_length,
|
| 617 |
+
kv_length=kv_length,
|
| 618 |
+
q_offset=q_offset,
|
| 619 |
+
kv_offset=kv_offset,
|
| 620 |
+
mask_function=mask_function,
|
| 621 |
+
attention_mask=attention_mask,
|
| 622 |
+
allow_is_causal_skip=False,
|
| 623 |
+
allow_is_bidirectional_skip=allow_is_bidirectional_skip,
|
| 624 |
+
allow_torch_fix=False,
|
| 625 |
+
use_vmap=use_vmap,
|
| 626 |
+
device=device,
|
| 627 |
+
**kwargs,
|
| 628 |
+
)
|
| 629 |
+
# only bidirectional masks can be skipped, otherwise we convert bool -> float
|
| 630 |
+
if mask is not None:
|
| 631 |
+
min_dtype = torch.finfo(dtype).min
|
| 632 |
+
# we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
|
| 633 |
+
mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
|
| 634 |
+
return mask
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def flash_attention_mask(
|
| 638 |
+
batch_size: int,
|
| 639 |
+
q_length: int,
|
| 640 |
+
kv_length: int,
|
| 641 |
+
q_offset: int = 0,
|
| 642 |
+
kv_offset: int = 0,
|
| 643 |
+
mask_function: Callable = causal_mask_function,
|
| 644 |
+
attention_mask: torch.Tensor | None = None,
|
| 645 |
+
**kwargs,
|
| 646 |
+
):
|
| 647 |
+
"""
|
| 648 |
+
Create the attention mask necessary to use FA2. Since FA2 is un-padded by definition, here we simply return
|
| 649 |
+
`None` if the mask is fully causal, or we return the 2D mask which will then be used to extract the seq_lens.
|
| 650 |
+
We just slice it in case of sliding window.
|
| 651 |
+
|
| 652 |
+
Args:
|
| 653 |
+
batch_size (`int`):
|
| 654 |
+
The batch size of the input sequence.
|
| 655 |
+
q_length (`int`):
|
| 656 |
+
The size that the query states will have during the attention computation.
|
| 657 |
+
kv_length (`int`):
|
| 658 |
+
The size that the key and value states will have during the attention computation.
|
| 659 |
+
q_offset (`int`, optional):
|
| 660 |
+
An optional offset to indicate at which first position the query states will refer to.
|
| 661 |
+
kv_offset (`int`, optional):
|
| 662 |
+
An optional offset to indicate at which first position the key and values states will refer to.
|
| 663 |
+
mask_function (`Callable`):
|
| 664 |
+
The mask factory function describing the mask pattern.
|
| 665 |
+
attention_mask (`torch.Tensor`, optional):
|
| 666 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
| 667 |
+
"""
|
| 668 |
+
if attention_mask is not None:
|
| 669 |
+
# Here we need to slice from the right if using sliding or chunked (for full attention, this is equivalent to doing nothing)
|
| 670 |
+
attention_mask = attention_mask[:, -kv_length:]
|
| 671 |
+
# We only return an actual mask if there is at least 1 padding token, otherwise we return `None` and use `is_causal` in FA2
|
| 672 |
+
# (note that the attention_mask is a boolean dtype here)
|
| 673 |
+
if attention_mask.all():
|
| 674 |
+
attention_mask = None
|
| 675 |
+
|
| 676 |
+
return attention_mask
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
def flex_attention_mask(
|
| 680 |
+
batch_size: int,
|
| 681 |
+
q_length: int,
|
| 682 |
+
kv_length: int,
|
| 683 |
+
q_offset: int = 0,
|
| 684 |
+
kv_offset: int = 0,
|
| 685 |
+
mask_function: Callable = causal_mask_function,
|
| 686 |
+
attention_mask: torch.Tensor | None = None,
|
| 687 |
+
device: torch.device | str = "cpu",
|
| 688 |
+
**kwargs,
|
| 689 |
+
) -> BlockMask:
|
| 690 |
+
"""
|
| 691 |
+
Create a 4D block mask which is a compressed representation of the full 4D block causal mask. BlockMask is essential
|
| 692 |
+
for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/
|
| 693 |
+
|
| 694 |
+
Args:
|
| 695 |
+
batch_size (`int`):
|
| 696 |
+
The batch size of the input sequence.
|
| 697 |
+
q_length (`int`):
|
| 698 |
+
The size that the query states will have during the attention computation.
|
| 699 |
+
kv_length (`int`):
|
| 700 |
+
The size that the key and value states will have during the attention computation.
|
| 701 |
+
q_offset (`int`, optional):
|
| 702 |
+
An optional offset to indicate at which first position the query states will refer to.
|
| 703 |
+
kv_offset (`int`, optional):
|
| 704 |
+
An optional offset to indicate at which first position the key and values states will refer to.
|
| 705 |
+
mask_function (`Callable`):
|
| 706 |
+
The mask factory function describing the mask pattern.
|
| 707 |
+
attention_mask (`torch.Tensor`, optional):
|
| 708 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
|
| 709 |
+
device (`torch.device` or `str`, optional):
|
| 710 |
+
An optional device to create the mask on.
|
| 711 |
+
"""
|
| 712 |
+
# Potentially add the padding 2D mask
|
| 713 |
+
if attention_mask is not None:
|
| 714 |
+
# Older torch (2.5.x) cannot handle sequences not in multiples of 128 (default block size)
|
| 715 |
+
# Hence we pad to multiples of this as a minimum to ensure this
|
| 716 |
+
pad_len = ((attention_mask.shape[1] // flex_default_block_size) + 1) * flex_default_block_size
|
| 717 |
+
pad_len = pad_len - attention_mask.shape[1]
|
| 718 |
+
if not _is_torch_greater_or_equal_than_2_6 and pad_len > 0:
|
| 719 |
+
attention_mask = torch.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len))
|
| 720 |
+
|
| 721 |
+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
|
| 722 |
+
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
| 723 |
+
|
| 724 |
+
# Add the offsets on top (because flex interface only allows length, not start and end indices)
|
| 725 |
+
mask_function = add_offsets_to_mask_function(mask_function, q_offset, kv_offset)
|
| 726 |
+
|
| 727 |
+
# Finally create the block mask
|
| 728 |
+
block_mask = create_block_mask(
|
| 729 |
+
mask_mod=mask_function,
|
| 730 |
+
B=batch_size,
|
| 731 |
+
H=None,
|
| 732 |
+
Q_LEN=q_length,
|
| 733 |
+
KV_LEN=kv_length,
|
| 734 |
+
device=device,
|
| 735 |
+
_compile=_is_torch_greater_or_equal_than_2_6,
|
| 736 |
+
)
|
| 737 |
+
return block_mask
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
class AttentionMaskInterface(GeneralInterface):
|
| 741 |
+
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
| 742 |
+
# a new instance is created (in order to locally override a given function)
|
| 743 |
+
_global_mapping = {
|
| 744 |
+
"sdpa": sdpa_mask,
|
| 745 |
+
"eager": eager_mask,
|
| 746 |
+
"flash_attention_2": flash_attention_mask,
|
| 747 |
+
"flash_attention_3": flash_attention_mask,
|
| 748 |
+
"flash_attention_4": flash_attention_mask,
|
| 749 |
+
"flex_attention": flex_attention_mask,
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# Global AttentionMaskInterface shared by all models which do not need to overwrite any of the existing ones
|
| 754 |
+
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor | None:
|
| 758 |
+
"""
|
| 759 |
+
Find the indices of the sequence to which each new query token in the sequence belongs when using packed
|
| 760 |
+
tensor format (i.e. several sequences packed in the same batch dimension).
|
| 761 |
+
|
| 762 |
+
Args:
|
| 763 |
+
position_ids (`torch.Tensor`)
|
| 764 |
+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
|
| 765 |
+
|
| 766 |
+
Returns:
|
| 767 |
+
A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
|
| 768 |
+
pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
|
| 769 |
+
|
| 770 |
+
If the there is only one sequence in each batch item (and we don't compile), then we return `None` indicating
|
| 771 |
+
no packed sequences. This is the same as [[0, 0, 0, 0, 0, 0]] for the example above.
|
| 772 |
+
"""
|
| 773 |
+
# What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
|
| 774 |
+
# taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
|
| 775 |
+
# gives exactly the sequence indices
|
| 776 |
+
# Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
|
| 777 |
+
# cannot be part of the end of the first batch dim and the start of the 2nd one for example
|
| 778 |
+
first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
|
| 779 |
+
position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
|
| 780 |
+
packed_sequence_mask = (position_diff != 1).cumsum(-1)
|
| 781 |
+
|
| 782 |
+
# Sadly this is a dynamic control flow, so we cannot enable this check on anything compile related
|
| 783 |
+
if not is_tracing(packed_sequence_mask) and (packed_sequence_mask[:, -1] == 0).all():
|
| 784 |
+
return None
|
| 785 |
+
|
| 786 |
+
return packed_sequence_mask
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
def _preprocess_mask_arguments(
|
| 790 |
+
config: PreTrainedConfig,
|
| 791 |
+
inputs_embeds: torch.Tensor,
|
| 792 |
+
attention_mask: torch.Tensor | BlockMask | None,
|
| 793 |
+
past_key_values: Cache | None,
|
| 794 |
+
position_ids: torch.Tensor | None,
|
| 795 |
+
layer_idx: int | None,
|
| 796 |
+
encoder_hidden_states: torch.Tensor | None = None,
|
| 797 |
+
) -> tuple[bool, torch.Tensor | BlockMask | None, int, int]:
|
| 798 |
+
"""
|
| 799 |
+
Perform some common pre-processing of the mask arguments we get from the modeling code. Mostly determine the
|
| 800 |
+
key-value length and offsets, and if we should early exit or not.
|
| 801 |
+
|
| 802 |
+
Args:
|
| 803 |
+
config (`PreTrainedConfig`):
|
| 804 |
+
The model config.
|
| 805 |
+
inputs_embeds (`torch.Tensor`):
|
| 806 |
+
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
| 807 |
+
batch size, query length and dtype.
|
| 808 |
+
attention_mask (`torch.Tensor`, optional):
|
| 809 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
| 810 |
+
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
| 811 |
+
past_key_values (`Cache`, optional):
|
| 812 |
+
The past key values, if we use a cache.
|
| 813 |
+
position_ids (`torch.Tensor`, optional)
|
| 814 |
+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
|
| 815 |
+
layer_idx (`int`, optional):
|
| 816 |
+
If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
|
| 817 |
+
length and offset. Indeed, for hybrid caches, different layers may return different lengths.
|
| 818 |
+
encoder_hidden_states (`torch.Tensor`, optional):
|
| 819 |
+
The input embeddings of shape (batch_size, kv_length, hidden_dim). If provided, it is used instead of
|
| 820 |
+
`inputs_embeds` to infer the kv length.
|
| 821 |
+
|
| 822 |
+
Returns:
|
| 823 |
+
early_exit (`bool`):
|
| 824 |
+
Whether we should early exit mask creation, and return the mask as-is.
|
| 825 |
+
attention_mask (`torch.Tensor` or `BlockMask` or `None`):
|
| 826 |
+
The attention mask to either return immediately, or to use in downstream mask creation.
|
| 827 |
+
packed_sequence_mask (`torch.Tensor`, optional):
|
| 828 |
+
In case we detected packed sequence format, this is a tensor where each similar integer indicates that
|
| 829 |
+
the tokens belong to the same sequence.
|
| 830 |
+
q_length (`int`):
|
| 831 |
+
The size that the query states will have during the attention computation.
|
| 832 |
+
kv_length (`int`):
|
| 833 |
+
The size that the key and value states will have during the attention computation.
|
| 834 |
+
q_offset (`int`, optional):
|
| 835 |
+
An optional offset to indicate at which first position the query states will refer to.
|
| 836 |
+
kv_offset (`int`):
|
| 837 |
+
An offset to indicate at which first position the key and values states will refer to.
|
| 838 |
+
"""
|
| 839 |
+
# If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
|
| 840 |
+
if isinstance(attention_mask, (torch.Tensor, BlockMask)) and len(attention_mask.shape) == 4:
|
| 841 |
+
return True, attention_mask, None, None, None, None, None
|
| 842 |
+
|
| 843 |
+
# For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
|
| 844 |
+
# Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
|
| 845 |
+
# full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
|
| 846 |
+
# with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
|
| 847 |
+
# according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
|
| 848 |
+
if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping:
|
| 849 |
+
return True, None, None, None, None, None, None
|
| 850 |
+
|
| 851 |
+
# Move the mask to correct device, and potentially switch dtype for efficiency
|
| 852 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
| 853 |
+
attention_mask = attention_mask.to(device=inputs_embeds.device, dtype=torch.bool)
|
| 854 |
+
|
| 855 |
+
q_length = inputs_embeds.shape[1]
|
| 856 |
+
# If using a cache, it can give all information about mask sizes based on seen tokens
|
| 857 |
+
if past_key_values is not None:
|
| 858 |
+
q_offset = past_key_values.get_seq_length()
|
| 859 |
+
# To avoid graph breaks, StaticLayer return a tensor instead of int -> this has no impact on the ops, but we
|
| 860 |
+
# need the correct device
|
| 861 |
+
q_offset = q_offset.to(inputs_embeds.device) if isinstance(q_offset, torch.Tensor) else q_offset
|
| 862 |
+
kv_length, kv_offset = past_key_values.get_mask_sizes(q_length, layer_idx)
|
| 863 |
+
# Otherwise, we infer based on our input
|
| 864 |
+
else:
|
| 865 |
+
q_offset = 0
|
| 866 |
+
# 1. Rely on input directly
|
| 867 |
+
if attention_mask is None:
|
| 868 |
+
# For encoder-decoders, use encoder_hidden_states to infer kv_length if provided
|
| 869 |
+
kv_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else q_length
|
| 870 |
+
kv_offset = 0
|
| 871 |
+
# 2. Rely on the mask instead - needed for special cases like prefix tuning in PEFT
|
| 872 |
+
#
|
| 873 |
+
# This is a very unique and special case where an encoder utilizes a cache and expects its length
|
| 874 |
+
# to be accounted for (usually, they should never use a cache). In general, the mask should always
|
| 875 |
+
# match with the input sizes nonetheless (i.e. it does not affect others).
|
| 876 |
+
# Conclusion: "prefix tuning is evil"
|
| 877 |
+
else:
|
| 878 |
+
kv_length, kv_offset = attention_mask.shape[-1], 0
|
| 879 |
+
|
| 880 |
+
# We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
|
| 881 |
+
# and we don't have past_key_values, i.e. generally a training setup)
|
| 882 |
+
packed_sequence_mask = None
|
| 883 |
+
if position_ids is not None and attention_mask is None and past_key_values is None:
|
| 884 |
+
batch_size = inputs_embeds.shape[0]
|
| 885 |
+
# The position ids are sometimes just unsqueezed, without being expanded
|
| 886 |
+
if batch_size != position_ids.shape[0]:
|
| 887 |
+
position_ids = position_ids.expand(batch_size, -1)
|
| 888 |
+
packed_sequence_mask = find_packed_sequence_indices(position_ids)
|
| 889 |
+
|
| 890 |
+
return False, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def create_causal_mask(
|
| 894 |
+
config: PreTrainedConfig,
|
| 895 |
+
inputs_embeds: torch.Tensor,
|
| 896 |
+
attention_mask: torch.Tensor | None,
|
| 897 |
+
past_key_values: Cache | None,
|
| 898 |
+
position_ids: torch.Tensor | None = None,
|
| 899 |
+
or_mask_function: Callable | None = None,
|
| 900 |
+
and_mask_function: Callable | None = None,
|
| 901 |
+
block_sequence_ids: torch.Tensor | None = None,
|
| 902 |
+
) -> torch.Tensor | BlockMask | None:
|
| 903 |
+
"""
|
| 904 |
+
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
|
| 905 |
+
has an hybrid cache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
|
| 906 |
+
to what is needed in the `modeling_xxx.py` files).
|
| 907 |
+
|
| 908 |
+
Args:
|
| 909 |
+
config (`PreTrainedConfig`):
|
| 910 |
+
The model config.
|
| 911 |
+
inputs_embeds (`torch.Tensor`):
|
| 912 |
+
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
| 913 |
+
batch size, query length and dtype.
|
| 914 |
+
attention_mask (`torch.Tensor`, optional):
|
| 915 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
| 916 |
+
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
| 917 |
+
cache_position (`torch.Tensor`):
|
| 918 |
+
Deprecated and unused.
|
| 919 |
+
past_key_values (`Cache`, optional):
|
| 920 |
+
The past key values, if we use a cache.
|
| 921 |
+
position_ids (`torch.Tensor`, optional)
|
| 922 |
+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
|
| 923 |
+
or_mask_function (`Callable`, optional):
|
| 924 |
+
An optional mask function to combine with the causal mask function (by doing the union of both). This is
|
| 925 |
+
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
| 926 |
+
and_mask_function (`Callable`, optional):
|
| 927 |
+
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
|
| 928 |
+
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
| 929 |
+
block_sequence_ids (`torch.Tensor`, *optional*):
|
| 930 |
+
A tensor of same shape as input IDs indicating to which block or group each token belongs to. Tokens from
|
| 931 |
+
the same block will keep a bidirectional mask within the block, attending causally to the past. Index `-1`
|
| 932 |
+
can be used for blocks that have to keep complete causality within itself.
|
| 933 |
+
"""
|
| 934 |
+
# Power feature: if `is_causal` is False, then fallback to bi-directional mask for bi-directional attention.
|
| 935 |
+
# It allows to use decoder-only models with bi-directional attention as well
|
| 936 |
+
if not getattr(config, "is_causal", True):
|
| 937 |
+
return create_bidirectional_mask(
|
| 938 |
+
config,
|
| 939 |
+
inputs_embeds,
|
| 940 |
+
attention_mask,
|
| 941 |
+
past_key_values=past_key_values,
|
| 942 |
+
or_mask_function=or_mask_function,
|
| 943 |
+
and_mask_function=and_mask_function,
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
# If we have an hybrid cache structure, here we want to create the mask for the full layers
|
| 947 |
+
if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding:
|
| 948 |
+
layer_idx = past_key_values.is_sliding.index(False)
|
| 949 |
+
else:
|
| 950 |
+
layer_idx = 0
|
| 951 |
+
|
| 952 |
+
early_exit, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset = (
|
| 953 |
+
_preprocess_mask_arguments(config, inputs_embeds, attention_mask, past_key_values, position_ids, layer_idx)
|
| 954 |
+
)
|
| 955 |
+
if early_exit:
|
| 956 |
+
return attention_mask
|
| 957 |
+
|
| 958 |
+
batch_size, dtype, device = inputs_embeds.shape[0], inputs_embeds.dtype, inputs_embeds.device
|
| 959 |
+
mask_factory_function = causal_mask_function
|
| 960 |
+
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
| 961 |
+
|
| 962 |
+
# Defaulting to using non-vmap based mask creations except when detecting
|
| 963 |
+
# users passing custom mask functions (as we cannot guarantee that they
|
| 964 |
+
# are properly index-based as required by our implementation).
|
| 965 |
+
use_vmap = False
|
| 966 |
+
|
| 967 |
+
# Do not allow skip if we are compiling (this is to match BC)
|
| 968 |
+
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
|
| 969 |
+
if _is_torch_xpu_available:
|
| 970 |
+
# Do not allow skip if we are compiling for decoding, but for prefill, we still allow skip to optimization the perf of 1st token generation
|
| 971 |
+
allow_is_causal_skip = not (getattr(past_key_values, "is_compileable", False) and q_length == 1)
|
| 972 |
+
else:
|
| 973 |
+
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
|
| 974 |
+
|
| 975 |
+
# Allow slight deviations from causal mask
|
| 976 |
+
# Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
|
| 977 |
+
# padding mask, etc) as the resulting mask may otherwise not be correct!
|
| 978 |
+
if or_mask_function is not None:
|
| 979 |
+
if not _is_torch_greater_or_equal_than_2_6:
|
| 980 |
+
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
| 981 |
+
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
| 982 |
+
allow_is_causal_skip = False
|
| 983 |
+
use_vmap = True
|
| 984 |
+
if and_mask_function is not None:
|
| 985 |
+
if not _is_torch_greater_or_equal_than_2_6:
|
| 986 |
+
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
| 987 |
+
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
| 988 |
+
allow_is_causal_skip = False
|
| 989 |
+
use_vmap = True
|
| 990 |
+
|
| 991 |
+
# If we detected packing format or blockwise overlay
|
| 992 |
+
if packed_sequence_mask is not None:
|
| 993 |
+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
|
| 994 |
+
allow_is_causal_skip = False
|
| 995 |
+
if block_sequence_ids is not None:
|
| 996 |
+
block_sequence_ids = maybe_pad_block_sequence_ids(block_sequence_ids, attention_mask, kv_length, kv_offset)
|
| 997 |
+
mask_factory_function = or_masks(mask_factory_function, blockwise_overlay(block_sequence_ids))
|
| 998 |
+
allow_is_causal_skip = False
|
| 999 |
+
|
| 1000 |
+
# We now create the mask
|
| 1001 |
+
causal_mask = mask_interface(
|
| 1002 |
+
batch_size=batch_size,
|
| 1003 |
+
q_length=q_length,
|
| 1004 |
+
kv_length=kv_length,
|
| 1005 |
+
q_offset=q_offset,
|
| 1006 |
+
kv_offset=kv_offset,
|
| 1007 |
+
mask_function=mask_factory_function,
|
| 1008 |
+
attention_mask=attention_mask,
|
| 1009 |
+
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
| 1010 |
+
dtype=dtype, # Additional kwarg for eager
|
| 1011 |
+
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
| 1012 |
+
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
| 1013 |
+
device=device,
|
| 1014 |
+
)
|
| 1015 |
+
return causal_mask
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
def create_bidirectional_mask(
|
| 1019 |
+
config: PreTrainedConfig,
|
| 1020 |
+
inputs_embeds: torch.Tensor,
|
| 1021 |
+
attention_mask: torch.Tensor | None,
|
| 1022 |
+
encoder_hidden_states: torch.Tensor | None = None,
|
| 1023 |
+
past_key_values: Cache | None = None,
|
| 1024 |
+
or_mask_function: Callable | None = None,
|
| 1025 |
+
and_mask_function: Callable | None = None,
|
| 1026 |
+
**kwargs,
|
| 1027 |
+
) -> torch.Tensor | BlockMask | None:
|
| 1028 |
+
"""
|
| 1029 |
+
Create a standard bidirectional mask based on the attention implementation used (stored in the config).
|
| 1030 |
+
|
| 1031 |
+
Args:
|
| 1032 |
+
config (`PreTrainedConfig`):
|
| 1033 |
+
The model config.
|
| 1034 |
+
inputs_embeds (`torch.Tensor`):
|
| 1035 |
+
The input embeddings of shape (batch_size, query_length, hidden_dim). This is only used to infer metadata
|
| 1036 |
+
such as the batch size, query length, dtype, and device.
|
| 1037 |
+
attention_mask (`torch.Tensor`, optional):
|
| 1038 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, kv_length).
|
| 1039 |
+
It can also be an already prepared 4D mask of shape (batch_size, 1, query_length, kv_length),
|
| 1040 |
+
in which case it is returned as-is.
|
| 1041 |
+
encoder_hidden_states (`torch.Tensor`, optional):
|
| 1042 |
+
The input embeddings of shape (batch_size, kv_length, hidden_dim). If provided, it is used instead of
|
| 1043 |
+
`inputs_embeds` to infer the batch size, kv length and dtype.
|
| 1044 |
+
past_key_values (`Cache`, optional):
|
| 1045 |
+
The past key values, if we use a cache.
|
| 1046 |
+
or_mask_function (`Callable`, optional):
|
| 1047 |
+
An optional mask function to combine with the base mask function (by doing the union of both). This is
|
| 1048 |
+
useful to easily overlay another mask on top, for example for image tokens handling.
|
| 1049 |
+
and_mask_function (`Callable`, optional):
|
| 1050 |
+
An optional mask function to combine with the base mask function (by doing the intersection of both). This is
|
| 1051 |
+
useful to easily overlay another mask on top, for example for image tokens handling.
|
| 1052 |
+
"""
|
| 1053 |
+
# We ignore a few irrelevant arguments at the end as we do not have a (growing) cache here
|
| 1054 |
+
early_exit, attention_mask, _, q_length, kv_length, q_offset, kv_offset = _preprocess_mask_arguments(
|
| 1055 |
+
config, inputs_embeds, attention_mask, past_key_values, None, 0, encoder_hidden_states
|
| 1056 |
+
)
|
| 1057 |
+
if early_exit:
|
| 1058 |
+
return attention_mask
|
| 1059 |
+
|
| 1060 |
+
embeds = encoder_hidden_states if encoder_hidden_states is not None else inputs_embeds
|
| 1061 |
+
batch_size, dtype, device = embeds.shape[0], embeds.dtype, embeds.device
|
| 1062 |
+
mask_factory_function = bidirectional_mask_function
|
| 1063 |
+
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
| 1064 |
+
|
| 1065 |
+
# Allow skipping the mask creation except we have additional masking operators (and/or masks)
|
| 1066 |
+
allow_is_bidirectional_skip = True
|
| 1067 |
+
# Defaulting to using non-vmap based mask creations except when detecting
|
| 1068 |
+
# users passing custom mask functions (as we cannot guarantee that they
|
| 1069 |
+
# are properly index-based as required by our implementation).
|
| 1070 |
+
use_vmap = False
|
| 1071 |
+
|
| 1072 |
+
# Allow slight deviations from the base mask
|
| 1073 |
+
# Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
|
| 1074 |
+
# padding mask, etc) as the resulting mask may otherwise not be correct!
|
| 1075 |
+
if or_mask_function is not None:
|
| 1076 |
+
if not _is_torch_greater_or_equal_than_2_6:
|
| 1077 |
+
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
| 1078 |
+
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
| 1079 |
+
allow_is_bidirectional_skip = False
|
| 1080 |
+
use_vmap = True
|
| 1081 |
+
if and_mask_function is not None:
|
| 1082 |
+
if not _is_torch_greater_or_equal_than_2_6:
|
| 1083 |
+
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
| 1084 |
+
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
| 1085 |
+
allow_is_bidirectional_skip = False
|
| 1086 |
+
use_vmap = True
|
| 1087 |
+
|
| 1088 |
+
# We now create the mask
|
| 1089 |
+
attention_mask = mask_interface(
|
| 1090 |
+
batch_size=batch_size,
|
| 1091 |
+
q_length=q_length,
|
| 1092 |
+
kv_length=kv_length,
|
| 1093 |
+
q_offset=q_offset,
|
| 1094 |
+
kv_offset=kv_offset,
|
| 1095 |
+
mask_function=mask_factory_function,
|
| 1096 |
+
attention_mask=attention_mask,
|
| 1097 |
+
# Additional kwargs for sdpa
|
| 1098 |
+
allow_is_causal_skip=False,
|
| 1099 |
+
allow_is_bidirectional_skip=allow_is_bidirectional_skip,
|
| 1100 |
+
dtype=dtype, # Additional kwarg for eager
|
| 1101 |
+
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
| 1102 |
+
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
| 1103 |
+
device=device,
|
| 1104 |
+
)
|
| 1105 |
+
return attention_mask
|
| 1106 |
+
|
| 1107 |
+
|
| 1108 |
+
def create_sliding_window_causal_mask(
|
| 1109 |
+
config: PreTrainedConfig,
|
| 1110 |
+
inputs_embeds: torch.Tensor,
|
| 1111 |
+
attention_mask: torch.Tensor | None,
|
| 1112 |
+
past_key_values: Cache | None,
|
| 1113 |
+
position_ids: torch.Tensor | None = None,
|
| 1114 |
+
or_mask_function: Callable | None = None,
|
| 1115 |
+
and_mask_function: Callable | None = None,
|
| 1116 |
+
block_sequence_ids: torch.Tensor | None = None,
|
| 1117 |
+
) -> torch.Tensor | BlockMask | None:
|
| 1118 |
+
"""
|
| 1119 |
+
Create a sliding window causal mask based on the attention implementation used (stored in the config). This type
|
| 1120 |
+
of attention pattern was mostly democratized by Mistral. If `past_key_values` has an hybrid cache structure, this
|
| 1121 |
+
function will return the mask corresponding to one of the "sliding_attention" layers (to align to what is needed in the
|
| 1122 |
+
`modeling_xxx.py` files).
|
| 1123 |
+
|
| 1124 |
+
Args:
|
| 1125 |
+
config (`PreTrainedConfig`):
|
| 1126 |
+
The model config.
|
| 1127 |
+
inputs_embeds (`torch.Tensor`):
|
| 1128 |
+
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
| 1129 |
+
batch size, query length and dtype.
|
| 1130 |
+
attention_mask (`torch.Tensor`, optional):
|
| 1131 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
| 1132 |
+
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
| 1133 |
+
cache_position (`torch.Tensor`):
|
| 1134 |
+
Deprecated and unused.
|
| 1135 |
+
past_key_values (`Cache`, optional):
|
| 1136 |
+
The past key values, if we use a cache.
|
| 1137 |
+
position_ids (`torch.Tensor`, optional)
|
| 1138 |
+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
|
| 1139 |
+
or_mask_function (`Callable`, optional):
|
| 1140 |
+
An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
|
| 1141 |
+
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
|
| 1142 |
+
and_mask_function (`Callable`, optional):
|
| 1143 |
+
An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is
|
| 1144 |
+
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
|
| 1145 |
+
block_sequence_ids (`torch.Tensor`, *optional*):
|
| 1146 |
+
A tensor of same shape as input IDs indicating to which block or group each token belongs to. Tokens from
|
| 1147 |
+
the same block will keep a bidirectional mask within the block, attending causally to the past. Index `-1`
|
| 1148 |
+
can be used for blocks that have to keep complete causality within itself.
|
| 1149 |
+
"""
|
| 1150 |
+
# Power feature: if `is_causal` is False, then fallback to bi-directional mask for bi-directional attention
|
| 1151 |
+
# It allows to use decoder-only models with bi-directional attention as well
|
| 1152 |
+
if not getattr(config, "is_causal", True):
|
| 1153 |
+
return create_bidirectional_sliding_window_mask(
|
| 1154 |
+
config,
|
| 1155 |
+
inputs_embeds,
|
| 1156 |
+
attention_mask,
|
| 1157 |
+
past_key_values=past_key_values,
|
| 1158 |
+
or_mask_function=or_mask_function,
|
| 1159 |
+
and_mask_function=and_mask_function,
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
# If we have an hybrid cache structure, here we want to create the mask for the sliding layers
|
| 1163 |
+
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
|
| 1164 |
+
layer_idx = past_key_values.is_sliding.index(True)
|
| 1165 |
+
else:
|
| 1166 |
+
layer_idx = 0
|
| 1167 |
+
|
| 1168 |
+
early_exit, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset = (
|
| 1169 |
+
_preprocess_mask_arguments(config, inputs_embeds, attention_mask, past_key_values, position_ids, layer_idx)
|
| 1170 |
+
)
|
| 1171 |
+
if early_exit:
|
| 1172 |
+
return attention_mask
|
| 1173 |
+
|
| 1174 |
+
sliding_window = getattr(config, "sliding_window", None)
|
| 1175 |
+
if sliding_window is None:
|
| 1176 |
+
raise ValueError("Could not find a `sliding_window` argument in the config, or it is not set")
|
| 1177 |
+
|
| 1178 |
+
batch_size, dtype, device = inputs_embeds.shape[0], inputs_embeds.dtype, inputs_embeds.device
|
| 1179 |
+
mask_factory_function = sliding_window_causal_mask_function(sliding_window)
|
| 1180 |
+
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
| 1181 |
+
|
| 1182 |
+
# Defaulting to using non-vmap based mask creations except when detecting
|
| 1183 |
+
# users passing custom mask functions (as we cannot guarantee that they
|
| 1184 |
+
# are properly index-based as required by our implementation).
|
| 1185 |
+
use_vmap = False
|
| 1186 |
+
# Do not allow skip if we are compiling (this is to match BC)
|
| 1187 |
+
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
|
| 1188 |
+
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
|
| 1189 |
+
|
| 1190 |
+
# Allow slight deviations from causal mask
|
| 1191 |
+
# Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
|
| 1192 |
+
# padding mask, etc) as the resulting mask may otherwise not be correct!
|
| 1193 |
+
if or_mask_function is not None:
|
| 1194 |
+
if not _is_torch_greater_or_equal_than_2_6:
|
| 1195 |
+
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
| 1196 |
+
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
| 1197 |
+
allow_is_causal_skip = False
|
| 1198 |
+
use_vmap = True
|
| 1199 |
+
if and_mask_function is not None:
|
| 1200 |
+
if not _is_torch_greater_or_equal_than_2_6:
|
| 1201 |
+
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
| 1202 |
+
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
| 1203 |
+
allow_is_causal_skip = False
|
| 1204 |
+
use_vmap = True
|
| 1205 |
+
|
| 1206 |
+
# If we detected packing format or blockwise overlay
|
| 1207 |
+
if packed_sequence_mask is not None:
|
| 1208 |
+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
|
| 1209 |
+
allow_is_causal_skip = False
|
| 1210 |
+
if block_sequence_ids is not None:
|
| 1211 |
+
block_sequence_ids = maybe_pad_block_sequence_ids(block_sequence_ids, attention_mask, kv_length, kv_offset)
|
| 1212 |
+
mask_factory_function = or_masks(mask_factory_function, blockwise_overlay(block_sequence_ids))
|
| 1213 |
+
allow_is_causal_skip = False
|
| 1214 |
+
|
| 1215 |
+
# We now create the mask
|
| 1216 |
+
causal_mask = mask_interface(
|
| 1217 |
+
batch_size=batch_size,
|
| 1218 |
+
q_length=q_length,
|
| 1219 |
+
kv_length=kv_length,
|
| 1220 |
+
q_offset=q_offset,
|
| 1221 |
+
kv_offset=kv_offset,
|
| 1222 |
+
mask_function=mask_factory_function,
|
| 1223 |
+
attention_mask=attention_mask,
|
| 1224 |
+
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
| 1225 |
+
local_size=sliding_window, # Additional kwarg for sdpa
|
| 1226 |
+
dtype=dtype, # Additional kwarg for eager
|
| 1227 |
+
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
| 1228 |
+
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
| 1229 |
+
device=device,
|
| 1230 |
+
)
|
| 1231 |
+
return causal_mask
|
| 1232 |
+
|
| 1233 |
+
|
| 1234 |
+
def create_bidirectional_sliding_window_mask(
|
| 1235 |
+
config: PreTrainedConfig,
|
| 1236 |
+
inputs_embeds: torch.Tensor,
|
| 1237 |
+
attention_mask: torch.Tensor | None,
|
| 1238 |
+
encoder_hidden_states: torch.Tensor | None = None,
|
| 1239 |
+
past_key_values: Cache | None = None,
|
| 1240 |
+
or_mask_function: Callable | None = None,
|
| 1241 |
+
and_mask_function: Callable | None = None,
|
| 1242 |
+
**kwargs,
|
| 1243 |
+
) -> torch.Tensor | BlockMask | None:
|
| 1244 |
+
"""
|
| 1245 |
+
Create a standard bidirectional sliding window mask based on the attention implementation used (stored in the config).
|
| 1246 |
+
|
| 1247 |
+
Args:
|
| 1248 |
+
config (`PreTrainedConfig`):
|
| 1249 |
+
The model config.
|
| 1250 |
+
inputs_embeds (`torch.Tensor`):
|
| 1251 |
+
The input embeddings of shape (batch_size, query_length, hidden_dim). This is only used to infer metadata
|
| 1252 |
+
such as the batch size, query length, dtype, and device.
|
| 1253 |
+
attention_mask (`torch.Tensor`, optional):
|
| 1254 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, kv_length).
|
| 1255 |
+
It can also be an already prepared 4D mask of shape (batch_size, 1, query_length, kv_length),
|
| 1256 |
+
in which case it is returned as-is.
|
| 1257 |
+
encoder_hidden_states (`torch.Tensor`, optional):
|
| 1258 |
+
The input embeddings of shape (batch_size, kv_length, hidden_dim). If provided, it is used instead of
|
| 1259 |
+
`inputs_embeds` to infer the batch size, kv length and dtype.
|
| 1260 |
+
past_key_values (`Cache`, optional):
|
| 1261 |
+
The past key values, if we use a cache.
|
| 1262 |
+
or_mask_function (`Callable`, optional):
|
| 1263 |
+
An optional mask function to combine with the base mask function (by doing the union of both). This is
|
| 1264 |
+
useful to easily overlay another mask on top, for example for image tokens handling.
|
| 1265 |
+
and_mask_function (`Callable`, optional):
|
| 1266 |
+
An optional mask function to combine with the base mask function (by doing the intersection of both). This is
|
| 1267 |
+
useful to easily overlay another mask on top, for example for image tokens handling.
|
| 1268 |
+
"""
|
| 1269 |
+
# We ignore a few irrelevant arguments at the end as we do not have a (growing) cache here
|
| 1270 |
+
early_exit, attention_mask, _, q_length, kv_length, q_offset, kv_offset = _preprocess_mask_arguments(
|
| 1271 |
+
config, inputs_embeds, attention_mask, past_key_values, None, 0, encoder_hidden_states
|
| 1272 |
+
)
|
| 1273 |
+
if early_exit:
|
| 1274 |
+
return attention_mask
|
| 1275 |
+
|
| 1276 |
+
sliding_window = getattr(config, "sliding_window", None)
|
| 1277 |
+
if sliding_window is None:
|
| 1278 |
+
raise ValueError("Could not find a `sliding_window` argument in the config, or it is not set")
|
| 1279 |
+
|
| 1280 |
+
embeds = encoder_hidden_states if encoder_hidden_states is not None else inputs_embeds
|
| 1281 |
+
batch_size, dtype, device = embeds.shape[0], embeds.dtype, embeds.device
|
| 1282 |
+
mask_factory_function = sliding_window_bidirectional_mask_function(sliding_window)
|
| 1283 |
+
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
| 1284 |
+
|
| 1285 |
+
use_vmap = False
|
| 1286 |
+
allow_is_bidirectional_skip = True
|
| 1287 |
+
|
| 1288 |
+
if or_mask_function is not None:
|
| 1289 |
+
if not _is_torch_greater_or_equal_than_2_6:
|
| 1290 |
+
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
| 1291 |
+
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
| 1292 |
+
allow_is_bidirectional_skip = False
|
| 1293 |
+
use_vmap = True
|
| 1294 |
+
if and_mask_function is not None:
|
| 1295 |
+
if not _is_torch_greater_or_equal_than_2_6:
|
| 1296 |
+
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
| 1297 |
+
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
| 1298 |
+
allow_is_bidirectional_skip = False
|
| 1299 |
+
use_vmap = True
|
| 1300 |
+
|
| 1301 |
+
attention_mask = mask_interface(
|
| 1302 |
+
batch_size=batch_size,
|
| 1303 |
+
q_length=q_length,
|
| 1304 |
+
kv_length=kv_length,
|
| 1305 |
+
q_offset=q_offset,
|
| 1306 |
+
kv_offset=kv_offset,
|
| 1307 |
+
mask_function=mask_factory_function,
|
| 1308 |
+
attention_mask=attention_mask,
|
| 1309 |
+
allow_is_causal_skip=False,
|
| 1310 |
+
allow_is_bidirectional_skip=allow_is_bidirectional_skip,
|
| 1311 |
+
local_size=sliding_window, # Additional kwarg for sdpa
|
| 1312 |
+
dtype=dtype, # Additional kwarg for eager
|
| 1313 |
+
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
| 1314 |
+
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
| 1315 |
+
device=device,
|
| 1316 |
+
)
|
| 1317 |
+
return attention_mask
|
| 1318 |
+
|
| 1319 |
+
|
| 1320 |
+
def create_chunked_causal_mask(
|
| 1321 |
+
config: PreTrainedConfig,
|
| 1322 |
+
inputs_embeds: torch.Tensor,
|
| 1323 |
+
attention_mask: torch.Tensor | None,
|
| 1324 |
+
past_key_values: Cache | None,
|
| 1325 |
+
position_ids: torch.Tensor | None = None,
|
| 1326 |
+
or_mask_function: Callable | None = None,
|
| 1327 |
+
and_mask_function: Callable | None = None,
|
| 1328 |
+
) -> torch.Tensor | BlockMask | None:
|
| 1329 |
+
"""
|
| 1330 |
+
Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type
|
| 1331 |
+
of attention pattern was mostly democratized by Llama4. If `past_key_values` has an hybrid cache structure, this
|
| 1332 |
+
function will return the mask corresponding to one of the "chunked_attention" layers (to align to what is needed in the
|
| 1333 |
+
`modeling_xxx.py` files).
|
| 1334 |
+
|
| 1335 |
+
Args:
|
| 1336 |
+
config (`PreTrainedConfig`):
|
| 1337 |
+
The model config.
|
| 1338 |
+
inputs_embeds (`torch.Tensor`):
|
| 1339 |
+
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
| 1340 |
+
batch size, query length and dtype.
|
| 1341 |
+
attention_mask (`torch.Tensor`, optional):
|
| 1342 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
| 1343 |
+
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
| 1344 |
+
cache_position (`torch.Tensor`):
|
| 1345 |
+
Deprecated and unused.
|
| 1346 |
+
past_key_values (`Cache`, optional):
|
| 1347 |
+
The past key values, if we use a cache.
|
| 1348 |
+
position_ids (`torch.Tensor`, optional)
|
| 1349 |
+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
|
| 1350 |
+
or_mask_function (`Callable`, optional):
|
| 1351 |
+
An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
|
| 1352 |
+
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
|
| 1353 |
+
and_mask_function (`Callable`, optional):
|
| 1354 |
+
An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is
|
| 1355 |
+
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
|
| 1356 |
+
"""
|
| 1357 |
+
# If we have an hybrid cache structure, here we want to create the mask for the sliding layers
|
| 1358 |
+
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
|
| 1359 |
+
layer_idx = past_key_values.is_sliding.index(True)
|
| 1360 |
+
else:
|
| 1361 |
+
layer_idx = 0
|
| 1362 |
+
|
| 1363 |
+
early_exit, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset = (
|
| 1364 |
+
_preprocess_mask_arguments(config, inputs_embeds, attention_mask, past_key_values, position_ids, layer_idx)
|
| 1365 |
+
)
|
| 1366 |
+
if early_exit:
|
| 1367 |
+
return attention_mask
|
| 1368 |
+
|
| 1369 |
+
chunk_size = getattr(config, "attention_chunk_size", None)
|
| 1370 |
+
if chunk_size is None:
|
| 1371 |
+
raise ValueError("Could not find an `attention_chunk_size` argument in the config, or it is not set")
|
| 1372 |
+
|
| 1373 |
+
# Raise if using chunked attention on context too large with FA
|
| 1374 |
+
if is_flash_attention_requested(config) and kv_length + kv_offset > chunk_size:
|
| 1375 |
+
raise ValueError(
|
| 1376 |
+
"Flash attention cannot handle chunked attention, and the key-value length is larger than the chunk size so the "
|
| 1377 |
+
"chunked pattern cannot be respected. You should use another `attn_implementation` when instantiating the model"
|
| 1378 |
+
)
|
| 1379 |
+
|
| 1380 |
+
batch_size, dtype, device = inputs_embeds.shape[0], inputs_embeds.dtype, inputs_embeds.device
|
| 1381 |
+
# For chunked attention and batched inputs, we need to take the number of left padding tokens into account
|
| 1382 |
+
# to start the chunk from the actual start of the sequence for the padded sequence
|
| 1383 |
+
if attention_mask is not None:
|
| 1384 |
+
# Only count the left padding tokens, not all of them
|
| 1385 |
+
left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1)
|
| 1386 |
+
else:
|
| 1387 |
+
left_padding_tokens = torch.zeros(batch_size, device=device, dtype=int)
|
| 1388 |
+
mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens)
|
| 1389 |
+
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
| 1390 |
+
|
| 1391 |
+
# Defaulting to using non-vmap based mask creations except when detecting
|
| 1392 |
+
# users passing custom mask functions (as we cannot guarantee that they
|
| 1393 |
+
# are properly index-based as required by our implementation).
|
| 1394 |
+
use_vmap = False
|
| 1395 |
+
# Do not allow skip if we are compiling (this is to match BC)
|
| 1396 |
+
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
|
| 1397 |
+
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
|
| 1398 |
+
|
| 1399 |
+
# Allow slight deviations from causal mask
|
| 1400 |
+
# Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
|
| 1401 |
+
# padding mask, etc) as the resulting mask may otherwise not be correct!
|
| 1402 |
+
if or_mask_function is not None:
|
| 1403 |
+
if not _is_torch_greater_or_equal_than_2_6:
|
| 1404 |
+
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
| 1405 |
+
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
| 1406 |
+
allow_is_causal_skip = False
|
| 1407 |
+
use_vmap = True
|
| 1408 |
+
if and_mask_function is not None:
|
| 1409 |
+
if not _is_torch_greater_or_equal_than_2_6:
|
| 1410 |
+
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
|
| 1411 |
+
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
| 1412 |
+
allow_is_causal_skip = False
|
| 1413 |
+
use_vmap = True
|
| 1414 |
+
|
| 1415 |
+
# If we detected packing format
|
| 1416 |
+
if packed_sequence_mask is not None:
|
| 1417 |
+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
|
| 1418 |
+
allow_is_causal_skip = False
|
| 1419 |
+
|
| 1420 |
+
# We now create the mask
|
| 1421 |
+
causal_mask = mask_interface(
|
| 1422 |
+
batch_size=batch_size,
|
| 1423 |
+
q_length=q_length,
|
| 1424 |
+
kv_length=kv_length,
|
| 1425 |
+
q_offset=q_offset,
|
| 1426 |
+
kv_offset=kv_offset,
|
| 1427 |
+
mask_function=mask_factory_function,
|
| 1428 |
+
attention_mask=attention_mask,
|
| 1429 |
+
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
| 1430 |
+
local_size=chunk_size, # Additional kwarg for sdpa
|
| 1431 |
+
dtype=dtype, # Additional kwarg for eager
|
| 1432 |
+
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
| 1433 |
+
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
|
| 1434 |
+
device=device,
|
| 1435 |
+
)
|
| 1436 |
+
return causal_mask
|
| 1437 |
+
|
| 1438 |
+
|
| 1439 |
+
LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = {
|
| 1440 |
+
"full_attention": create_causal_mask,
|
| 1441 |
+
"sliding_attention": create_sliding_window_causal_mask,
|
| 1442 |
+
"chunked_attention": create_chunked_causal_mask,
|
| 1443 |
+
"compressed_sparse_attention": create_sliding_window_causal_mask,
|
| 1444 |
+
"heavily_compressed_attention": create_sliding_window_causal_mask,
|
| 1445 |
+
}
|
| 1446 |
+
|
| 1447 |
+
|
| 1448 |
+
def create_masks_for_generate(
|
| 1449 |
+
config: PreTrainedConfig,
|
| 1450 |
+
inputs_embeds: torch.Tensor,
|
| 1451 |
+
attention_mask: torch.Tensor | None,
|
| 1452 |
+
past_key_values: Cache | None,
|
| 1453 |
+
position_ids: torch.Tensor | None = None,
|
| 1454 |
+
or_mask_function: Callable | None = None,
|
| 1455 |
+
and_mask_function: Callable | None = None,
|
| 1456 |
+
block_sequence_ids: torch.Tensor | None = None,
|
| 1457 |
+
**kwargs,
|
| 1458 |
+
):
|
| 1459 |
+
"""
|
| 1460 |
+
This function mimics how we create the masks in the `modeling_xxx.py` files, and is used in places like `generate`
|
| 1461 |
+
in order to easily create the masks in advance, when we compile the forwards with Static caches.
|
| 1462 |
+
|
| 1463 |
+
Args:
|
| 1464 |
+
config (`PreTrainedConfig`):
|
| 1465 |
+
The model config.
|
| 1466 |
+
inputs_embeds (`torch.Tensor`):
|
| 1467 |
+
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
| 1468 |
+
batch size, query length and dtype.
|
| 1469 |
+
attention_mask (`torch.Tensor`, optional):
|
| 1470 |
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
| 1471 |
+
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
| 1472 |
+
past_key_values (`Cache`, optional):
|
| 1473 |
+
The past key values, if we use a cache.
|
| 1474 |
+
position_ids (`torch.Tensor`, optional)
|
| 1475 |
+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
|
| 1476 |
+
or_mask_function (`Callable`, optional):
|
| 1477 |
+
An optional mask function to combine with the other mask function (by doing the union of both). This is
|
| 1478 |
+
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
| 1479 |
+
and_mask_function (`Callable`, optional):
|
| 1480 |
+
An optional mask function to combine with the other mask function (by doing the intersection of both). This is
|
| 1481 |
+
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
| 1482 |
+
block_sequence_ids (`torch.Tensor`, *optional*):
|
| 1483 |
+
A tensor of same shape as input IDs indicating to which block or group each token belongs to. Tokens from
|
| 1484 |
+
the same block will keep a bidirectional mask within the block, attending causally to the past. Index `-1`
|
| 1485 |
+
can be used for blocks that have to keep complete causality within itself.
|
| 1486 |
+
"""
|
| 1487 |
+
# The attribute reside in the text config for composite models
|
| 1488 |
+
effective_config = config.get_text_config()
|
| 1489 |
+
# Prepare the mask args
|
| 1490 |
+
mask_kwargs = {
|
| 1491 |
+
"config": effective_config,
|
| 1492 |
+
"inputs_embeds": inputs_embeds,
|
| 1493 |
+
"attention_mask": attention_mask,
|
| 1494 |
+
"past_key_values": past_key_values,
|
| 1495 |
+
"position_ids": position_ids,
|
| 1496 |
+
"or_mask_function": or_mask_function,
|
| 1497 |
+
"and_mask_function": and_mask_function,
|
| 1498 |
+
"block_sequence_ids": block_sequence_ids,
|
| 1499 |
+
}
|
| 1500 |
+
|
| 1501 |
+
# If the attribute exist, we need several masks
|
| 1502 |
+
if hasattr(effective_config, "layer_types"):
|
| 1503 |
+
causal_masks = {}
|
| 1504 |
+
for layer_pattern in set(effective_config.layer_types):
|
| 1505 |
+
causal_masks[layer_pattern] = LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING[layer_pattern](**mask_kwargs)
|
| 1506 |
+
return causal_masks
|
| 1507 |
+
# In this case, all layers are sliding
|
| 1508 |
+
elif getattr(effective_config, "sliding_window", None) is not None:
|
| 1509 |
+
return create_sliding_window_causal_mask(**mask_kwargs)
|
| 1510 |
+
# In this case, all layers are chunked
|
| 1511 |
+
elif getattr(effective_config, "attention_chunk_size", None) is not None:
|
| 1512 |
+
return create_chunked_causal_mask(**mask_kwargs)
|
| 1513 |
+
# All layers use standard causal attention
|
| 1514 |
+
return create_causal_mask(**mask_kwargs)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general
|
| 16 |
+
`masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now,
|
| 17 |
+
and will be removed in the future.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import warnings
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing import Union
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from .utils.import_utils import is_torchdynamo_compiling, is_tracing
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
DEPRECATION_MESSAGE = (
|
| 30 |
+
"The attention mask API under `transformers.modeling_attn_mask_utils` (`AttentionMaskConverter`) "
|
| 31 |
+
"is deprecated and will be removed in Transformers v5.10. Please use the new API in `transformers.masking_utils`."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class AttentionMaskConverter:
|
| 37 |
+
"""
|
| 38 |
+
A utility attention mask class that allows one to:
|
| 39 |
+
- Create a causal 4d mask
|
| 40 |
+
- Create a causal 4d mask with slided window
|
| 41 |
+
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
|
| 42 |
+
key_value_length) that can be multiplied with attention scores
|
| 43 |
+
|
| 44 |
+
Examples:
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
>>> import torch
|
| 48 |
+
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 49 |
+
|
| 50 |
+
>>> converter = AttentionMaskConverter(True)
|
| 51 |
+
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
|
| 52 |
+
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 53 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 54 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
| 55 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
|
| 56 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
Parameters:
|
| 60 |
+
is_causal (`bool`):
|
| 61 |
+
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
|
| 62 |
+
|
| 63 |
+
sliding_window (`int`, *optional*):
|
| 64 |
+
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
is_causal: bool
|
| 68 |
+
sliding_window: int
|
| 69 |
+
|
| 70 |
+
def __init__(self, is_causal: bool, sliding_window: int | None = None):
|
| 71 |
+
warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
|
| 72 |
+
|
| 73 |
+
self.is_causal = is_causal
|
| 74 |
+
self.sliding_window = sliding_window
|
| 75 |
+
|
| 76 |
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def to_causal_4d(
|
| 82 |
+
self,
|
| 83 |
+
batch_size: int,
|
| 84 |
+
query_length: int,
|
| 85 |
+
key_value_length: int,
|
| 86 |
+
dtype: torch.dtype,
|
| 87 |
+
device: Union[torch.device, "str"] = "cpu",
|
| 88 |
+
) -> torch.Tensor | None:
|
| 89 |
+
"""
|
| 90 |
+
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
| 91 |
+
bias to upper right hand triangular matrix (causal mask).
|
| 92 |
+
"""
|
| 93 |
+
if not self.is_causal:
|
| 94 |
+
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
|
| 95 |
+
|
| 96 |
+
# If shape is not cached, create a new causal mask and cache it
|
| 97 |
+
input_shape = (batch_size, query_length)
|
| 98 |
+
past_key_values_length = key_value_length - query_length
|
| 99 |
+
|
| 100 |
+
# create causal mask
|
| 101 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 102 |
+
causal_4d_mask = None
|
| 103 |
+
if input_shape[-1] > 1 or self.sliding_window is not None:
|
| 104 |
+
causal_4d_mask = self._make_causal_mask(
|
| 105 |
+
input_shape,
|
| 106 |
+
dtype,
|
| 107 |
+
device=device,
|
| 108 |
+
past_key_values_length=past_key_values_length,
|
| 109 |
+
sliding_window=self.sliding_window,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return causal_4d_mask
|
| 113 |
+
|
| 114 |
+
def to_4d(
|
| 115 |
+
self,
|
| 116 |
+
attention_mask_2d: torch.Tensor,
|
| 117 |
+
query_length: int,
|
| 118 |
+
dtype: torch.dtype,
|
| 119 |
+
key_value_length: int | None = None,
|
| 120 |
+
) -> torch.Tensor:
|
| 121 |
+
"""
|
| 122 |
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
| 123 |
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
| 124 |
+
causal, a causal mask will be added.
|
| 125 |
+
"""
|
| 126 |
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
| 127 |
+
|
| 128 |
+
# create causal mask
|
| 129 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 130 |
+
causal_4d_mask = None
|
| 131 |
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
| 132 |
+
if key_value_length is None:
|
| 133 |
+
raise ValueError(
|
| 134 |
+
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
past_key_values_length = key_value_length - query_length
|
| 138 |
+
causal_4d_mask = self._make_causal_mask(
|
| 139 |
+
input_shape,
|
| 140 |
+
dtype,
|
| 141 |
+
device=attention_mask_2d.device,
|
| 142 |
+
past_key_values_length=past_key_values_length,
|
| 143 |
+
sliding_window=self.sliding_window,
|
| 144 |
+
)
|
| 145 |
+
elif self.sliding_window is not None:
|
| 146 |
+
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
|
| 147 |
+
|
| 148 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 149 |
+
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
|
| 150 |
+
attention_mask_2d.device
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if causal_4d_mask is not None:
|
| 154 |
+
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
|
| 155 |
+
|
| 156 |
+
# expanded_attn_mask + causal_4d_mask can cause some overflow
|
| 157 |
+
expanded_4d_mask = expanded_attn_mask
|
| 158 |
+
|
| 159 |
+
return expanded_4d_mask
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def _make_causal_mask(
|
| 163 |
+
input_ids_shape: torch.Size,
|
| 164 |
+
dtype: torch.dtype,
|
| 165 |
+
device: torch.device,
|
| 166 |
+
past_key_values_length: int = 0,
|
| 167 |
+
sliding_window: int | None = None,
|
| 168 |
+
):
|
| 169 |
+
"""
|
| 170 |
+
Make causal mask used for bi-directional self-attention.
|
| 171 |
+
"""
|
| 172 |
+
warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
|
| 173 |
+
|
| 174 |
+
bsz, tgt_len = input_ids_shape
|
| 175 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 176 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 177 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 178 |
+
|
| 179 |
+
mask = mask.to(dtype)
|
| 180 |
+
|
| 181 |
+
if past_key_values_length > 0:
|
| 182 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
| 183 |
+
|
| 184 |
+
# add lower triangular sliding window mask if necessary
|
| 185 |
+
if sliding_window is not None:
|
| 186 |
+
diagonal = past_key_values_length - sliding_window - 1
|
| 187 |
+
|
| 188 |
+
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
|
| 189 |
+
# Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
|
| 190 |
+
# See https://github.com/pytorch/pytorch/issues/127571
|
| 191 |
+
if is_torchdynamo_compiling():
|
| 192 |
+
mask = mask.clone()
|
| 193 |
+
mask.masked_fill_(context_mask, torch.finfo(dtype).min)
|
| 194 |
+
|
| 195 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
|
| 199 |
+
"""
|
| 200 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 201 |
+
"""
|
| 202 |
+
warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
|
| 203 |
+
|
| 204 |
+
bsz, src_len = mask.size()
|
| 205 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 206 |
+
|
| 207 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 208 |
+
|
| 209 |
+
inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask
|
| 210 |
+
|
| 211 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
def _unmask_unattended(
|
| 215 |
+
expanded_mask: torch.FloatTensor,
|
| 216 |
+
min_dtype: float,
|
| 217 |
+
):
|
| 218 |
+
# fmt: off
|
| 219 |
+
"""
|
| 220 |
+
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
|
| 221 |
+
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 222 |
+
Details: https://github.com/pytorch/pytorch/issues/110213
|
| 223 |
+
|
| 224 |
+
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
|
| 225 |
+
`attention_mask` is [bsz, src_seq_len].
|
| 226 |
+
|
| 227 |
+
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
|
| 228 |
+
|
| 229 |
+
For example, if `expanded_mask` is (e.g. here left-padding case)
|
| 230 |
+
```
|
| 231 |
+
[[[[0, 0, 0],
|
| 232 |
+
[0, 0, 0],
|
| 233 |
+
[0, 0, 1]]],
|
| 234 |
+
[[[1, 0, 0],
|
| 235 |
+
[1, 1, 0],
|
| 236 |
+
[1, 1, 1]]],
|
| 237 |
+
[[[0, 0, 0],
|
| 238 |
+
[0, 1, 0],
|
| 239 |
+
[0, 1, 1]]]]
|
| 240 |
+
```
|
| 241 |
+
then the modified `expanded_mask` will be
|
| 242 |
+
```
|
| 243 |
+
[[[[1, 1, 1], <-- modified
|
| 244 |
+
[1, 1, 1], <-- modified
|
| 245 |
+
[0, 0, 1]]],
|
| 246 |
+
[[[1, 0, 0],
|
| 247 |
+
[1, 1, 0],
|
| 248 |
+
[1, 1, 1]]],
|
| 249 |
+
[[[1, 1, 1], <-- modified
|
| 250 |
+
[0, 1, 0],
|
| 251 |
+
[0, 1, 1]]]]
|
| 252 |
+
```
|
| 253 |
+
"""
|
| 254 |
+
warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
|
| 255 |
+
|
| 256 |
+
# fmt: on
|
| 257 |
+
if expanded_mask.dtype == torch.bool:
|
| 258 |
+
raise ValueError(
|
| 259 |
+
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
|
| 263 |
+
|
| 264 |
+
@staticmethod
|
| 265 |
+
def _ignore_causal_mask_sdpa(
|
| 266 |
+
attention_mask: torch.Tensor | None,
|
| 267 |
+
inputs_embeds: torch.Tensor,
|
| 268 |
+
past_key_values_length: int,
|
| 269 |
+
sliding_window: int | None = None,
|
| 270 |
+
is_training: bool = False,
|
| 271 |
+
) -> bool:
|
| 272 |
+
"""
|
| 273 |
+
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
|
| 274 |
+
ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
|
| 275 |
+
|
| 276 |
+
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
|
| 277 |
+
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
|
| 278 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
|
| 279 |
+
passed).
|
| 280 |
+
"""
|
| 281 |
+
warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
|
| 282 |
+
|
| 283 |
+
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
| 284 |
+
key_value_length = query_length + past_key_values_length
|
| 285 |
+
|
| 286 |
+
is_tracing_ = is_tracing(inputs_embeds)
|
| 287 |
+
|
| 288 |
+
ignore_causal_mask = False
|
| 289 |
+
|
| 290 |
+
if attention_mask is None:
|
| 291 |
+
# TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
|
| 292 |
+
# shape, thus SDPA's `is_causal` argument is rightfully updated
|
| 293 |
+
# (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
|
| 294 |
+
# `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
|
| 295 |
+
# hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
|
| 296 |
+
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
|
| 297 |
+
# Thus, we only set `ignore_causal_mask = True` if the model is set to training.
|
| 298 |
+
#
|
| 299 |
+
# Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
|
| 300 |
+
# ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
|
| 301 |
+
if (
|
| 302 |
+
(is_training or not is_tracing_)
|
| 303 |
+
and (query_length == 1 or key_value_length == query_length)
|
| 304 |
+
and (sliding_window is None or key_value_length < sliding_window)
|
| 305 |
+
):
|
| 306 |
+
ignore_causal_mask = True
|
| 307 |
+
elif sliding_window is None or key_value_length < sliding_window:
|
| 308 |
+
if len(attention_mask.shape) == 4:
|
| 309 |
+
return False
|
| 310 |
+
elif not is_tracing_ and torch.all(attention_mask == 1):
|
| 311 |
+
if query_length == 1 or key_value_length == query_length:
|
| 312 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
| 313 |
+
ignore_causal_mask = True
|
| 314 |
+
|
| 315 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
|
| 316 |
+
# the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
|
| 317 |
+
# SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
| 318 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
| 319 |
+
# TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
|
| 320 |
+
|
| 321 |
+
return ignore_causal_mask
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _prepare_4d_causal_attention_mask(
|
| 325 |
+
attention_mask: torch.Tensor | None,
|
| 326 |
+
input_shape: torch.Size | tuple | list,
|
| 327 |
+
inputs_embeds: torch.Tensor,
|
| 328 |
+
past_key_values_length: int,
|
| 329 |
+
sliding_window: int | None = None,
|
| 330 |
+
):
|
| 331 |
+
"""
|
| 332 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 333 |
+
`(batch_size, key_value_length)`
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
attention_mask (`torch.Tensor` or `None`):
|
| 337 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 338 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
| 339 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
| 340 |
+
inputs_embeds (`torch.Tensor`):
|
| 341 |
+
The embedded inputs as a torch Tensor.
|
| 342 |
+
past_key_values_length (`int`):
|
| 343 |
+
The length of the key value cache.
|
| 344 |
+
sliding_window (`int`, *optional*):
|
| 345 |
+
If the model uses windowed attention, a sliding window should be passed.
|
| 346 |
+
"""
|
| 347 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
| 348 |
+
|
| 349 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
| 350 |
+
|
| 351 |
+
# 4d mask is passed through the layers
|
| 352 |
+
if attention_mask is not None and len(attention_mask.shape) == 2:
|
| 353 |
+
attention_mask = attn_mask_converter.to_4d(
|
| 354 |
+
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
|
| 355 |
+
)
|
| 356 |
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
| 357 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
| 358 |
+
if tuple(attention_mask.shape) != expected_shape:
|
| 359 |
+
raise ValueError(
|
| 360 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
| 364 |
+
inverted_mask = 1.0 - attention_mask
|
| 365 |
+
attention_mask = inverted_mask.masked_fill(
|
| 366 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
| 367 |
+
)
|
| 368 |
+
else:
|
| 369 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
| 370 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
return attention_mask
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# Adapted from _prepare_4d_causal_attention_mask
|
| 377 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(
|
| 378 |
+
attention_mask: torch.Tensor | None,
|
| 379 |
+
input_shape: torch.Size | tuple | list,
|
| 380 |
+
inputs_embeds: torch.Tensor,
|
| 381 |
+
past_key_values_length: int,
|
| 382 |
+
sliding_window: int | None = None,
|
| 383 |
+
):
|
| 384 |
+
"""
|
| 385 |
+
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
|
| 386 |
+
|
| 387 |
+
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
|
| 388 |
+
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
|
| 389 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
| 390 |
+
"""
|
| 391 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
| 392 |
+
|
| 393 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
| 394 |
+
|
| 395 |
+
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
| 396 |
+
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
| 397 |
+
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
| 398 |
+
is_tracing_ = is_tracing(inputs_embeds)
|
| 399 |
+
|
| 400 |
+
ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 401 |
+
attention_mask=attention_mask,
|
| 402 |
+
inputs_embeds=inputs_embeds,
|
| 403 |
+
past_key_values_length=past_key_values_length,
|
| 404 |
+
sliding_window=sliding_window,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
if ignore_causal_mask:
|
| 408 |
+
expanded_4d_mask = None
|
| 409 |
+
elif attention_mask is None:
|
| 410 |
+
expanded_4d_mask = attn_mask_converter.to_causal_4d(
|
| 411 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| 412 |
+
)
|
| 413 |
+
else:
|
| 414 |
+
if attention_mask.dim() == 4:
|
| 415 |
+
expanded_4d_mask = attention_mask
|
| 416 |
+
else:
|
| 417 |
+
expanded_4d_mask = attn_mask_converter.to_4d(
|
| 418 |
+
attention_mask,
|
| 419 |
+
input_shape[-1],
|
| 420 |
+
dtype=inputs_embeds.dtype,
|
| 421 |
+
key_value_length=key_value_length,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
|
| 425 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 426 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 427 |
+
if not is_tracing_ and expanded_4d_mask.device.type in ["cuda", "xpu"]:
|
| 428 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
| 429 |
+
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
return expanded_4d_mask
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
|
| 436 |
+
"""
|
| 437 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 438 |
+
`(batch_size, key_value_length)`
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
mask (`torch.Tensor`):
|
| 442 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 443 |
+
dtype (`torch.dtype`):
|
| 444 |
+
The torch dtype the created mask shall have.
|
| 445 |
+
tgt_len (`int`):
|
| 446 |
+
The target length or query length the created mask shall have.
|
| 447 |
+
"""
|
| 448 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None):
|
| 452 |
+
"""
|
| 453 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 454 |
+
`(batch_size, key_value_length)`
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
mask (`torch.Tensor`):
|
| 458 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 459 |
+
dtype (`torch.dtype`):
|
| 460 |
+
The torch dtype the created mask shall have.
|
| 461 |
+
tgt_len (`int`):
|
| 462 |
+
The target length or query length the created mask shall have.
|
| 463 |
+
"""
|
| 464 |
+
warnings.warn(DEPRECATION_MESSAGE, FutureWarning)
|
| 465 |
+
|
| 466 |
+
_, key_value_length = mask.shape
|
| 467 |
+
tgt_len = tgt_len if tgt_len is not None else key_value_length
|
| 468 |
+
|
| 469 |
+
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
|
| 470 |
+
if not is_tracing(mask) and torch.all(mask == 1):
|
| 471 |
+
return None
|
| 472 |
+
else:
|
| 473 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def _create_4d_causal_attention_mask(
|
| 477 |
+
input_shape: torch.Size | tuple | list,
|
| 478 |
+
dtype: torch.dtype,
|
| 479 |
+
device: torch.device,
|
| 480 |
+
past_key_values_length: int = 0,
|
| 481 |
+
sliding_window: int | None = None,
|
| 482 |
+
) -> torch.Tensor | None:
|
| 483 |
+
"""
|
| 484 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
| 488 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
| 489 |
+
dtype (`torch.dtype`):
|
| 490 |
+
The torch dtype the created mask shall have.
|
| 491 |
+
device (`int`):
|
| 492 |
+
The torch device the created mask shall have.
|
| 493 |
+
sliding_window (`int`, *optional*):
|
| 494 |
+
If the model uses windowed attention, a sliding window should be passed.
|
| 495 |
+
"""
|
| 496 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
| 497 |
+
|
| 498 |
+
key_value_length = past_key_values_length + input_shape[-1]
|
| 499 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
| 500 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
return attention_mask
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/blt/configuration_blt.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Blt model configuration"""
|
| 15 |
+
|
| 16 |
+
from huggingface_hub.dataclasses import strict
|
| 17 |
+
|
| 18 |
+
from ...configuration_utils import PreTrainedConfig
|
| 19 |
+
from ...modeling_rope_utils import RopeParameters
|
| 20 |
+
from ...utils import auto_docstring, logging
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@auto_docstring(checkpoint="itazap/blt-1b-hf")
|
| 27 |
+
@strict
|
| 28 |
+
class BltLocalEncoderConfig(PreTrainedConfig):
|
| 29 |
+
r"""
|
| 30 |
+
cross_attn_all_layers (`bool`, *optional*, defaults to `True`):
|
| 31 |
+
Whether all attention layers have cross attention.
|
| 32 |
+
cross_attn_k (`int`, *optional*, defaults to 2):
|
| 33 |
+
Number of cross-attention heads used in the model.
|
| 34 |
+
hidden_size_global (`int`, *int*, defaults to 2048):
|
| 35 |
+
Hidden size of the global transformer layer.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
model_type = "blt_local_encoder"
|
| 39 |
+
default_theta = 500000.0
|
| 40 |
+
|
| 41 |
+
vocab_size: int = 260
|
| 42 |
+
cross_attn_all_layers: bool | None = False
|
| 43 |
+
cross_attn_k: int | None = 2
|
| 44 |
+
hidden_size_global: int | None = 2048
|
| 45 |
+
hidden_size: int = 1024
|
| 46 |
+
num_attention_heads: int = 16
|
| 47 |
+
num_key_value_heads: int | None = None
|
| 48 |
+
num_hidden_layers: int = 1
|
| 49 |
+
rms_norm_eps: float = 1e-5
|
| 50 |
+
dropout: float | int | None = 0.0
|
| 51 |
+
max_position_embeddings: int = 24576
|
| 52 |
+
rope_parameters: RopeParameters | dict | None = None
|
| 53 |
+
hidden_act: str = "silu"
|
| 54 |
+
intermediate_size: int | None = None
|
| 55 |
+
initializer_range: float = 0.02
|
| 56 |
+
|
| 57 |
+
def __post_init__(self, **kwargs):
|
| 58 |
+
self.num_key_value_heads = self.num_key_value_heads or self.num_attention_heads
|
| 59 |
+
self.intermediate_size = self.intermediate_size or int(8 * self.hidden_size / 3)
|
| 60 |
+
self.tie_word_embeddings = False
|
| 61 |
+
super().__post_init__(**kwargs)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@auto_docstring(checkpoint="itazap/blt-1b-hf")
|
| 65 |
+
@strict
|
| 66 |
+
class BltLocalDecoderConfig(PreTrainedConfig):
|
| 67 |
+
r"""
|
| 68 |
+
cross_attn_all_layers (`bool`, *optional*, defaults to `True`):
|
| 69 |
+
Whether all attention layers have cross attention.
|
| 70 |
+
cross_attn_k (`int`, *optional*, defaults to 2):
|
| 71 |
+
Number of cross-attention heads used in the model.
|
| 72 |
+
hidden_size_global (`int`, *int*, defaults to 2048):
|
| 73 |
+
Hidden size of the global transformer layer.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
model_type = "blt_local_decoder"
|
| 77 |
+
default_theta = 500000.0
|
| 78 |
+
|
| 79 |
+
vocab_size: int = 260
|
| 80 |
+
cross_attn_all_layers: bool | None = True
|
| 81 |
+
cross_attn_k: int | None = 2
|
| 82 |
+
hidden_size_global: int | None = 2048
|
| 83 |
+
hidden_size: int = 1024
|
| 84 |
+
num_attention_heads: int = 16
|
| 85 |
+
num_key_value_heads: int | None = None
|
| 86 |
+
num_hidden_layers: int = 9
|
| 87 |
+
rms_norm_eps: float = 1e-5
|
| 88 |
+
dropout: float | int | None = 0.0
|
| 89 |
+
max_position_embeddings: int = 24576
|
| 90 |
+
rope_parameters: RopeParameters | dict | None = None
|
| 91 |
+
hidden_act: str = "silu"
|
| 92 |
+
intermediate_size: int = 2816
|
| 93 |
+
initializer_range: float = 0.02
|
| 94 |
+
pad_token_id: int | None = None
|
| 95 |
+
bos_token_id: int | None = None
|
| 96 |
+
eos_token_id: int | list[int] | None = None
|
| 97 |
+
tie_word_embeddings: bool = False
|
| 98 |
+
|
| 99 |
+
def __post_init__(self, **kwargs):
|
| 100 |
+
self.num_key_value_heads = self.num_key_value_heads or self.num_attention_heads
|
| 101 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
| 102 |
+
self.intermediate_size = self.intermediate_size or int(8 * self.hidden_size / 3)
|
| 103 |
+
self.tie_word_embeddings = False # Force-set to False for BC
|
| 104 |
+
super().__post_init__(**kwargs)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@auto_docstring(checkpoint="itazap/blt-1b-hf")
|
| 108 |
+
@strict
|
| 109 |
+
class BltGlobalTransformerConfig(PreTrainedConfig):
|
| 110 |
+
model_type = "blt_global_transformer"
|
| 111 |
+
default_theta = 500000.0
|
| 112 |
+
|
| 113 |
+
hidden_size: int = 2048
|
| 114 |
+
num_attention_heads: int = 16
|
| 115 |
+
num_key_value_heads: int | None = None
|
| 116 |
+
num_hidden_layers: int = 25
|
| 117 |
+
rms_norm_eps: float = 1e-5
|
| 118 |
+
dropout: float | int | None = 0.0
|
| 119 |
+
max_position_embeddings: int = 4096
|
| 120 |
+
rope_parameters: RopeParameters | dict | None = None
|
| 121 |
+
hidden_act: str = "silu"
|
| 122 |
+
intermediate_size: int = 5632
|
| 123 |
+
initializer_range: float = 0.02
|
| 124 |
+
tie_word_embeddings: bool = False
|
| 125 |
+
|
| 126 |
+
def __post_init__(self, **kwargs):
|
| 127 |
+
self.num_key_value_heads = self.num_key_value_heads or self.num_attention_heads
|
| 128 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
| 129 |
+
self.intermediate_size = self.intermediate_size or int(8 * self.hidden_size / 3)
|
| 130 |
+
self.tie_word_embeddings = False
|
| 131 |
+
|
| 132 |
+
super().__post_init__(**kwargs)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@auto_docstring(checkpoint="itazap/blt-1b-hf")
|
| 136 |
+
@strict
|
| 137 |
+
class BltPatcherConfig(PreTrainedConfig):
|
| 138 |
+
model_type = "blt_patcher"
|
| 139 |
+
|
| 140 |
+
vocab_size: int = 260
|
| 141 |
+
hidden_size: int = 768
|
| 142 |
+
num_hidden_layers: int = 14
|
| 143 |
+
num_attention_heads: int = 12
|
| 144 |
+
num_key_value_heads: int | None = None
|
| 145 |
+
max_position_embeddings: int = 8192
|
| 146 |
+
rms_norm_eps: float = 1e-5
|
| 147 |
+
dropout: float | int | None = 0.0
|
| 148 |
+
intermediate_size: int = 2048
|
| 149 |
+
rope_parameters: RopeParameters | dict | None = None
|
| 150 |
+
initializer_range: float = 0.02
|
| 151 |
+
tie_word_embeddings: bool = False
|
| 152 |
+
|
| 153 |
+
def __post_init__(self, **kwargs):
|
| 154 |
+
self.num_key_value_heads = self.num_key_value_heads or self.num_attention_heads
|
| 155 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
| 156 |
+
self.intermediate_size = self.intermediate_size or int(8 * self.hidden_size / 3)
|
| 157 |
+
self.tie_word_embeddings = False
|
| 158 |
+
self.hidden_act = "silu" # Blt uses silu activation
|
| 159 |
+
|
| 160 |
+
super().__post_init__(**kwargs)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@auto_docstring(checkpoint="itazap/blt-1b-hf")
|
| 164 |
+
@strict
|
| 165 |
+
class BltConfig(PreTrainedConfig):
|
| 166 |
+
r"""
|
| 167 |
+
patch_in_forward (`bool`, *optional*, defaults to `True`):
|
| 168 |
+
Whether to perform patching during the forward pass.
|
| 169 |
+
patch_size (`int`, *optional*, defaults to 4):
|
| 170 |
+
Size of the patches used in the patching mechanism.
|
| 171 |
+
patching_mode (`str`, *optional*, defaults to `"entropy"`):
|
| 172 |
+
The mode used for patching, such as entropy-based patching.
|
| 173 |
+
patching_threshold (`float`, *optional*, defaults to 1.34):
|
| 174 |
+
Threshold value used for determining when to apply patches.
|
| 175 |
+
patching_batch_size (`int`, *optional*, defaults to 1):
|
| 176 |
+
Batch size used during the patching process.
|
| 177 |
+
max_patch_length (`int`, *optional*):
|
| 178 |
+
Maximum length of patches that can be generated.
|
| 179 |
+
cross_attn_k (`int`, *optional*, defaults to 2):
|
| 180 |
+
Number of cross-attention heads used in the model.
|
| 181 |
+
encoder_hash_byte_group_size (`list`, *optional*):
|
| 182 |
+
List of byte group sizes used in the encoder hash function.
|
| 183 |
+
encoder_hash_byte_group_vocab (`int`, *optional*, defaults to 500002):
|
| 184 |
+
Vocabulary size for the encoder hash byte groups.
|
| 185 |
+
encoder_hash_byte_group_nb_functions (`int`, *optional*, defaults to 1):
|
| 186 |
+
Number of hash functions used in the encoder byte grouping.
|
| 187 |
+
patcher_config (`BltPatcherConfig`, *optional*):
|
| 188 |
+
Configuration for the patcher component of the model.
|
| 189 |
+
global_config (`BltGlobalTransformerConfig`, *optional*):
|
| 190 |
+
Configuration for the global transformer component of the model.
|
| 191 |
+
|
| 192 |
+
Example:
|
| 193 |
+
```python
|
| 194 |
+
>>> from transformers import BltModel, BltConfig
|
| 195 |
+
|
| 196 |
+
>>> # Initializing a Blt configuration
|
| 197 |
+
>>> configuration = BltConfig()
|
| 198 |
+
|
| 199 |
+
>>> # Initializing a model from the configuration
|
| 200 |
+
>>> model = BltModel(configuration)
|
| 201 |
+
|
| 202 |
+
>>> # Accessing the model configuration
|
| 203 |
+
>>> configuration = model.config
|
| 204 |
+
```"""
|
| 205 |
+
|
| 206 |
+
model_type = "blt"
|
| 207 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 208 |
+
default_theta = 500000.0
|
| 209 |
+
sub_configs = {
|
| 210 |
+
"patcher_config": BltPatcherConfig,
|
| 211 |
+
"encoder_config": BltLocalEncoderConfig,
|
| 212 |
+
"decoder_config": BltLocalDecoderConfig,
|
| 213 |
+
"global_config": BltGlobalTransformerConfig,
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
vocab_size: int = 260
|
| 217 |
+
max_position_embeddings: int = 4096
|
| 218 |
+
patch_in_forward: bool | None = True
|
| 219 |
+
patch_size: int | None = 4
|
| 220 |
+
patching_mode: str | None = "entropy"
|
| 221 |
+
patching_threshold: float | None = 1.335442066192627
|
| 222 |
+
patching_batch_size: int | None = 1
|
| 223 |
+
max_patch_length: int | None = None
|
| 224 |
+
cross_attn_k: int | None = 2
|
| 225 |
+
encoder_hash_byte_group_size: list[int] | None = None
|
| 226 |
+
encoder_hash_byte_group_vocab: int | None = 500002
|
| 227 |
+
encoder_hash_byte_group_nb_functions: int | None = 1
|
| 228 |
+
patcher_config: dict | PreTrainedConfig | None = None
|
| 229 |
+
encoder_config: dict | PreTrainedConfig | None = None
|
| 230 |
+
decoder_config: dict | PreTrainedConfig | None = None
|
| 231 |
+
global_config: dict | PreTrainedConfig | None = None
|
| 232 |
+
tie_word_embeddings: bool = False
|
| 233 |
+
pad_token_id: int | None = None
|
| 234 |
+
bos_token_id: int | None = None
|
| 235 |
+
eos_token_id: int | list[int] | None = None
|
| 236 |
+
initializer_range: float = 0.02
|
| 237 |
+
rope_parameters: RopeParameters | dict | None = None
|
| 238 |
+
|
| 239 |
+
def __post_init__(self, **kwargs):
|
| 240 |
+
self.encoder_hash_byte_group_size = self.encoder_hash_byte_group_size or [3, 4, 5, 6, 7, 8]
|
| 241 |
+
|
| 242 |
+
# Initialize component configurations
|
| 243 |
+
if self.patcher_config is None:
|
| 244 |
+
self.patcher_config = BltPatcherConfig(initializer_range=self.initializer_range)
|
| 245 |
+
logger.info("patcher_config is None, using default Blt patcher config")
|
| 246 |
+
elif isinstance(self.patcher_config, dict):
|
| 247 |
+
self.patcher_config.setdefault("initializer_range", self.initializer_range)
|
| 248 |
+
self.patcher_config = BltPatcherConfig(**self.patcher_config)
|
| 249 |
+
|
| 250 |
+
if self.encoder_config is None:
|
| 251 |
+
self.encoder_config = BltLocalEncoderConfig(initializer_range=self.initializer_range)
|
| 252 |
+
logger.info("encoder_config is None, using default Blt encoder config")
|
| 253 |
+
elif isinstance(self.encoder_config, dict):
|
| 254 |
+
self.encoder_config.setdefault("initializer_range", self.initializer_range)
|
| 255 |
+
self.encoder_config = BltLocalEncoderConfig(**self.encoder_config)
|
| 256 |
+
|
| 257 |
+
if self.decoder_config is None:
|
| 258 |
+
self.decoder_config = BltLocalDecoderConfig(initializer_range=self.initializer_range)
|
| 259 |
+
logger.info("decoder_config is None, using default Blt decoder config")
|
| 260 |
+
elif isinstance(self.decoder_config, dict):
|
| 261 |
+
self.decoder_config.setdefault("initializer_range", self.initializer_range)
|
| 262 |
+
self.decoder_config = BltLocalDecoderConfig(**self.decoder_config)
|
| 263 |
+
|
| 264 |
+
if self.global_config is None:
|
| 265 |
+
self.global_config = BltGlobalTransformerConfig(initializer_range=self.initializer_range)
|
| 266 |
+
logger.info("global_config is None, using default Blt global config")
|
| 267 |
+
elif isinstance(self.global_config, dict):
|
| 268 |
+
self.global_config.setdefault("initializer_range", self.initializer_range)
|
| 269 |
+
self.global_config = BltGlobalTransformerConfig(**self.global_config)
|
| 270 |
+
|
| 271 |
+
# Determine if token embedding projection is needed based on dimension mismatch (7b)
|
| 272 |
+
encoder_cross_output_size = self.encoder_config.hidden_size * self.cross_attn_k
|
| 273 |
+
self.global_config.encoder_cross_output_size = (
|
| 274 |
+
encoder_cross_output_size if encoder_cross_output_size != self.global_config.hidden_size else None
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
super().__post_init__(**kwargs)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
__all__ = [
|
| 281 |
+
"BltConfig",
|
| 282 |
+
"BltPatcherConfig",
|
| 283 |
+
"BltLocalEncoderConfig",
|
| 284 |
+
"BltLocalDecoderConfig",
|
| 285 |
+
"BltGlobalTransformerConfig",
|
| 286 |
+
]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/jetmoe/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_jetmoe import *
|
| 22 |
+
from .modeling_jetmoe import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/jetmoe/modeling_jetmoe.py
ADDED
|
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from src/transformers/models/jetmoe/modular_jetmoe.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_jetmoe.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
# Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
from collections.abc import Callable
|
| 22 |
+
from typing import Optional
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
from torch.nn import functional as F
|
| 27 |
+
|
| 28 |
+
from ... import initialization as init
|
| 29 |
+
from ...activations import ACT2FN
|
| 30 |
+
from ...cache_utils import Cache, DynamicCache
|
| 31 |
+
from ...generation import GenerationMixin
|
| 32 |
+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
|
| 33 |
+
from ...masking_utils import create_causal_mask
|
| 34 |
+
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
| 35 |
+
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
|
| 36 |
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 37 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 38 |
+
from ...processing_utils import Unpack
|
| 39 |
+
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
| 40 |
+
from ...utils.generic import maybe_autocast, merge_with_config_defaults
|
| 41 |
+
from ...utils.output_capturing import OutputRecorder, capture_outputs
|
| 42 |
+
from .configuration_jetmoe import JetMoeConfig
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 49 |
+
class JetMoeRMSNorm(nn.Module):
|
| 50 |
+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
|
| 51 |
+
"""
|
| 52 |
+
JetMoeRMSNorm is equivalent to T5LayerNorm
|
| 53 |
+
"""
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 56 |
+
self.variance_epsilon = eps
|
| 57 |
+
|
| 58 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
input_dtype = hidden_states.dtype
|
| 60 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 61 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 62 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 63 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 64 |
+
|
| 65 |
+
def extra_repr(self):
|
| 66 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class JetMoeRotaryEmbedding(nn.Module):
|
| 70 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 71 |
+
|
| 72 |
+
def __init__(self, config: JetMoeConfig, device=None):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 75 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 76 |
+
|
| 77 |
+
self.config = config
|
| 78 |
+
|
| 79 |
+
self.rope_type = self.config.rope_parameters["rope_type"]
|
| 80 |
+
rope_init_fn: Callable = self.compute_default_rope_parameters
|
| 81 |
+
if self.rope_type != "default":
|
| 82 |
+
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 83 |
+
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
| 84 |
+
|
| 85 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 86 |
+
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def compute_default_rope_parameters(
|
| 90 |
+
config: JetMoeConfig | None = None,
|
| 91 |
+
device: Optional["torch.device"] = None,
|
| 92 |
+
seq_len: int | None = None,
|
| 93 |
+
) -> tuple["torch.Tensor", float]:
|
| 94 |
+
"""
|
| 95 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
| 96 |
+
Args:
|
| 97 |
+
config ([`~transformers.PreTrainedConfig`]):
|
| 98 |
+
The model configuration.
|
| 99 |
+
device (`torch.device`):
|
| 100 |
+
The device to use for initialization of the inverse frequencies.
|
| 101 |
+
seq_len (`int`, *optional*):
|
| 102 |
+
The current sequence length. Unused for this type of RoPE.
|
| 103 |
+
Returns:
|
| 104 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
| 105 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
| 106 |
+
"""
|
| 107 |
+
base = config.rope_parameters["rope_theta"]
|
| 108 |
+
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 109 |
+
|
| 110 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
| 111 |
+
|
| 112 |
+
# Compute the inverse frequencies
|
| 113 |
+
inv_freq = 1.0 / (
|
| 114 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 115 |
+
)
|
| 116 |
+
return inv_freq, attention_factor
|
| 117 |
+
|
| 118 |
+
@torch.no_grad()
|
| 119 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 120 |
+
def forward(self, x, position_ids):
|
| 121 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 122 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 123 |
+
|
| 124 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 125 |
+
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
| 126 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 127 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 128 |
+
cos = emb.cos() * self.attention_scaling
|
| 129 |
+
sin = emb.sin() * self.attention_scaling
|
| 130 |
+
|
| 131 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class JetMoeParallelExperts(nn.Module):
|
| 135 |
+
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
|
| 136 |
+
"""
|
| 137 |
+
Initialize the JetMoeParallelExperts module.
|
| 138 |
+
The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
|
| 139 |
+
many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
|
| 140 |
+
[ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
|
| 141 |
+
[MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
|
| 142 |
+
used in vllm.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
num_experts (int):
|
| 146 |
+
Number of experts.
|
| 147 |
+
input_size (int):
|
| 148 |
+
Size of the input.
|
| 149 |
+
output_size (int):
|
| 150 |
+
Size of the output.
|
| 151 |
+
"""
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
|
| 154 |
+
self.num_experts = num_experts
|
| 155 |
+
self.input_size = input_size
|
| 156 |
+
self.output_size = output_size
|
| 157 |
+
|
| 158 |
+
def forward(self, inputs, expert_size):
|
| 159 |
+
"""
|
| 160 |
+
Forward pass of the JetMoeParallelExperts module.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
inputs (Tensor):
|
| 164 |
+
Input tensor.
|
| 165 |
+
expert_size:
|
| 166 |
+
Expert size information.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Tensor: Output tensor.
|
| 170 |
+
"""
|
| 171 |
+
input_list = inputs.split(expert_size, dim=0)
|
| 172 |
+
output_list = []
|
| 173 |
+
for i in range(self.num_experts):
|
| 174 |
+
output_list.append(F.linear(input_list[i], self.weight[i]))
|
| 175 |
+
results = torch.cat(output_list, dim=0)
|
| 176 |
+
return results
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class JetMoeTopKGating(nn.Module):
|
| 180 |
+
def __init__(self, input_size: int, num_experts: int, top_k: int):
|
| 181 |
+
"""
|
| 182 |
+
Initialize the top-k gating mechanism.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
input_size (`int`):
|
| 186 |
+
Size of the input.
|
| 187 |
+
num_experts (`int`):
|
| 188 |
+
Number of experts.
|
| 189 |
+
top_k (`int`):
|
| 190 |
+
Number of top experts to select.
|
| 191 |
+
"""
|
| 192 |
+
super().__init__()
|
| 193 |
+
|
| 194 |
+
self.num_experts = num_experts
|
| 195 |
+
self.input_size = input_size
|
| 196 |
+
self.top_k = top_k
|
| 197 |
+
|
| 198 |
+
self.layer = nn.Linear(input_size, num_experts, bias=False)
|
| 199 |
+
|
| 200 |
+
def forward(self, hidden_states):
|
| 201 |
+
# compute the top_k routing decision
|
| 202 |
+
logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
|
| 203 |
+
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
|
| 204 |
+
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
|
| 205 |
+
|
| 206 |
+
# compute number of input given to each expert
|
| 207 |
+
zeros = torch.zeros(
|
| 208 |
+
[top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
|
| 209 |
+
) # [num_tokens, num_experts]
|
| 210 |
+
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
|
| 211 |
+
expert_size = gates.long().sum(0) # [num_experts,]
|
| 212 |
+
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
|
| 213 |
+
# (and `DataDependentOutputException`)
|
| 214 |
+
expert_size = expert_size.tolist()
|
| 215 |
+
|
| 216 |
+
# sort and group input tokens according to expert assignment
|
| 217 |
+
top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
|
| 218 |
+
_, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
|
| 219 |
+
batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
|
| 220 |
+
|
| 221 |
+
# gather the gate values for grouped input tokens
|
| 222 |
+
top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
|
| 223 |
+
batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
|
| 224 |
+
|
| 225 |
+
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class JetMoeMoE(nn.Module):
|
| 229 |
+
"""
|
| 230 |
+
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
config:
|
| 234 |
+
Configuration object with model hyperparameters.
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
def __init__(self, config: JetMoeConfig):
|
| 238 |
+
super().__init__()
|
| 239 |
+
|
| 240 |
+
self.input_size = config.hidden_size
|
| 241 |
+
self.hidden_size = config.intermediate_size
|
| 242 |
+
self.activation = ACT2FN[config.activation_function]
|
| 243 |
+
self.bias = torch.nn.Parameter(torch.empty(self.input_size))
|
| 244 |
+
self.input_linear = JetMoeParallelExperts(config.num_local_experts, self.input_size, self.hidden_size * 2)
|
| 245 |
+
self.output_linear = JetMoeParallelExperts(config.num_local_experts, self.hidden_size, self.input_size)
|
| 246 |
+
|
| 247 |
+
self.router = JetMoeTopKGating(
|
| 248 |
+
input_size=self.input_size,
|
| 249 |
+
num_experts=config.num_local_experts,
|
| 250 |
+
top_k=config.num_experts_per_tok,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
def forward(self, layer_input):
|
| 254 |
+
"""
|
| 255 |
+
Forward pass of the mixture of experts layer.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
layer_input (Tensor):
|
| 259 |
+
Input tensor.
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
Tensor:
|
| 263 |
+
Output tensor.
|
| 264 |
+
Tensor:
|
| 265 |
+
Router logits.
|
| 266 |
+
"""
|
| 267 |
+
bsz, length, emb_size = layer_input.size()
|
| 268 |
+
layer_input = layer_input.reshape(-1, emb_size)
|
| 269 |
+
_, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
|
| 270 |
+
|
| 271 |
+
expert_inputs = layer_input[batch_index]
|
| 272 |
+
hidden_states = self.input_linear(expert_inputs, expert_size)
|
| 273 |
+
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
|
| 274 |
+
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
|
| 275 |
+
expert_outputs = self.output_linear(hidden_states, expert_size)
|
| 276 |
+
|
| 277 |
+
expert_outputs = expert_outputs * batch_gates[:, None]
|
| 278 |
+
|
| 279 |
+
zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
|
| 280 |
+
layer_output = zeros.index_add(0, batch_index, expert_outputs)
|
| 281 |
+
layer_output = layer_output.view(bsz, length, self.input_size)
|
| 282 |
+
layer_output = layer_output + self.bias
|
| 283 |
+
return layer_output
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class JetMoeMoA(nn.Module):
|
| 287 |
+
"""
|
| 288 |
+
A Sparsely gated mixture of attention layer with pairs of query- and output-projections as experts.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
config:
|
| 292 |
+
Configuration object with model hyperparameters.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__(self, config: JetMoeConfig):
|
| 296 |
+
super().__init__()
|
| 297 |
+
|
| 298 |
+
self.num_experts = config.num_local_experts
|
| 299 |
+
self.input_size = config.hidden_size
|
| 300 |
+
self.hidden_size = config.kv_channels * config.num_key_value_heads
|
| 301 |
+
self.top_k = config.num_experts_per_tok
|
| 302 |
+
self.bias = torch.nn.Parameter(torch.empty(self.input_size))
|
| 303 |
+
|
| 304 |
+
self.input_linear = JetMoeParallelExperts(self.num_experts, self.input_size, self.hidden_size)
|
| 305 |
+
self.output_linear = JetMoeParallelExperts(self.num_experts, self.hidden_size, self.input_size)
|
| 306 |
+
|
| 307 |
+
self.router = JetMoeTopKGating(
|
| 308 |
+
input_size=self.input_size,
|
| 309 |
+
num_experts=self.num_experts,
|
| 310 |
+
top_k=self.top_k,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def map(self, layer_input):
|
| 314 |
+
"""
|
| 315 |
+
Map inputs to attention experts according to routing decision and compute query projection inside each experts.
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
# Compute gating topology
|
| 319 |
+
bsz, length, emb_size = layer_input.size()
|
| 320 |
+
layer_input = layer_input.reshape(-1, emb_size) # [bsz * length, emb_size]
|
| 321 |
+
index_sorted_experts, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
|
| 322 |
+
topo_info = (index_sorted_experts, batch_index, batch_gates, expert_size)
|
| 323 |
+
|
| 324 |
+
# Group inputs according to topology and compute query projection
|
| 325 |
+
expert_inputs = layer_input[batch_index] # [bsz * length * top_k, emb_size]
|
| 326 |
+
expert_outputs = self.input_linear(expert_inputs, expert_size) # [bsz * length * top_k, hidden_size]
|
| 327 |
+
|
| 328 |
+
# Ungroup queries back to original order
|
| 329 |
+
zeros = torch.zeros(
|
| 330 |
+
(bsz * length * self.top_k, self.hidden_size), dtype=expert_outputs.dtype, device=expert_outputs.device
|
| 331 |
+
)
|
| 332 |
+
layer_output = zeros.index_add(0, index_sorted_experts, expert_outputs)
|
| 333 |
+
layer_output = layer_output.view(bsz, length, self.top_k, -1) # [bsz, length, top_k, hidden_size]
|
| 334 |
+
return layer_output, router_logits, topo_info
|
| 335 |
+
|
| 336 |
+
def reduce(self, layer_input, topo_info):
|
| 337 |
+
"""
|
| 338 |
+
Compute output projection inside each attention experts and merge the outputs of different experts.
|
| 339 |
+
"""
|
| 340 |
+
bsz, length, k, hidden_size = layer_input.size()
|
| 341 |
+
layer_input = layer_input.reshape(-1, hidden_size) # [bsz * length * k, hidden_size]
|
| 342 |
+
index_sorted_experts, batch_index, batch_gates, expert_size = topo_info
|
| 343 |
+
|
| 344 |
+
# Group inputs according to topology and compute output projection
|
| 345 |
+
expert_inputs = layer_input[index_sorted_experts] # [bsz * length * top_k, hidden_size]
|
| 346 |
+
expert_outputs = self.output_linear(expert_inputs, expert_size) # [bsz * length * top_k, emb_size]
|
| 347 |
+
|
| 348 |
+
# Apply gates to attention expert outputs
|
| 349 |
+
expert_outputs = expert_outputs * batch_gates[:, None]
|
| 350 |
+
|
| 351 |
+
# Ungroup and merge outputs to original order
|
| 352 |
+
zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
|
| 353 |
+
layer_output = zeros.index_add(0, batch_index, expert_outputs)
|
| 354 |
+
layer_output = layer_output.view(bsz, length, self.input_size)
|
| 355 |
+
layer_output = layer_output + self.bias
|
| 356 |
+
return layer_output
|
| 357 |
+
|
| 358 |
+
def forward(self, layer_input):
|
| 359 |
+
raise NotImplementedError("This module doesn't support call and forward.")
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def rotate_half(x):
|
| 363 |
+
"""Rotates half the hidden dims of the input."""
|
| 364 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 365 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 366 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
@use_kernel_func_from_hub("rotary_pos_emb")
|
| 370 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 371 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
q (`torch.Tensor`): The query tensor.
|
| 375 |
+
k (`torch.Tensor`): The key tensor.
|
| 376 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 377 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 378 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 379 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 380 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 381 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 382 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 383 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 384 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 385 |
+
Returns:
|
| 386 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 387 |
+
"""
|
| 388 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 389 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 390 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 391 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 392 |
+
return q_embed, k_embed
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 396 |
+
"""
|
| 397 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 398 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 399 |
+
"""
|
| 400 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 401 |
+
if n_rep == 1:
|
| 402 |
+
return hidden_states
|
| 403 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 404 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def eager_attention_forward(
|
| 408 |
+
module: nn.Module,
|
| 409 |
+
query: torch.Tensor,
|
| 410 |
+
key: torch.Tensor,
|
| 411 |
+
value: torch.Tensor,
|
| 412 |
+
attention_mask: torch.Tensor | None,
|
| 413 |
+
scaling: float,
|
| 414 |
+
dropout: float = 0.0,
|
| 415 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 416 |
+
):
|
| 417 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 418 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 419 |
+
|
| 420 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 421 |
+
if attention_mask is not None:
|
| 422 |
+
attn_weights = attn_weights + attention_mask
|
| 423 |
+
|
| 424 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 425 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 426 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 427 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 428 |
+
|
| 429 |
+
return attn_output, attn_weights
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
class JetMoeAttention(nn.Module):
|
| 433 |
+
"""
|
| 434 |
+
Multi-headed attention from 'Attention Is All You Need' paper.
|
| 435 |
+
"""
|
| 436 |
+
|
| 437 |
+
def __init__(self, config: JetMoeConfig, layer_idx: int | None = None):
|
| 438 |
+
"""
|
| 439 |
+
Initialize the JetMoeAttention module.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
config:
|
| 443 |
+
Configuration object with model hyperparameters.
|
| 444 |
+
layer_idx:
|
| 445 |
+
Index of the layer in the model.
|
| 446 |
+
"""
|
| 447 |
+
super().__init__()
|
| 448 |
+
self.config = config
|
| 449 |
+
self.layer_idx = layer_idx
|
| 450 |
+
self.is_causal = True
|
| 451 |
+
if layer_idx is None:
|
| 452 |
+
logger.warning_once(
|
| 453 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
| 454 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
| 455 |
+
"when creating this class."
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
self.num_key_value_groups = 1 # We ignore this by setting it to 1 as we have different repeat patterns
|
| 459 |
+
self.top_k = config.num_experts_per_tok
|
| 460 |
+
self.attention_dropout = config.attention_dropout
|
| 461 |
+
self.kv_projection_size = config.kv_channels * config.num_key_value_heads
|
| 462 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 463 |
+
self.num_heads = config.num_attention_heads
|
| 464 |
+
self.head_dim = config.kv_channels
|
| 465 |
+
self.scaling = self.head_dim**-0.5
|
| 466 |
+
self.experts = JetMoeMoA(config)
|
| 467 |
+
|
| 468 |
+
self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False)
|
| 469 |
+
|
| 470 |
+
def forward(
|
| 471 |
+
self,
|
| 472 |
+
hidden_states: torch.Tensor,
|
| 473 |
+
attention_mask: torch.Tensor | None = None,
|
| 474 |
+
position_embeddings: torch.LongTensor | None = None,
|
| 475 |
+
past_key_values: Cache | None = None,
|
| 476 |
+
**kwargs,
|
| 477 |
+
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
| 478 |
+
input_shape = hidden_states.shape[:-1]
|
| 479 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 480 |
+
|
| 481 |
+
query_states, router_logits, topo_info = self.experts.map(hidden_states)
|
| 482 |
+
key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
|
| 483 |
+
|
| 484 |
+
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
| 485 |
+
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
| 486 |
+
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
| 487 |
+
|
| 488 |
+
cos, sin = position_embeddings
|
| 489 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 490 |
+
|
| 491 |
+
if past_key_values is not None:
|
| 492 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
| 493 |
+
|
| 494 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 495 |
+
self.config._attn_implementation, eager_attention_forward
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# This is different from other models where we repeat k/v heads
|
| 499 |
+
# instead of repeat interleaving them
|
| 500 |
+
key_states = key_states.repeat(1, self.top_k, 1, 1)
|
| 501 |
+
value_states = value_states.repeat(1, self.top_k, 1, 1)
|
| 502 |
+
|
| 503 |
+
attn_output, attn_weights = attention_interface(
|
| 504 |
+
self,
|
| 505 |
+
query_states,
|
| 506 |
+
key_states,
|
| 507 |
+
value_states,
|
| 508 |
+
attention_mask,
|
| 509 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 510 |
+
scaling=self.scaling,
|
| 511 |
+
**kwargs,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
attn_output = attn_output.view(*input_shape, self.top_k, -1)
|
| 515 |
+
attn_output = self.experts.reduce(attn_output, topo_info)
|
| 516 |
+
attn_output = attn_output.view(*input_shape, -1)
|
| 517 |
+
return attn_output, attn_weights, router_logits
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class JetMoeDecoderLayer(GradientCheckpointingLayer):
|
| 521 |
+
def __init__(self, config: JetMoeConfig, layer_idx: int | None = None):
|
| 522 |
+
super().__init__()
|
| 523 |
+
self.hidden_size = config.hidden_size
|
| 524 |
+
self.mlp = JetMoeMoE(config)
|
| 525 |
+
self.input_layernorm = JetMoeRMSNorm(config.hidden_size)
|
| 526 |
+
self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size)
|
| 527 |
+
self.self_attention = JetMoeAttention(config, layer_idx)
|
| 528 |
+
|
| 529 |
+
def forward(
|
| 530 |
+
self,
|
| 531 |
+
hidden_states: torch.Tensor,
|
| 532 |
+
attention_mask: torch.Tensor | None = None,
|
| 533 |
+
position_ids: torch.LongTensor | None = None,
|
| 534 |
+
past_key_values: Cache | None = None,
|
| 535 |
+
use_cache: bool | None = False,
|
| 536 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 537 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 538 |
+
) -> torch.Tensor:
|
| 539 |
+
residual = hidden_states
|
| 540 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 541 |
+
# Self Attention
|
| 542 |
+
hidden_states, _, _ = self.self_attention(
|
| 543 |
+
hidden_states=hidden_states,
|
| 544 |
+
attention_mask=attention_mask,
|
| 545 |
+
position_ids=position_ids,
|
| 546 |
+
past_key_values=past_key_values,
|
| 547 |
+
use_cache=use_cache,
|
| 548 |
+
position_embeddings=position_embeddings,
|
| 549 |
+
**kwargs,
|
| 550 |
+
)
|
| 551 |
+
hidden_states = residual + hidden_states
|
| 552 |
+
|
| 553 |
+
# Fully Connected
|
| 554 |
+
residual = hidden_states
|
| 555 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 556 |
+
hidden_states = self.mlp(hidden_states)
|
| 557 |
+
hidden_states = residual + hidden_states
|
| 558 |
+
return hidden_states
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
@auto_docstring
|
| 562 |
+
class JetMoePreTrainedModel(PreTrainedModel):
|
| 563 |
+
config: JetMoeConfig
|
| 564 |
+
base_model_prefix = "model"
|
| 565 |
+
supports_gradient_checkpointing = False
|
| 566 |
+
_no_split_modules = ["JetMoeDecoderLayer"]
|
| 567 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 568 |
+
_supports_flash_attn = True
|
| 569 |
+
_supports_sdpa = True
|
| 570 |
+
_supports_flex_attn = True
|
| 571 |
+
_can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
|
| 572 |
+
_supports_attention_backend = True
|
| 573 |
+
_can_record_outputs = {
|
| 574 |
+
"router_logits": [OutputRecorder(JetMoeAttention, index=2), OutputRecorder(JetMoeTopKGating, index=4)],
|
| 575 |
+
"hidden_states": JetMoeDecoderLayer,
|
| 576 |
+
"attentions": OutputRecorder(JetMoeAttention, index=1),
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
@torch.no_grad()
|
| 580 |
+
def _init_weights(self, module):
|
| 581 |
+
"""Initialize the weights."""
|
| 582 |
+
super()._init_weights(module)
|
| 583 |
+
if isinstance(module, JetMoeParallelExperts):
|
| 584 |
+
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
| 585 |
+
elif isinstance(module, JetMoeMoA | JetMoeMoE):
|
| 586 |
+
init.zeros_(module.bias)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
@auto_docstring
|
| 590 |
+
class JetMoeModel(JetMoePreTrainedModel):
|
| 591 |
+
def __init__(self, config: JetMoeConfig):
|
| 592 |
+
super().__init__(config)
|
| 593 |
+
self.padding_idx = config.pad_token_id
|
| 594 |
+
self.vocab_size = config.vocab_size
|
| 595 |
+
|
| 596 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 597 |
+
self.layers = nn.ModuleList(
|
| 598 |
+
[JetMoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 599 |
+
)
|
| 600 |
+
self.norm = JetMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 601 |
+
self.rotary_emb = JetMoeRotaryEmbedding(config=config)
|
| 602 |
+
self.gradient_checkpointing = False
|
| 603 |
+
self._attn_implementation = config._attn_implementation
|
| 604 |
+
|
| 605 |
+
# Initialize weights and apply final processing
|
| 606 |
+
self.post_init()
|
| 607 |
+
|
| 608 |
+
@merge_with_config_defaults
|
| 609 |
+
@capture_outputs
|
| 610 |
+
@auto_docstring
|
| 611 |
+
def forward(
|
| 612 |
+
self,
|
| 613 |
+
input_ids: torch.LongTensor | None = None,
|
| 614 |
+
attention_mask: torch.Tensor | None = None,
|
| 615 |
+
position_ids: torch.LongTensor | None = None,
|
| 616 |
+
past_key_values: Cache | None = None,
|
| 617 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 618 |
+
use_cache: bool | None = None,
|
| 619 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 620 |
+
) -> MoeModelOutputWithPast:
|
| 621 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 622 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 623 |
+
|
| 624 |
+
if use_cache and past_key_values is None:
|
| 625 |
+
past_key_values = DynamicCache(config=self.config)
|
| 626 |
+
|
| 627 |
+
if inputs_embeds is None:
|
| 628 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 629 |
+
|
| 630 |
+
if position_ids is None:
|
| 631 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 632 |
+
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
|
| 633 |
+
position_ids = position_ids.unsqueeze(0)
|
| 634 |
+
|
| 635 |
+
causal_mask = create_causal_mask(
|
| 636 |
+
config=self.config,
|
| 637 |
+
inputs_embeds=inputs_embeds,
|
| 638 |
+
attention_mask=attention_mask,
|
| 639 |
+
past_key_values=past_key_values,
|
| 640 |
+
position_ids=position_ids,
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
hidden_states = inputs_embeds
|
| 644 |
+
|
| 645 |
+
# create position embeddings to be shared across the decoder layers
|
| 646 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 647 |
+
|
| 648 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 649 |
+
hidden_states = decoder_layer(
|
| 650 |
+
hidden_states,
|
| 651 |
+
position_embeddings=position_embeddings,
|
| 652 |
+
attention_mask=causal_mask,
|
| 653 |
+
past_key_values=past_key_values,
|
| 654 |
+
use_cache=use_cache,
|
| 655 |
+
position_ids=position_ids,
|
| 656 |
+
**kwargs,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
hidden_states = self.norm(hidden_states)
|
| 660 |
+
|
| 661 |
+
return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
|
| 662 |
+
last_hidden_state=hidden_states,
|
| 663 |
+
past_key_values=past_key_values,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def load_balancing_loss_func(
|
| 668 |
+
gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
|
| 669 |
+
num_experts: int | None = None,
|
| 670 |
+
top_k=2,
|
| 671 |
+
attention_mask: torch.Tensor | None = None,
|
| 672 |
+
) -> torch.Tensor | int:
|
| 673 |
+
r"""
|
| 674 |
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
| 675 |
+
|
| 676 |
+
See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
|
| 677 |
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
| 678 |
+
experts is too unbalanced.
|
| 679 |
+
|
| 680 |
+
Args:
|
| 681 |
+
gate_logits:
|
| 682 |
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
| 683 |
+
shape [batch_size X sequence_length, num_experts].
|
| 684 |
+
num_experts:
|
| 685 |
+
Number of experts
|
| 686 |
+
top_k:
|
| 687 |
+
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
| 688 |
+
parameter.
|
| 689 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 690 |
+
The attention_mask used in forward function
|
| 691 |
+
shape [batch_size X sequence_length] if not None.
|
| 692 |
+
|
| 693 |
+
Returns:
|
| 694 |
+
The auxiliary loss.
|
| 695 |
+
"""
|
| 696 |
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
| 697 |
+
return 0
|
| 698 |
+
|
| 699 |
+
if isinstance(gate_logits, tuple):
|
| 700 |
+
compute_device = gate_logits[0].device
|
| 701 |
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
| 702 |
+
|
| 703 |
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
|
| 704 |
+
|
| 705 |
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
| 706 |
+
|
| 707 |
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
| 708 |
+
|
| 709 |
+
if attention_mask is None:
|
| 710 |
+
# Compute the percentage of tokens routed to each experts
|
| 711 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 712 |
+
|
| 713 |
+
# Compute the average probability of routing to these experts
|
| 714 |
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
| 715 |
+
else:
|
| 716 |
+
batch_size, sequence_length = attention_mask.shape
|
| 717 |
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
| 718 |
+
|
| 719 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
| 720 |
+
expert_attention_mask = (
|
| 721 |
+
attention_mask[None, :, :, None, None]
|
| 722 |
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
| 723 |
+
.reshape(-1, top_k, num_experts)
|
| 724 |
+
.to(compute_device)
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
# Compute the percentage of tokens routed to each experts
|
| 728 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
| 729 |
+
expert_attention_mask, dim=0
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
| 733 |
+
router_per_expert_attention_mask = (
|
| 734 |
+
attention_mask[None, :, :, None]
|
| 735 |
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
| 736 |
+
.reshape(-1, num_experts)
|
| 737 |
+
.to(compute_device)
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
# Compute the average probability of routing to these experts
|
| 741 |
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
| 742 |
+
router_per_expert_attention_mask, dim=0
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
| 746 |
+
return overall_loss * num_experts
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin):
|
| 750 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 751 |
+
|
| 752 |
+
def __init__(self, config):
|
| 753 |
+
super().__init__(config)
|
| 754 |
+
self.model = JetMoeModel(config)
|
| 755 |
+
self.vocab_size = config.vocab_size
|
| 756 |
+
self.aux_loss_coef = config.aux_loss_coef
|
| 757 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 758 |
+
self.tie_word_embeddings = config.tie_word_embeddings
|
| 759 |
+
self.num_experts = config.num_local_experts
|
| 760 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 761 |
+
|
| 762 |
+
# Initialize weights and apply final processing
|
| 763 |
+
self.post_init()
|
| 764 |
+
|
| 765 |
+
@can_return_tuple
|
| 766 |
+
@auto_docstring
|
| 767 |
+
def forward(
|
| 768 |
+
self,
|
| 769 |
+
input_ids: torch.LongTensor | None = None,
|
| 770 |
+
attention_mask: torch.Tensor | None = None,
|
| 771 |
+
position_ids: torch.LongTensor | None = None,
|
| 772 |
+
past_key_values: Cache | None = None,
|
| 773 |
+
inputs_embeds: torch.FloatTensor | None = None,
|
| 774 |
+
labels: torch.LongTensor | None = None,
|
| 775 |
+
use_cache: bool | None = None,
|
| 776 |
+
logits_to_keep: int | torch.Tensor = 0,
|
| 777 |
+
output_router_logits: bool | None = False,
|
| 778 |
+
**kwargs,
|
| 779 |
+
) -> MoeCausalLMOutputWithPast:
|
| 780 |
+
outputs: MoeModelOutputWithPast = self.model(
|
| 781 |
+
input_ids=input_ids,
|
| 782 |
+
attention_mask=attention_mask,
|
| 783 |
+
position_ids=position_ids,
|
| 784 |
+
past_key_values=past_key_values,
|
| 785 |
+
inputs_embeds=inputs_embeds,
|
| 786 |
+
use_cache=use_cache,
|
| 787 |
+
output_router_logits=output_router_logits,
|
| 788 |
+
**kwargs,
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
hidden_states = outputs.last_hidden_state
|
| 792 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 793 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 794 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 795 |
+
|
| 796 |
+
loss = None
|
| 797 |
+
if labels is not None:
|
| 798 |
+
loss = self.loss_function(
|
| 799 |
+
logits,
|
| 800 |
+
labels,
|
| 801 |
+
vocab_size=self.config.vocab_size,
|
| 802 |
+
**kwargs,
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
aux_loss = None
|
| 806 |
+
if output_router_logits:
|
| 807 |
+
aux_loss = load_balancing_loss_func(
|
| 808 |
+
outputs.router_logits,
|
| 809 |
+
self.num_experts,
|
| 810 |
+
self.num_experts_per_tok,
|
| 811 |
+
attention_mask,
|
| 812 |
+
)
|
| 813 |
+
if labels is not None:
|
| 814 |
+
loss += self.aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
| 815 |
+
|
| 816 |
+
return MoeCausalLMOutputWithPast(
|
| 817 |
+
loss=loss,
|
| 818 |
+
aux_loss=aux_loss,
|
| 819 |
+
logits=logits,
|
| 820 |
+
past_key_values=outputs.past_key_values,
|
| 821 |
+
hidden_states=outputs.hidden_states,
|
| 822 |
+
attentions=outputs.attentions,
|
| 823 |
+
router_logits=outputs.router_logits,
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
class JetMoeForSequenceClassification(GenericForSequenceClassification, JetMoePreTrainedModel): ...
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
__all__ = ["JetMoeForCausalLM", "JetMoeModel", "JetMoePreTrainedModel", "JetMoeForSequenceClassification"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/vitmatte/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_vitmatte import *
|
| 22 |
+
from .image_processing_pil_vitmatte import *
|
| 23 |
+
from .image_processing_vitmatte import *
|
| 24 |
+
from .modeling_vitmatte import *
|
| 25 |
+
else:
|
| 26 |
+
import sys
|
| 27 |
+
|
| 28 |
+
_file = globals()["__file__"]
|
| 29 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/vitmatte/image_processing_pil_vitmatte.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Image processor class for ViTMatte."""
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from ...image_processing_backends import PilBackend
|
| 19 |
+
from ...image_processing_utils import BatchFeature
|
| 20 |
+
from ...image_transforms import PaddingMode
|
| 21 |
+
from ...image_transforms import pad as np_pad
|
| 22 |
+
from ...image_utils import (
|
| 23 |
+
IMAGENET_STANDARD_MEAN,
|
| 24 |
+
IMAGENET_STANDARD_STD,
|
| 25 |
+
ChannelDimension,
|
| 26 |
+
ImageInput,
|
| 27 |
+
get_image_size,
|
| 28 |
+
)
|
| 29 |
+
from ...processing_utils import ImagesKwargs, Unpack
|
| 30 |
+
from ...utils import TensorType, auto_docstring, is_torch_available
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if is_torch_available():
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Adapted from transformers.models.vitmatte.image_processing_vitmatte.VitMatteImageProcessorKwargs
|
| 38 |
+
class VitMatteImageProcessorKwargs(ImagesKwargs, total=False):
|
| 39 |
+
r"""
|
| 40 |
+
size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
|
| 41 |
+
The width and height of the image will be padded to be divisible by this number.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
size_divisor: int
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@auto_docstring
|
| 48 |
+
class VitMatteImageProcessorPil(PilBackend):
|
| 49 |
+
do_rescale = True
|
| 50 |
+
rescale_factor = 1 / 255
|
| 51 |
+
do_normalize = True
|
| 52 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
| 53 |
+
image_std = IMAGENET_STANDARD_STD
|
| 54 |
+
do_pad = True
|
| 55 |
+
size_divisor = 32
|
| 56 |
+
valid_kwargs = VitMatteImageProcessorKwargs
|
| 57 |
+
|
| 58 |
+
def __init__(self, **kwargs: Unpack[VitMatteImageProcessorKwargs]) -> None:
|
| 59 |
+
size_divisibility = kwargs.pop("size_divisibility", None)
|
| 60 |
+
if size_divisibility is not None:
|
| 61 |
+
kwargs.setdefault("size_divisor", size_divisibility)
|
| 62 |
+
super().__init__(**kwargs)
|
| 63 |
+
|
| 64 |
+
def pad_image(
|
| 65 |
+
self,
|
| 66 |
+
image: np.ndarray,
|
| 67 |
+
size_divisor: int = 32,
|
| 68 |
+
) -> np.ndarray:
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
image (`np.ndarray`):
|
| 72 |
+
Image to pad.
|
| 73 |
+
size_divisor (`int`, *optional*, defaults to 32):
|
| 74 |
+
The width and height of the image will be padded to be divisible by this number.
|
| 75 |
+
"""
|
| 76 |
+
height, width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
| 77 |
+
|
| 78 |
+
pad_height = 0 if height % size_divisor == 0 else size_divisor - height % size_divisor
|
| 79 |
+
pad_width = 0 if width % size_divisor == 0 else size_divisor - width % size_divisor
|
| 80 |
+
if pad_width + pad_height > 0:
|
| 81 |
+
padding = ((0, pad_height), (0, pad_width))
|
| 82 |
+
image = np_pad(
|
| 83 |
+
image,
|
| 84 |
+
padding=padding,
|
| 85 |
+
mode=PaddingMode.CONSTANT,
|
| 86 |
+
constant_values=0,
|
| 87 |
+
data_format=ChannelDimension.FIRST,
|
| 88 |
+
input_data_format=ChannelDimension.FIRST,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return image
|
| 92 |
+
|
| 93 |
+
@auto_docstring
|
| 94 |
+
def preprocess(
|
| 95 |
+
self,
|
| 96 |
+
images: ImageInput,
|
| 97 |
+
trimaps: ImageInput,
|
| 98 |
+
**kwargs: Unpack[VitMatteImageProcessorKwargs],
|
| 99 |
+
) -> BatchFeature:
|
| 100 |
+
r"""
|
| 101 |
+
trimaps (`ImageInput`):
|
| 102 |
+
The trimaps to preprocess.
|
| 103 |
+
"""
|
| 104 |
+
return super().preprocess(images, trimaps, **kwargs)
|
| 105 |
+
|
| 106 |
+
def _preprocess_image_like_inputs(
|
| 107 |
+
self,
|
| 108 |
+
images: ImageInput,
|
| 109 |
+
trimaps: ImageInput,
|
| 110 |
+
do_convert_rgb: bool,
|
| 111 |
+
input_data_format: ChannelDimension,
|
| 112 |
+
device: str | None = None,
|
| 113 |
+
**kwargs: Unpack[VitMatteImageProcessorKwargs],
|
| 114 |
+
) -> BatchFeature:
|
| 115 |
+
"""
|
| 116 |
+
Preprocess image-like inputs.
|
| 117 |
+
"""
|
| 118 |
+
images = self._prepare_image_like_inputs(
|
| 119 |
+
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
|
| 120 |
+
)
|
| 121 |
+
trimaps = self._prepare_image_like_inputs(images=trimaps, expected_ndims=2, device=device)
|
| 122 |
+
|
| 123 |
+
return self._preprocess(images, trimaps, **kwargs)
|
| 124 |
+
|
| 125 |
+
def _preprocess(
|
| 126 |
+
self,
|
| 127 |
+
images: list[np.ndarray],
|
| 128 |
+
trimaps: list[np.ndarray],
|
| 129 |
+
do_rescale: bool,
|
| 130 |
+
rescale_factor: float,
|
| 131 |
+
do_normalize: bool,
|
| 132 |
+
image_mean: float | list[float] | None,
|
| 133 |
+
image_std: float | list[float] | None,
|
| 134 |
+
do_pad: bool | None,
|
| 135 |
+
size_divisor: int | None,
|
| 136 |
+
return_tensors: str | TensorType | None,
|
| 137 |
+
**kwargs,
|
| 138 |
+
) -> BatchFeature:
|
| 139 |
+
processed_images = []
|
| 140 |
+
for image, trimap in zip(images, trimaps):
|
| 141 |
+
if do_rescale:
|
| 142 |
+
image = self.rescale(image, rescale_factor)
|
| 143 |
+
trimap = self.rescale(trimap, rescale_factor)
|
| 144 |
+
if do_normalize:
|
| 145 |
+
image = self.normalize(image, image_mean, image_std)
|
| 146 |
+
# Concatenate images and trimaps along channel dimension
|
| 147 |
+
# trimap is already (1, H, W) from _prepare_image_like_inputs with expected_ndims=2
|
| 148 |
+
if trimap.ndim == 3 and trimap.shape[0] == 1:
|
| 149 |
+
image = np.concatenate([image, trimap], axis=0)
|
| 150 |
+
else:
|
| 151 |
+
image = np.concatenate([image, np.expand_dims(trimap, axis=0)], axis=0)
|
| 152 |
+
if do_pad:
|
| 153 |
+
image = self.pad_image(image, size_divisor)
|
| 154 |
+
processed_images.append(image)
|
| 155 |
+
|
| 156 |
+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
__all__ = ["VitMatteImageProcessorPil"]
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/optimization.py
ADDED
|
@@ -0,0 +1,1342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""PyTorch optimization for BERT model."""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
import warnings
|
| 20 |
+
from functools import partial
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from torch.optim import Optimizer
|
| 25 |
+
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
| 26 |
+
|
| 27 |
+
from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
|
| 28 |
+
from .trainer_utils import SchedulerType
|
| 29 |
+
from .utils import logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _get_constant_lambda(_=None):
|
| 36 |
+
return 1
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
| 40 |
+
"""
|
| 41 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 45 |
+
The optimizer for which to schedule the learning rate.
|
| 46 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 47 |
+
The index of the last epoch when resuming training.
|
| 48 |
+
|
| 49 |
+
Return:
|
| 50 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
|
| 57 |
+
"""
|
| 58 |
+
Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 62 |
+
The optimizer for which to schedule the learning rate.
|
| 63 |
+
kwargs (`dict`, *optional*):
|
| 64 |
+
Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
|
| 65 |
+
for possible parameters.
|
| 66 |
+
|
| 67 |
+
Return:
|
| 68 |
+
`torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
return ReduceLROnPlateau(optimizer, **kwargs)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
|
| 75 |
+
if current_step < num_warmup_steps:
|
| 76 |
+
return float(current_step) / float(max(1.0, num_warmup_steps))
|
| 77 |
+
return 1.0
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
|
| 81 |
+
"""
|
| 82 |
+
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
| 83 |
+
increases linearly between 0 and the initial lr set in the optimizer.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 87 |
+
The optimizer for which to schedule the learning rate.
|
| 88 |
+
num_warmup_steps (`int`):
|
| 89 |
+
The number of steps for the warmup phase.
|
| 90 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 91 |
+
The index of the last epoch when resuming training.
|
| 92 |
+
|
| 93 |
+
Return:
|
| 94 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
|
| 98 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
|
| 102 |
+
if current_step < num_warmup_steps:
|
| 103 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 104 |
+
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
| 108 |
+
"""
|
| 109 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
| 110 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 114 |
+
The optimizer for which to schedule the learning rate.
|
| 115 |
+
num_warmup_steps (`int`):
|
| 116 |
+
The number of steps for the warmup phase.
|
| 117 |
+
num_training_steps (`int`):
|
| 118 |
+
The total number of training steps.
|
| 119 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 120 |
+
The index of the last epoch when resuming training.
|
| 121 |
+
|
| 122 |
+
Return:
|
| 123 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
lr_lambda = partial(
|
| 127 |
+
_get_linear_schedule_with_warmup_lr_lambda,
|
| 128 |
+
num_warmup_steps=num_warmup_steps,
|
| 129 |
+
num_training_steps=num_training_steps,
|
| 130 |
+
)
|
| 131 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _get_cosine_schedule_with_warmup_lr_lambda(
|
| 135 |
+
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
|
| 136 |
+
):
|
| 137 |
+
if current_step < num_warmup_steps:
|
| 138 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 139 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 140 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def get_cosine_schedule_with_warmup(
|
| 144 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
|
| 145 |
+
):
|
| 146 |
+
"""
|
| 147 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 148 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
| 149 |
+
initial lr set in the optimizer.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 153 |
+
The optimizer for which to schedule the learning rate.
|
| 154 |
+
num_warmup_steps (`int`):
|
| 155 |
+
The number of steps for the warmup phase.
|
| 156 |
+
num_training_steps (`int`):
|
| 157 |
+
The total number of training steps.
|
| 158 |
+
num_cycles (`float`, *optional*, defaults to 0.5):
|
| 159 |
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
| 160 |
+
following a half-cosine).
|
| 161 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 162 |
+
The index of the last epoch when resuming training.
|
| 163 |
+
|
| 164 |
+
Return:
|
| 165 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
lr_lambda = partial(
|
| 169 |
+
_get_cosine_schedule_with_warmup_lr_lambda,
|
| 170 |
+
num_warmup_steps=num_warmup_steps,
|
| 171 |
+
num_training_steps=num_training_steps,
|
| 172 |
+
num_cycles=num_cycles,
|
| 173 |
+
)
|
| 174 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda(
|
| 178 |
+
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int
|
| 179 |
+
):
|
| 180 |
+
if current_step < num_warmup_steps:
|
| 181 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 182 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 183 |
+
if progress >= 1.0:
|
| 184 |
+
return 0.0
|
| 185 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
| 189 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 193 |
+
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
| 194 |
+
linearly between 0 and the initial lr set in the optimizer.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 198 |
+
The optimizer for which to schedule the learning rate.
|
| 199 |
+
num_warmup_steps (`int`):
|
| 200 |
+
The number of steps for the warmup phase.
|
| 201 |
+
num_training_steps (`int`):
|
| 202 |
+
The total number of training steps.
|
| 203 |
+
num_cycles (`int`, *optional*, defaults to 1):
|
| 204 |
+
The number of hard restarts to use.
|
| 205 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 206 |
+
The index of the last epoch when resuming training.
|
| 207 |
+
|
| 208 |
+
Return:
|
| 209 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
lr_lambda = partial(
|
| 213 |
+
_get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda,
|
| 214 |
+
num_warmup_steps=num_warmup_steps,
|
| 215 |
+
num_training_steps=num_training_steps,
|
| 216 |
+
num_cycles=num_cycles,
|
| 217 |
+
)
|
| 218 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _get_polynomial_decay_schedule_with_warmup_lr_lambda(
|
| 222 |
+
current_step: int,
|
| 223 |
+
*,
|
| 224 |
+
num_warmup_steps: int,
|
| 225 |
+
num_training_steps: int,
|
| 226 |
+
lr_end: float,
|
| 227 |
+
power: float,
|
| 228 |
+
lr_init: int,
|
| 229 |
+
):
|
| 230 |
+
if current_step < num_warmup_steps:
|
| 231 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 232 |
+
elif current_step > num_training_steps:
|
| 233 |
+
return lr_end / lr_init # as LambdaLR multiplies by lr_init
|
| 234 |
+
else:
|
| 235 |
+
lr_range = lr_init - lr_end
|
| 236 |
+
decay_steps = num_training_steps - num_warmup_steps
|
| 237 |
+
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
|
| 238 |
+
decay = lr_range * pct_remaining**power + lr_end
|
| 239 |
+
return decay / lr_init # as LambdaLR multiplies by lr_init
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def get_polynomial_decay_schedule_with_warmup(
|
| 243 |
+
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
|
| 244 |
+
):
|
| 245 |
+
"""
|
| 246 |
+
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
|
| 247 |
+
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
|
| 248 |
+
initial lr set in the optimizer.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 252 |
+
The optimizer for which to schedule the learning rate.
|
| 253 |
+
num_warmup_steps (`int`):
|
| 254 |
+
The number of steps for the warmup phase.
|
| 255 |
+
num_training_steps (`int`):
|
| 256 |
+
The total number of training steps.
|
| 257 |
+
lr_end (`float`, *optional*, defaults to 1e-7):
|
| 258 |
+
The end LR.
|
| 259 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 260 |
+
Power factor.
|
| 261 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 262 |
+
The index of the last epoch when resuming training.
|
| 263 |
+
|
| 264 |
+
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
|
| 265 |
+
implementation at
|
| 266 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
|
| 267 |
+
|
| 268 |
+
Return:
|
| 269 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 270 |
+
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
lr_init = optimizer.defaults["lr"]
|
| 274 |
+
if not (lr_init > lr_end):
|
| 275 |
+
raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
|
| 276 |
+
|
| 277 |
+
lr_lambda = partial(
|
| 278 |
+
_get_polynomial_decay_schedule_with_warmup_lr_lambda,
|
| 279 |
+
num_warmup_steps=num_warmup_steps,
|
| 280 |
+
num_training_steps=num_training_steps,
|
| 281 |
+
lr_end=lr_end,
|
| 282 |
+
power=power,
|
| 283 |
+
lr_init=lr_init,
|
| 284 |
+
)
|
| 285 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: int | None = None):
|
| 289 |
+
if current_step < num_warmup_steps:
|
| 290 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 291 |
+
shift = timescale - num_warmup_steps
|
| 292 |
+
decay = 1.0 / math.sqrt((current_step + shift) / timescale)
|
| 293 |
+
return decay
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def get_inverse_sqrt_schedule(
|
| 297 |
+
optimizer: Optimizer, num_warmup_steps: int, timescale: int | None = None, last_epoch: int = -1
|
| 298 |
+
):
|
| 299 |
+
"""
|
| 300 |
+
Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
|
| 301 |
+
warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 305 |
+
The optimizer for which to schedule the learning rate.
|
| 306 |
+
num_warmup_steps (`int`):
|
| 307 |
+
The number of steps for the warmup phase.
|
| 308 |
+
timescale (`int`, *optional*, defaults to `num_warmup_steps`):
|
| 309 |
+
Time scale.
|
| 310 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 311 |
+
The index of the last epoch when resuming training.
|
| 312 |
+
|
| 313 |
+
Return:
|
| 314 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 315 |
+
"""
|
| 316 |
+
# Note: this implementation is adapted from
|
| 317 |
+
# https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
|
| 318 |
+
|
| 319 |
+
if timescale is None:
|
| 320 |
+
timescale = num_warmup_steps or 10_000
|
| 321 |
+
|
| 322 |
+
lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale)
|
| 323 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _get_cosine_schedule_with_warmup_lr_lambda(
|
| 327 |
+
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
|
| 328 |
+
):
|
| 329 |
+
if current_step < num_warmup_steps:
|
| 330 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 331 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 332 |
+
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
| 333 |
+
factor = factor * (1 - min_lr_rate) + min_lr_rate
|
| 334 |
+
return max(0, factor)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def get_cosine_with_min_lr_schedule_with_warmup(
|
| 338 |
+
optimizer: Optimizer,
|
| 339 |
+
num_warmup_steps: int,
|
| 340 |
+
num_training_steps: int,
|
| 341 |
+
num_cycles: float = 0.5,
|
| 342 |
+
last_epoch: int = -1,
|
| 343 |
+
min_lr: float | None = None,
|
| 344 |
+
min_lr_rate: float | None = None,
|
| 345 |
+
):
|
| 346 |
+
"""
|
| 347 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 348 |
+
initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
|
| 349 |
+
initial lr set in the optimizer.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 353 |
+
The optimizer for which to schedule the learning rate.
|
| 354 |
+
num_warmup_steps (`int`):
|
| 355 |
+
The number of steps for the warmup phase.
|
| 356 |
+
num_training_steps (`int`):
|
| 357 |
+
The total number of training steps.
|
| 358 |
+
num_cycles (`float`, *optional*, defaults to 0.5):
|
| 359 |
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
| 360 |
+
following a half-cosine).
|
| 361 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 362 |
+
The index of the last epoch when resuming training.
|
| 363 |
+
min_lr (`float`, *optional*):
|
| 364 |
+
The minimum learning rate to reach after the cosine schedule.
|
| 365 |
+
min_lr_rate (`float`, *optional*):
|
| 366 |
+
The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
|
| 367 |
+
|
| 368 |
+
Return:
|
| 369 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
if min_lr is not None and min_lr_rate is not None:
|
| 373 |
+
raise ValueError("Only one of min_lr or min_lr_rate should be set")
|
| 374 |
+
elif min_lr is not None:
|
| 375 |
+
min_lr_rate = min_lr / optimizer.defaults["lr"]
|
| 376 |
+
elif min_lr_rate is None:
|
| 377 |
+
raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
|
| 378 |
+
|
| 379 |
+
lr_lambda = partial(
|
| 380 |
+
_get_cosine_schedule_with_warmup_lr_lambda,
|
| 381 |
+
num_warmup_steps=num_warmup_steps,
|
| 382 |
+
num_training_steps=num_training_steps,
|
| 383 |
+
num_cycles=num_cycles,
|
| 384 |
+
min_lr_rate=min_lr_rate,
|
| 385 |
+
)
|
| 386 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda(
|
| 390 |
+
current_step: int,
|
| 391 |
+
*,
|
| 392 |
+
num_warmup_steps: int,
|
| 393 |
+
num_training_steps: int,
|
| 394 |
+
num_cycles: float,
|
| 395 |
+
min_lr_rate: float = 0.0,
|
| 396 |
+
warmup_lr_rate: float | None = None,
|
| 397 |
+
):
|
| 398 |
+
current_step = float(current_step)
|
| 399 |
+
num_warmup_steps = float(num_warmup_steps)
|
| 400 |
+
num_training_steps = float(num_training_steps)
|
| 401 |
+
|
| 402 |
+
if current_step < num_warmup_steps:
|
| 403 |
+
if warmup_lr_rate is None:
|
| 404 |
+
return (current_step + 1.0) / max(1.0, num_warmup_steps)
|
| 405 |
+
else:
|
| 406 |
+
warmup_lr_rate = float(warmup_lr_rate)
|
| 407 |
+
return warmup_lr_rate + (1.0 - warmup_lr_rate) * (current_step) / (max(1, num_warmup_steps - 1))
|
| 408 |
+
progress = (current_step - num_warmup_steps + 1.0) / (max(1.0, num_training_steps - num_warmup_steps))
|
| 409 |
+
factor = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))
|
| 410 |
+
factor = factor * (1 - min_lr_rate) + min_lr_rate
|
| 411 |
+
return max(0, factor)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def get_cosine_with_min_lr_schedule_with_warmup_lr_rate(
|
| 415 |
+
optimizer: Optimizer,
|
| 416 |
+
num_warmup_steps: int,
|
| 417 |
+
num_training_steps: int,
|
| 418 |
+
num_cycles: float = 0.5,
|
| 419 |
+
last_epoch: int = -1,
|
| 420 |
+
min_lr: float | None = None,
|
| 421 |
+
min_lr_rate: float | None = None,
|
| 422 |
+
warmup_lr_rate: float | None = None,
|
| 423 |
+
):
|
| 424 |
+
"""
|
| 425 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 426 |
+
initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
|
| 427 |
+
initial lr set in the optimizer.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 431 |
+
The optimizer for which to schedule the learning rate.
|
| 432 |
+
num_warmup_steps (`int`):
|
| 433 |
+
The number of steps for the warmup phase.
|
| 434 |
+
num_training_steps (`int`):
|
| 435 |
+
The total number of training steps.
|
| 436 |
+
num_cycles (`float`, *optional*, defaults to 0.5):
|
| 437 |
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
| 438 |
+
following a half-cosine).
|
| 439 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 440 |
+
The index of the last epoch when resuming training.
|
| 441 |
+
min_lr (`float`, *optional*):
|
| 442 |
+
The minimum learning rate to reach after the cosine schedule.
|
| 443 |
+
min_lr_rate (`float`, *optional*):
|
| 444 |
+
The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
|
| 445 |
+
warmup_lr_rate (`float`, *optional*):
|
| 446 |
+
The minimum learning rate as a ratio of the start learning rate. If not set, `warmup_lr_rate` will be treated as float(1/num_warmup_steps).
|
| 447 |
+
|
| 448 |
+
Return:
|
| 449 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
if min_lr is not None and min_lr_rate is not None:
|
| 453 |
+
raise ValueError("Only one of min_lr or min_lr_rate should be set")
|
| 454 |
+
elif min_lr is not None:
|
| 455 |
+
min_lr_rate = min_lr / optimizer.defaults["lr"]
|
| 456 |
+
elif min_lr_rate is None:
|
| 457 |
+
raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
|
| 458 |
+
|
| 459 |
+
lr_lambda = partial(
|
| 460 |
+
_get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda,
|
| 461 |
+
num_warmup_steps=num_warmup_steps,
|
| 462 |
+
num_training_steps=num_training_steps,
|
| 463 |
+
num_cycles=num_cycles,
|
| 464 |
+
min_lr_rate=min_lr_rate,
|
| 465 |
+
warmup_lr_rate=warmup_lr_rate,
|
| 466 |
+
)
|
| 467 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def _get_wsd_scheduler_lambda(
|
| 471 |
+
current_step: int,
|
| 472 |
+
*,
|
| 473 |
+
num_warmup_steps: int,
|
| 474 |
+
num_stable_steps: int,
|
| 475 |
+
num_decay_steps: int,
|
| 476 |
+
warmup_type: str,
|
| 477 |
+
decay_type: str,
|
| 478 |
+
min_lr_ratio: float,
|
| 479 |
+
num_cycles: float,
|
| 480 |
+
):
|
| 481 |
+
if current_step < num_warmup_steps:
|
| 482 |
+
progress = float(current_step) / float(max(1, num_warmup_steps))
|
| 483 |
+
if warmup_type == "linear":
|
| 484 |
+
factor = progress
|
| 485 |
+
elif warmup_type == "cosine":
|
| 486 |
+
factor = 0.5 * (1.0 - math.cos(math.pi * progress))
|
| 487 |
+
elif warmup_type == "1-sqrt":
|
| 488 |
+
factor = 1.0 - math.sqrt(1.0 - progress)
|
| 489 |
+
factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
|
| 490 |
+
return max(0.0, factor)
|
| 491 |
+
|
| 492 |
+
if current_step < num_warmup_steps + num_stable_steps:
|
| 493 |
+
return 1.0
|
| 494 |
+
|
| 495 |
+
if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
|
| 496 |
+
progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
|
| 497 |
+
if decay_type == "linear":
|
| 498 |
+
factor = 1.0 - progress
|
| 499 |
+
elif decay_type == "cosine":
|
| 500 |
+
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
| 501 |
+
elif decay_type == "1-sqrt":
|
| 502 |
+
factor = 1.0 - math.sqrt(progress)
|
| 503 |
+
factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
|
| 504 |
+
return max(0.0, factor)
|
| 505 |
+
return min_lr_ratio
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def get_wsd_schedule(
|
| 509 |
+
optimizer: Optimizer,
|
| 510 |
+
num_warmup_steps: int,
|
| 511 |
+
num_decay_steps: int,
|
| 512 |
+
num_training_steps: int | None = None,
|
| 513 |
+
num_stable_steps: int | None = None,
|
| 514 |
+
warmup_type: str = "linear",
|
| 515 |
+
decay_type: str = "cosine",
|
| 516 |
+
min_lr_ratio: float = 0,
|
| 517 |
+
num_cycles: float = 0.5,
|
| 518 |
+
last_epoch: int = -1,
|
| 519 |
+
):
|
| 520 |
+
"""
|
| 521 |
+
Create a schedule with a learning rate that has three stages:
|
| 522 |
+
1. warmup: increase from min_lr_ratio times the initial learning rate to the initial learning rate following a warmup_type.
|
| 523 |
+
2. stable: constant learning rate.
|
| 524 |
+
3. decay: decrease from the initial learning rate to min_lr_ratio times the initial learning rate following a decay_type.
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 528 |
+
The optimizer for which to schedule the learning rate.
|
| 529 |
+
num_warmup_steps (`int`):
|
| 530 |
+
The number of steps for the warmup phase.
|
| 531 |
+
num_decay_steps (`int`):
|
| 532 |
+
The number of steps for the decay phase.
|
| 533 |
+
num_training_steps (`int`, *optional*):
|
| 534 |
+
The total number of training steps. This is the sum of the warmup, stable and decay steps. If `num_stable_steps` is not provided, the stable phase will be `num_training_steps - num_warmup_steps - num_decay_steps`.
|
| 535 |
+
num_stable_steps (`int`, *optional*):
|
| 536 |
+
The number of steps for the stable phase. Please ensure that `num_warmup_steps + num_stable_steps + num_decay_steps` equals `num_training_steps`, otherwise the other steps will default to the minimum learning rate.
|
| 537 |
+
warmup_type (`str`, *optional*, defaults to "linear"):
|
| 538 |
+
The type of warmup to use. Can be 'linear', 'cosine' or '1-sqrt'.
|
| 539 |
+
decay_type (`str`, *optional*, defaults to "cosine"):
|
| 540 |
+
The type of decay to use. Can be 'linear', 'cosine' or '1-sqrt'.
|
| 541 |
+
min_lr_ratio (`float`, *optional*, defaults to 0):
|
| 542 |
+
The minimum learning rate as a ratio of the initial learning rate.
|
| 543 |
+
num_cycles (`float`, *optional*, defaults to 0.5):
|
| 544 |
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
| 545 |
+
following a half-cosine).
|
| 546 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 547 |
+
The index of the last epoch when resuming training.
|
| 548 |
+
|
| 549 |
+
Return:
|
| 550 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 551 |
+
"""
|
| 552 |
+
|
| 553 |
+
if num_training_steps is None and num_stable_steps is None:
|
| 554 |
+
raise ValueError("Either num_training_steps or num_stable_steps must be specified.")
|
| 555 |
+
|
| 556 |
+
if num_training_steps is not None and num_stable_steps is not None:
|
| 557 |
+
warnings.warn("Both num_training_steps and num_stable_steps are specified. num_stable_steps will be used.")
|
| 558 |
+
|
| 559 |
+
if warmup_type not in ["linear", "cosine", "1-sqrt"]:
|
| 560 |
+
raise ValueError(f"Unknown warmup type: {warmup_type}, expected 'linear', 'cosine' or '1-sqrt'")
|
| 561 |
+
|
| 562 |
+
if decay_type not in ["linear", "cosine", "1-sqrt"]:
|
| 563 |
+
raise ValueError(f"Unknown decay type: {decay_type}, expected 'linear', 'cosine' or '1-sqrt'")
|
| 564 |
+
|
| 565 |
+
if num_stable_steps is None:
|
| 566 |
+
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
|
| 567 |
+
|
| 568 |
+
lr_lambda = partial(
|
| 569 |
+
_get_wsd_scheduler_lambda,
|
| 570 |
+
num_warmup_steps=num_warmup_steps,
|
| 571 |
+
num_stable_steps=num_stable_steps,
|
| 572 |
+
num_decay_steps=num_decay_steps,
|
| 573 |
+
warmup_type=warmup_type,
|
| 574 |
+
decay_type=decay_type,
|
| 575 |
+
min_lr_ratio=min_lr_ratio,
|
| 576 |
+
num_cycles=num_cycles,
|
| 577 |
+
)
|
| 578 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class StreamingAverage:
|
| 582 |
+
"""Rolling window average for smoothing metric values.
|
| 583 |
+
|
| 584 |
+
Maintains a sliding window of values and computes their average,
|
| 585 |
+
useful for smoothing noisy metric values before making learning rate decisions.
|
| 586 |
+
|
| 587 |
+
Args:
|
| 588 |
+
window_size (`int`):
|
| 589 |
+
The maximum number of values to keep in the rolling window.
|
| 590 |
+
"""
|
| 591 |
+
|
| 592 |
+
def __init__(self, window_size: int) -> None:
|
| 593 |
+
self.window_size: int = window_size
|
| 594 |
+
self.values: list[float] = []
|
| 595 |
+
self.sum: float = 0.0
|
| 596 |
+
|
| 597 |
+
def streamavg(self, value: float) -> float:
|
| 598 |
+
"""Add a value and return the current rolling average."""
|
| 599 |
+
self.values.append(value)
|
| 600 |
+
self.sum += value
|
| 601 |
+
|
| 602 |
+
if len(self.values) > self.window_size:
|
| 603 |
+
removed = self.values.pop(0)
|
| 604 |
+
self.sum -= removed
|
| 605 |
+
|
| 606 |
+
return self.sum / len(self.values)
|
| 607 |
+
|
| 608 |
+
def state_dict(self) -> dict[str, Any]:
|
| 609 |
+
return {
|
| 610 |
+
"window_size": self.window_size,
|
| 611 |
+
"values": self.values.copy(),
|
| 612 |
+
"sum": self.sum,
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
| 616 |
+
self.window_size = state_dict.get("window_size", self.window_size)
|
| 617 |
+
self.values = state_dict.get("values", []).copy()
|
| 618 |
+
self.sum = state_dict.get("sum", 0.0)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class GreedyLR:
|
| 622 |
+
"""Adaptive learning rate scheduler that responds to training metrics.
|
| 623 |
+
|
| 624 |
+
GreedyLR dynamically adjusts the learning rate based on training performance:
|
| 625 |
+
- Increases LR when metrics improve consistently (divides by factor)
|
| 626 |
+
- Decreases LR when metrics plateau (multiplies by factor)
|
| 627 |
+
|
| 628 |
+
This differs from traditional schedulers like cosine annealing by responding
|
| 629 |
+
to actual training dynamics rather than following a predetermined schedule.
|
| 630 |
+
|
| 631 |
+
Reference: `GreedyLR: A Novel Adaptive Learning Rate Scheduler <https://arxiv.org/abs/2512.14527>`_
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 635 |
+
The optimizer for which to schedule the learning rate.
|
| 636 |
+
mode (`str`, *optional*, defaults to `"min"`):
|
| 637 |
+
One of 'min' or 'max'. In 'min' mode, LR will be reduced when the
|
| 638 |
+
metric has stopped decreasing; in 'max' mode when it has stopped increasing.
|
| 639 |
+
factor (`float`, *optional*, defaults to 0.95):
|
| 640 |
+
Factor by which the learning rate will be adjusted. LR is multiplied by
|
| 641 |
+
factor on plateau and divided by factor on improvement. Must be < 1.0.
|
| 642 |
+
patience (`int`, *optional*, defaults to 10):
|
| 643 |
+
Number of epochs with no improvement after which learning rate will be adjusted.
|
| 644 |
+
threshold (`float`, *optional*, defaults to 1e-06):
|
| 645 |
+
Threshold for measuring the new optimum.
|
| 646 |
+
threshold_mode (`str`, *optional*, defaults to `"abs"`):
|
| 647 |
+
One of 'rel' or 'abs'.
|
| 648 |
+
cooldown (`int`, *optional*, defaults to 0):
|
| 649 |
+
Number of epochs to wait before resuming normal operation after LR has been reduced.
|
| 650 |
+
warmup (`int`, *optional*, defaults to 0):
|
| 651 |
+
Number of epochs to wait before resuming normal operation after LR has been increased.
|
| 652 |
+
min_lr (`float` or `list[float]`, *optional*, defaults to 0.001):
|
| 653 |
+
A lower bound on the learning rate.
|
| 654 |
+
max_lr (`float` or `list[float]`, *optional*, defaults to 1.0):
|
| 655 |
+
An upper bound on the learning rate.
|
| 656 |
+
eps (`float`, *optional*, defaults to 1e-08):
|
| 657 |
+
Minimal decay applied to lr.
|
| 658 |
+
verbose (`bool`, *optional*, defaults to `False`):
|
| 659 |
+
If True, prints a message to stdout for each update.
|
| 660 |
+
smooth (`bool`, *optional*, defaults to `False`):
|
| 661 |
+
If True, applies streaming average smoothing to metrics.
|
| 662 |
+
window_size (`int`, *optional*, defaults to 50):
|
| 663 |
+
The window size for the streaming average when smooth=True.
|
| 664 |
+
reset_start (`int`, *optional*, defaults to 500):
|
| 665 |
+
Number of steps to wait at min_lr before resetting to initial state.
|
| 666 |
+
|
| 667 |
+
Example:
|
| 668 |
+
```python
|
| 669 |
+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
| 670 |
+
>>> scheduler = GreedyLR(optimizer, mode="min", patience=10)
|
| 671 |
+
>>> for epoch in range(100):
|
| 672 |
+
... train(...)
|
| 673 |
+
... val_loss = validate(...)
|
| 674 |
+
... scheduler.step(val_loss)
|
| 675 |
+
```
|
| 676 |
+
"""
|
| 677 |
+
|
| 678 |
+
def __init__(
|
| 679 |
+
self,
|
| 680 |
+
optimizer: Optimizer,
|
| 681 |
+
mode: str = "min",
|
| 682 |
+
factor: float = 0.95,
|
| 683 |
+
patience: int = 10,
|
| 684 |
+
threshold: float = 1e-6,
|
| 685 |
+
threshold_mode: str = "abs",
|
| 686 |
+
cooldown: int = 0,
|
| 687 |
+
warmup: int = 0,
|
| 688 |
+
min_lr: float | list[float] = 1e-3,
|
| 689 |
+
max_lr: float | list[float] = 1.0,
|
| 690 |
+
eps: float = 1e-8,
|
| 691 |
+
verbose: bool = False,
|
| 692 |
+
smooth: bool = False,
|
| 693 |
+
window_size: int = 50,
|
| 694 |
+
reset_start: int = 500,
|
| 695 |
+
) -> None:
|
| 696 |
+
if factor >= 1.0:
|
| 697 |
+
raise ValueError("Factor should be < 1.0.")
|
| 698 |
+
if not isinstance(optimizer, Optimizer):
|
| 699 |
+
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
|
| 700 |
+
|
| 701 |
+
self.optimizer = optimizer
|
| 702 |
+
self.factor = factor
|
| 703 |
+
self.patience = patience
|
| 704 |
+
self.verbose = verbose
|
| 705 |
+
self.cooldown = cooldown
|
| 706 |
+
self.warmup = warmup
|
| 707 |
+
self.cooldown_counter = 0
|
| 708 |
+
self.warmup_counter = 0
|
| 709 |
+
self.mode = mode
|
| 710 |
+
self.threshold = threshold
|
| 711 |
+
self.threshold_mode = threshold_mode
|
| 712 |
+
self.eps = eps
|
| 713 |
+
self.smooth = smooth
|
| 714 |
+
self.window_size = window_size
|
| 715 |
+
self.reset_start = reset_start
|
| 716 |
+
self.reset_start_original = reset_start
|
| 717 |
+
self.last_epoch = 0
|
| 718 |
+
|
| 719 |
+
if isinstance(min_lr, (list, tuple)):
|
| 720 |
+
if len(min_lr) != len(optimizer.param_groups):
|
| 721 |
+
raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
|
| 722 |
+
self.min_lrs = list(min_lr)
|
| 723 |
+
else:
|
| 724 |
+
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
| 725 |
+
|
| 726 |
+
if isinstance(max_lr, (list, tuple)):
|
| 727 |
+
if len(max_lr) != len(optimizer.param_groups):
|
| 728 |
+
raise ValueError(f"expected {len(optimizer.param_groups)} max_lrs, got {len(max_lr)}")
|
| 729 |
+
self.max_lrs = list(max_lr)
|
| 730 |
+
else:
|
| 731 |
+
self.max_lrs = [max_lr] * len(optimizer.param_groups)
|
| 732 |
+
|
| 733 |
+
self._init_lrs = [group["lr"] for group in optimizer.param_groups]
|
| 734 |
+
self._last_lr = self._init_lrs.copy()
|
| 735 |
+
|
| 736 |
+
self.best: float = float("inf") if mode == "min" else float("-inf")
|
| 737 |
+
self.num_bad_epochs = 0
|
| 738 |
+
self.num_good_epochs = 0
|
| 739 |
+
|
| 740 |
+
if mode not in ("min", "max"):
|
| 741 |
+
raise ValueError(f"mode {mode} is unknown!")
|
| 742 |
+
if threshold_mode not in ("rel", "abs"):
|
| 743 |
+
raise ValueError(f"threshold mode {threshold_mode} is unknown!")
|
| 744 |
+
|
| 745 |
+
self._streaming_avg: StreamingAverage | None = None
|
| 746 |
+
if smooth:
|
| 747 |
+
self._streaming_avg = StreamingAverage(window_size)
|
| 748 |
+
|
| 749 |
+
def step(self, metrics: float, epoch: int | None = None) -> None:
|
| 750 |
+
"""Perform a scheduler step based on the given metrics.
|
| 751 |
+
|
| 752 |
+
Args:
|
| 753 |
+
metrics (`float`):
|
| 754 |
+
The metric value to use for LR adjustment decisions.
|
| 755 |
+
epoch (`int`, *optional*):
|
| 756 |
+
The current epoch number. If None, uses internal counter.
|
| 757 |
+
"""
|
| 758 |
+
current = float(metrics)
|
| 759 |
+
|
| 760 |
+
if self.smooth and self._streaming_avg is not None:
|
| 761 |
+
current = self._streaming_avg.streamavg(current)
|
| 762 |
+
|
| 763 |
+
if epoch is None:
|
| 764 |
+
epoch = self.last_epoch + 1
|
| 765 |
+
self.last_epoch = epoch
|
| 766 |
+
|
| 767 |
+
if self.cooldown_counter > 0:
|
| 768 |
+
self.cooldown_counter -= 1
|
| 769 |
+
self.num_bad_epochs = 0
|
| 770 |
+
self.num_good_epochs = 0
|
| 771 |
+
elif self.warmup_counter > 0:
|
| 772 |
+
self.warmup_counter -= 1
|
| 773 |
+
self.num_bad_epochs = 0
|
| 774 |
+
self.num_good_epochs = 0
|
| 775 |
+
else:
|
| 776 |
+
if self.is_better(current, self.best):
|
| 777 |
+
self.best = current
|
| 778 |
+
self.num_bad_epochs = 0
|
| 779 |
+
self.num_good_epochs += 1
|
| 780 |
+
else:
|
| 781 |
+
self.num_bad_epochs += 1
|
| 782 |
+
self.num_good_epochs = 0
|
| 783 |
+
|
| 784 |
+
if self.num_good_epochs > self.patience:
|
| 785 |
+
self._increase_lr(epoch)
|
| 786 |
+
self.warmup_counter = self.warmup
|
| 787 |
+
self.num_good_epochs = 0
|
| 788 |
+
elif self.num_bad_epochs > self.patience:
|
| 789 |
+
self._reduce_lr(epoch)
|
| 790 |
+
self.cooldown_counter = self.cooldown
|
| 791 |
+
self.num_bad_epochs = 0
|
| 792 |
+
|
| 793 |
+
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
| 794 |
+
|
| 795 |
+
def is_better(self, current: float, best: float) -> bool:
|
| 796 |
+
if self.mode == "min":
|
| 797 |
+
if self.threshold_mode == "rel":
|
| 798 |
+
return current < best * (1.0 - self.threshold)
|
| 799 |
+
else:
|
| 800 |
+
return current < best - self.threshold
|
| 801 |
+
else:
|
| 802 |
+
if self.threshold_mode == "rel":
|
| 803 |
+
return current > best * (1.0 + self.threshold)
|
| 804 |
+
else:
|
| 805 |
+
return current > best + self.threshold
|
| 806 |
+
|
| 807 |
+
def _reduce_lr(self, epoch: int) -> None:
|
| 808 |
+
all_at_min = True
|
| 809 |
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
| 810 |
+
old_lr = float(param_group["lr"])
|
| 811 |
+
new_lr = max(old_lr * self.factor, self.min_lrs[i])
|
| 812 |
+
|
| 813 |
+
if old_lr - new_lr > self.eps:
|
| 814 |
+
param_group["lr"] = new_lr
|
| 815 |
+
if self.verbose:
|
| 816 |
+
print(f"Epoch {epoch}: reducing learning rate of group {i} to {new_lr:.4e}.")
|
| 817 |
+
|
| 818 |
+
if param_group["lr"] > self.min_lrs[i]:
|
| 819 |
+
all_at_min = False
|
| 820 |
+
|
| 821 |
+
if all_at_min:
|
| 822 |
+
self.reset_start -= 1
|
| 823 |
+
if self.reset_start <= 0:
|
| 824 |
+
self._reset()
|
| 825 |
+
|
| 826 |
+
def _increase_lr(self, epoch: int) -> None:
|
| 827 |
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
| 828 |
+
old_lr = float(param_group["lr"])
|
| 829 |
+
new_lr = min(old_lr / self.factor, self.max_lrs[i])
|
| 830 |
+
|
| 831 |
+
if new_lr - old_lr > self.eps:
|
| 832 |
+
param_group["lr"] = new_lr
|
| 833 |
+
if self.verbose:
|
| 834 |
+
print(f"Epoch {epoch}: increasing learning rate of group {i} to {new_lr:.4e}.")
|
| 835 |
+
|
| 836 |
+
self.reset_start = self.reset_start_original
|
| 837 |
+
|
| 838 |
+
def _reset(self) -> None:
|
| 839 |
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
| 840 |
+
param_group["lr"] = self._init_lrs[i]
|
| 841 |
+
|
| 842 |
+
self.best = float("inf") if self.mode == "min" else float("-inf")
|
| 843 |
+
self.num_bad_epochs = 0
|
| 844 |
+
self.num_good_epochs = 0
|
| 845 |
+
self.cooldown_counter = 0
|
| 846 |
+
self.warmup_counter = 0
|
| 847 |
+
self.reset_start = self.reset_start_original
|
| 848 |
+
|
| 849 |
+
if self.smooth and self._streaming_avg is not None:
|
| 850 |
+
self._streaming_avg = StreamingAverage(self.window_size)
|
| 851 |
+
|
| 852 |
+
if self.verbose:
|
| 853 |
+
print("Scheduler reset to initial state.")
|
| 854 |
+
|
| 855 |
+
def get_last_lr(self) -> list[float]:
|
| 856 |
+
"""Return last computed learning rate by current scheduler."""
|
| 857 |
+
return self._last_lr
|
| 858 |
+
|
| 859 |
+
def state_dict(self) -> dict[str, Any]:
|
| 860 |
+
"""Return the state of the scheduler as a dictionary."""
|
| 861 |
+
state = {
|
| 862 |
+
"factor": self.factor,
|
| 863 |
+
"min_lrs": self.min_lrs,
|
| 864 |
+
"max_lrs": self.max_lrs,
|
| 865 |
+
"patience": self.patience,
|
| 866 |
+
"verbose": self.verbose,
|
| 867 |
+
"cooldown": self.cooldown,
|
| 868 |
+
"warmup": self.warmup,
|
| 869 |
+
"cooldown_counter": self.cooldown_counter,
|
| 870 |
+
"warmup_counter": self.warmup_counter,
|
| 871 |
+
"mode": self.mode,
|
| 872 |
+
"threshold": self.threshold,
|
| 873 |
+
"threshold_mode": self.threshold_mode,
|
| 874 |
+
"best": self.best,
|
| 875 |
+
"num_bad_epochs": self.num_bad_epochs,
|
| 876 |
+
"num_good_epochs": self.num_good_epochs,
|
| 877 |
+
"eps": self.eps,
|
| 878 |
+
"last_epoch": self.last_epoch,
|
| 879 |
+
"smooth": self.smooth,
|
| 880 |
+
"window_size": self.window_size,
|
| 881 |
+
"reset_start": self.reset_start,
|
| 882 |
+
"reset_start_original": self.reset_start_original,
|
| 883 |
+
"_last_lr": self._last_lr,
|
| 884 |
+
"_init_lrs": self._init_lrs,
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
if self.smooth and self._streaming_avg is not None:
|
| 888 |
+
state["_streaming_avg"] = self._streaming_avg.state_dict()
|
| 889 |
+
|
| 890 |
+
return state
|
| 891 |
+
|
| 892 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
| 893 |
+
"""Load state from a dictionary."""
|
| 894 |
+
self.factor = state_dict.get("factor", self.factor)
|
| 895 |
+
self.min_lrs = state_dict.get("min_lrs", self.min_lrs)
|
| 896 |
+
self.max_lrs = state_dict.get("max_lrs", self.max_lrs)
|
| 897 |
+
self.patience = state_dict.get("patience", self.patience)
|
| 898 |
+
self.verbose = state_dict.get("verbose", self.verbose)
|
| 899 |
+
self.cooldown = state_dict.get("cooldown", self.cooldown)
|
| 900 |
+
self.warmup = state_dict.get("warmup", self.warmup)
|
| 901 |
+
self.cooldown_counter = state_dict.get("cooldown_counter", self.cooldown_counter)
|
| 902 |
+
self.warmup_counter = state_dict.get("warmup_counter", self.warmup_counter)
|
| 903 |
+
self.mode = state_dict.get("mode", self.mode)
|
| 904 |
+
self.threshold = state_dict.get("threshold", self.threshold)
|
| 905 |
+
self.threshold_mode = state_dict.get("threshold_mode", self.threshold_mode)
|
| 906 |
+
self.best = state_dict.get("best", self.best)
|
| 907 |
+
self.num_bad_epochs = state_dict.get("num_bad_epochs", self.num_bad_epochs)
|
| 908 |
+
self.num_good_epochs = state_dict.get("num_good_epochs", self.num_good_epochs)
|
| 909 |
+
self.eps = state_dict.get("eps", self.eps)
|
| 910 |
+
self.last_epoch = state_dict.get("last_epoch", self.last_epoch)
|
| 911 |
+
self.smooth = state_dict.get("smooth", self.smooth)
|
| 912 |
+
self.window_size = state_dict.get("window_size", self.window_size)
|
| 913 |
+
self.reset_start = state_dict.get("reset_start", self.reset_start)
|
| 914 |
+
self.reset_start_original = state_dict.get("reset_start_original", self.reset_start_original)
|
| 915 |
+
self._last_lr = state_dict.get("_last_lr", self._last_lr)
|
| 916 |
+
self._init_lrs = state_dict.get("_init_lrs", self._init_lrs)
|
| 917 |
+
|
| 918 |
+
if "_streaming_avg" in state_dict:
|
| 919 |
+
if self._streaming_avg is None:
|
| 920 |
+
self._streaming_avg = StreamingAverage(self.window_size)
|
| 921 |
+
self._streaming_avg.load_state_dict(state_dict["_streaming_avg"])
|
| 922 |
+
|
| 923 |
+
if "_last_lr" in state_dict:
|
| 924 |
+
for param_group, lr in zip(self.optimizer.param_groups, self._last_lr):
|
| 925 |
+
param_group["lr"] = lr
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
def get_greedy_schedule(optimizer: Optimizer, **kwargs):
|
| 929 |
+
"""
|
| 930 |
+
Create an adaptive learning rate scheduler that adjusts LR based on training metrics.
|
| 931 |
+
|
| 932 |
+
Args:
|
| 933 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 934 |
+
The optimizer for which to schedule the learning rate.
|
| 935 |
+
kwargs (`dict`, *optional*):
|
| 936 |
+
Extra parameters passed to the scheduler. See [`GreedyLR`] for possible parameters.
|
| 937 |
+
|
| 938 |
+
Return:
|
| 939 |
+
[`GreedyLR`] with the appropriate schedule.
|
| 940 |
+
"""
|
| 941 |
+
return GreedyLR(optimizer, **kwargs)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
TYPE_TO_SCHEDULER_FUNCTION = {
|
| 945 |
+
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
| 946 |
+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
| 947 |
+
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
|
| 948 |
+
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
| 949 |
+
SchedulerType.CONSTANT: get_constant_schedule,
|
| 950 |
+
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
| 951 |
+
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
|
| 952 |
+
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
|
| 953 |
+
SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
|
| 954 |
+
SchedulerType.COSINE_WARMUP_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup_lr_rate,
|
| 955 |
+
SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule,
|
| 956 |
+
SchedulerType.GREEDY: get_greedy_schedule,
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
def get_scheduler(
|
| 961 |
+
name: str | SchedulerType,
|
| 962 |
+
optimizer: Optimizer,
|
| 963 |
+
num_warmup_steps: int | None = None,
|
| 964 |
+
num_training_steps: int | None = None,
|
| 965 |
+
scheduler_specific_kwargs: dict | None = None,
|
| 966 |
+
):
|
| 967 |
+
"""
|
| 968 |
+
Unified API to get any scheduler from its name.
|
| 969 |
+
|
| 970 |
+
Args:
|
| 971 |
+
name (`str` or `SchedulerType`):
|
| 972 |
+
The name of the scheduler to use.
|
| 973 |
+
optimizer (`torch.optim.Optimizer`):
|
| 974 |
+
The optimizer that will be used during training.
|
| 975 |
+
num_warmup_steps (`int`, *optional*):
|
| 976 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
| 977 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 978 |
+
num_training_steps (`int``, *optional*):
|
| 979 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
| 980 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 981 |
+
scheduler_specific_kwargs (`dict`, *optional*):
|
| 982 |
+
Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
|
| 983 |
+
parameters will cause the scheduler function to raise a TypeError.
|
| 984 |
+
"""
|
| 985 |
+
name = SchedulerType(name)
|
| 986 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
| 987 |
+
|
| 988 |
+
# If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
|
| 989 |
+
# recursively call `get_scheduler` to get the proper schedulers on each parameter
|
| 990 |
+
if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
|
| 991 |
+
optimizer_dict = optimizer.optimizer_dict
|
| 992 |
+
scheduler_dict = {}
|
| 993 |
+
|
| 994 |
+
for param in optimizer_dict:
|
| 995 |
+
scheduler_dict[param] = get_scheduler(
|
| 996 |
+
name,
|
| 997 |
+
optimizer=optimizer_dict[param],
|
| 998 |
+
num_warmup_steps=num_warmup_steps,
|
| 999 |
+
num_training_steps=num_training_steps,
|
| 1000 |
+
scheduler_specific_kwargs=scheduler_specific_kwargs,
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
def scheduler_hook(param):
|
| 1004 |
+
# Since the optimizer hook has been already attached we only need to
|
| 1005 |
+
# attach the scheduler hook, the gradients have been zeroed here
|
| 1006 |
+
scheduler_dict[param].step()
|
| 1007 |
+
|
| 1008 |
+
for param in optimizer_dict:
|
| 1009 |
+
if param.requires_grad:
|
| 1010 |
+
param.register_post_accumulate_grad_hook(scheduler_hook)
|
| 1011 |
+
|
| 1012 |
+
return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"])
|
| 1013 |
+
|
| 1014 |
+
if name == SchedulerType.CONSTANT:
|
| 1015 |
+
return schedule_func(optimizer)
|
| 1016 |
+
|
| 1017 |
+
if scheduler_specific_kwargs is None:
|
| 1018 |
+
scheduler_specific_kwargs = {}
|
| 1019 |
+
|
| 1020 |
+
if name == SchedulerType.REDUCE_ON_PLATEAU:
|
| 1021 |
+
return schedule_func(optimizer, **scheduler_specific_kwargs)
|
| 1022 |
+
|
| 1023 |
+
if name == SchedulerType.GREEDY:
|
| 1024 |
+
return schedule_func(optimizer, **scheduler_specific_kwargs)
|
| 1025 |
+
|
| 1026 |
+
# All other schedulers require `num_warmup_steps`
|
| 1027 |
+
if num_warmup_steps is None:
|
| 1028 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
| 1029 |
+
|
| 1030 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
| 1031 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
| 1032 |
+
|
| 1033 |
+
if name == SchedulerType.INVERSE_SQRT:
|
| 1034 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **scheduler_specific_kwargs)
|
| 1035 |
+
|
| 1036 |
+
# wsd scheduler requires either num_training_steps or num_stable_steps
|
| 1037 |
+
if name == SchedulerType.WARMUP_STABLE_DECAY:
|
| 1038 |
+
return schedule_func(
|
| 1039 |
+
optimizer,
|
| 1040 |
+
num_warmup_steps=num_warmup_steps,
|
| 1041 |
+
num_training_steps=num_training_steps,
|
| 1042 |
+
**scheduler_specific_kwargs,
|
| 1043 |
+
)
|
| 1044 |
+
|
| 1045 |
+
# All other schedulers require `num_training_steps`
|
| 1046 |
+
if num_training_steps is None:
|
| 1047 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
| 1048 |
+
|
| 1049 |
+
return schedule_func(
|
| 1050 |
+
optimizer,
|
| 1051 |
+
num_warmup_steps=num_warmup_steps,
|
| 1052 |
+
num_training_steps=num_training_steps,
|
| 1053 |
+
**scheduler_specific_kwargs,
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
|
| 1057 |
+
class Adafactor(Optimizer):
|
| 1058 |
+
"""
|
| 1059 |
+
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
|
| 1060 |
+
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
| 1061 |
+
|
| 1062 |
+
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://huggingface.co/papers/1804.04235 Note that
|
| 1063 |
+
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
|
| 1064 |
+
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
| 1065 |
+
`relative_step=False`.
|
| 1066 |
+
|
| 1067 |
+
Arguments:
|
| 1068 |
+
params (`Iterable[nn.parameter.Parameter]`):
|
| 1069 |
+
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
| 1070 |
+
lr (`float`, *optional*):
|
| 1071 |
+
The external learning rate.
|
| 1072 |
+
eps (`tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
|
| 1073 |
+
Regularization constants for square gradient and parameter scale respectively
|
| 1074 |
+
clip_threshold (`float`, *optional*, defaults to 1.0):
|
| 1075 |
+
Threshold of root mean square of final gradient update
|
| 1076 |
+
decay_rate (`float`, *optional*, defaults to -0.8):
|
| 1077 |
+
Coefficient used to compute running averages of square
|
| 1078 |
+
beta1 (`float`, *optional*):
|
| 1079 |
+
Coefficient used for computing running averages of gradient
|
| 1080 |
+
weight_decay (`float`, *optional*, defaults to 0.0):
|
| 1081 |
+
Weight decay (L2 penalty)
|
| 1082 |
+
scale_parameter (`bool`, *optional*, defaults to `True`):
|
| 1083 |
+
If True, learning rate is scaled by root mean square
|
| 1084 |
+
relative_step (`bool`, *optional*, defaults to `True`):
|
| 1085 |
+
If True, time-dependent learning rate is computed instead of external learning rate
|
| 1086 |
+
warmup_init (`bool`, *optional*, defaults to `False`):
|
| 1087 |
+
Time-dependent learning rate computation depends on whether warm-up initialization is being used
|
| 1088 |
+
|
| 1089 |
+
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
|
| 1090 |
+
|
| 1091 |
+
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
|
| 1092 |
+
|
| 1093 |
+
- Training without LR warmup or clip_threshold is not recommended.
|
| 1094 |
+
|
| 1095 |
+
- use scheduled LR warm-up to fixed LR
|
| 1096 |
+
- use clip_threshold=1.0 (https://huggingface.co/papers/1804.04235)
|
| 1097 |
+
- Disable relative updates
|
| 1098 |
+
- Use scale_parameter=False
|
| 1099 |
+
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
|
| 1100 |
+
|
| 1101 |
+
Example:
|
| 1102 |
+
|
| 1103 |
+
```python
|
| 1104 |
+
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
|
| 1105 |
+
```
|
| 1106 |
+
|
| 1107 |
+
Others reported the following combination to work well:
|
| 1108 |
+
|
| 1109 |
+
```python
|
| 1110 |
+
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
| 1111 |
+
```
|
| 1112 |
+
|
| 1113 |
+
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
|
| 1114 |
+
scheduler as following:
|
| 1115 |
+
|
| 1116 |
+
```python
|
| 1117 |
+
from transformers.optimization import Adafactor, AdafactorSchedule
|
| 1118 |
+
|
| 1119 |
+
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
| 1120 |
+
lr_scheduler = AdafactorSchedule(optimizer)
|
| 1121 |
+
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
|
| 1122 |
+
```
|
| 1123 |
+
|
| 1124 |
+
Usage:
|
| 1125 |
+
|
| 1126 |
+
```python
|
| 1127 |
+
# replace AdamW with Adafactor
|
| 1128 |
+
optimizer = Adafactor(
|
| 1129 |
+
model.parameters(),
|
| 1130 |
+
lr=1e-3,
|
| 1131 |
+
eps=(1e-30, 1e-3),
|
| 1132 |
+
clip_threshold=1.0,
|
| 1133 |
+
decay_rate=-0.8,
|
| 1134 |
+
beta1=None,
|
| 1135 |
+
weight_decay=0.0,
|
| 1136 |
+
relative_step=False,
|
| 1137 |
+
scale_parameter=False,
|
| 1138 |
+
warmup_init=False,
|
| 1139 |
+
)
|
| 1140 |
+
```"""
|
| 1141 |
+
|
| 1142 |
+
def __init__(
|
| 1143 |
+
self,
|
| 1144 |
+
params,
|
| 1145 |
+
lr=None,
|
| 1146 |
+
eps=(1e-30, 1e-3),
|
| 1147 |
+
clip_threshold=1.0,
|
| 1148 |
+
decay_rate=-0.8,
|
| 1149 |
+
beta1=None,
|
| 1150 |
+
weight_decay=0.0,
|
| 1151 |
+
scale_parameter=True,
|
| 1152 |
+
relative_step=True,
|
| 1153 |
+
warmup_init=False,
|
| 1154 |
+
):
|
| 1155 |
+
if lr is not None and relative_step:
|
| 1156 |
+
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
|
| 1157 |
+
if warmup_init and not relative_step:
|
| 1158 |
+
raise ValueError("`warmup_init=True` requires `relative_step=True`")
|
| 1159 |
+
|
| 1160 |
+
defaults = {
|
| 1161 |
+
"lr": lr,
|
| 1162 |
+
"eps": eps,
|
| 1163 |
+
"clip_threshold": clip_threshold,
|
| 1164 |
+
"decay_rate": decay_rate,
|
| 1165 |
+
"beta1": beta1,
|
| 1166 |
+
"weight_decay": weight_decay,
|
| 1167 |
+
"scale_parameter": scale_parameter,
|
| 1168 |
+
"relative_step": relative_step,
|
| 1169 |
+
"warmup_init": warmup_init,
|
| 1170 |
+
}
|
| 1171 |
+
super().__init__(params, defaults)
|
| 1172 |
+
|
| 1173 |
+
@staticmethod
|
| 1174 |
+
def _get_lr(param_group, param_state):
|
| 1175 |
+
rel_step_sz = param_group["lr"]
|
| 1176 |
+
if param_group["relative_step"]:
|
| 1177 |
+
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
|
| 1178 |
+
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
|
| 1179 |
+
param_scale = 1.0
|
| 1180 |
+
if param_group["scale_parameter"]:
|
| 1181 |
+
param_scale = max(param_group["eps"][1], param_state["RMS"])
|
| 1182 |
+
return param_scale * rel_step_sz
|
| 1183 |
+
|
| 1184 |
+
@staticmethod
|
| 1185 |
+
def _get_options(param_group, param_shape):
|
| 1186 |
+
factored = len(param_shape) >= 2
|
| 1187 |
+
use_first_moment = param_group["beta1"] is not None
|
| 1188 |
+
return factored, use_first_moment
|
| 1189 |
+
|
| 1190 |
+
@staticmethod
|
| 1191 |
+
def _rms(tensor):
|
| 1192 |
+
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
| 1193 |
+
|
| 1194 |
+
@staticmethod
|
| 1195 |
+
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
| 1196 |
+
# copy from fairseq's adafactor implementation:
|
| 1197 |
+
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
|
| 1198 |
+
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
| 1199 |
+
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
| 1200 |
+
return torch.mul(r_factor, c_factor)
|
| 1201 |
+
|
| 1202 |
+
@torch.no_grad()
|
| 1203 |
+
def step(self, closure=None):
|
| 1204 |
+
"""
|
| 1205 |
+
Performs a single optimization step
|
| 1206 |
+
|
| 1207 |
+
Arguments:
|
| 1208 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 1209 |
+
and returns the loss.
|
| 1210 |
+
"""
|
| 1211 |
+
loss = None
|
| 1212 |
+
if closure is not None:
|
| 1213 |
+
loss = closure()
|
| 1214 |
+
|
| 1215 |
+
for group in self.param_groups:
|
| 1216 |
+
for p in group["params"]:
|
| 1217 |
+
if p.grad is None:
|
| 1218 |
+
continue
|
| 1219 |
+
grad = p.grad
|
| 1220 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
| 1221 |
+
grad = grad.float()
|
| 1222 |
+
if grad.is_sparse:
|
| 1223 |
+
raise RuntimeError("Adafactor does not support sparse gradients.")
|
| 1224 |
+
|
| 1225 |
+
state = self.state[p]
|
| 1226 |
+
grad_shape = grad.shape
|
| 1227 |
+
|
| 1228 |
+
factored, use_first_moment = self._get_options(group, grad_shape)
|
| 1229 |
+
# State Initialization
|
| 1230 |
+
if len(state) == 0:
|
| 1231 |
+
state["step"] = 0
|
| 1232 |
+
|
| 1233 |
+
if use_first_moment:
|
| 1234 |
+
# Exponential moving average of gradient values
|
| 1235 |
+
state["exp_avg"] = torch.zeros_like(grad)
|
| 1236 |
+
if factored:
|
| 1237 |
+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
| 1238 |
+
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
| 1239 |
+
else:
|
| 1240 |
+
state["exp_avg_sq"] = torch.zeros_like(grad)
|
| 1241 |
+
|
| 1242 |
+
state["RMS"] = 0
|
| 1243 |
+
else:
|
| 1244 |
+
if use_first_moment:
|
| 1245 |
+
state["exp_avg"] = state["exp_avg"].to(grad)
|
| 1246 |
+
if factored:
|
| 1247 |
+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
| 1248 |
+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
| 1249 |
+
else:
|
| 1250 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
| 1251 |
+
|
| 1252 |
+
p_data_fp32 = p
|
| 1253 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 1254 |
+
p_data_fp32 = p_data_fp32.float()
|
| 1255 |
+
|
| 1256 |
+
state["step"] += 1
|
| 1257 |
+
state["RMS"] = self._rms(p_data_fp32)
|
| 1258 |
+
lr = self._get_lr(group, state)
|
| 1259 |
+
|
| 1260 |
+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
| 1261 |
+
update = (grad**2) + group["eps"][0]
|
| 1262 |
+
if factored:
|
| 1263 |
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
| 1264 |
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
| 1265 |
+
|
| 1266 |
+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
| 1267 |
+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
| 1268 |
+
|
| 1269 |
+
# Approximation of exponential moving average of square of gradient
|
| 1270 |
+
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
| 1271 |
+
update.mul_(grad)
|
| 1272 |
+
else:
|
| 1273 |
+
exp_avg_sq = state["exp_avg_sq"]
|
| 1274 |
+
|
| 1275 |
+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
| 1276 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
| 1277 |
+
|
| 1278 |
+
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
| 1279 |
+
update.mul_(lr)
|
| 1280 |
+
|
| 1281 |
+
if use_first_moment:
|
| 1282 |
+
exp_avg = state["exp_avg"]
|
| 1283 |
+
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
| 1284 |
+
update = exp_avg
|
| 1285 |
+
|
| 1286 |
+
if group["weight_decay"] != 0:
|
| 1287 |
+
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
| 1288 |
+
|
| 1289 |
+
p_data_fp32.add_(-update)
|
| 1290 |
+
|
| 1291 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 1292 |
+
p.copy_(p_data_fp32)
|
| 1293 |
+
|
| 1294 |
+
return loss
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
class AdafactorSchedule(LambdaLR):
|
| 1298 |
+
"""
|
| 1299 |
+
Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g.,
|
| 1300 |
+
for logging), this class creates a proxy object that retrieves the current lr values from the optimizer.
|
| 1301 |
+
|
| 1302 |
+
It returns `initial_lr` during startup and the actual `lr` during stepping.
|
| 1303 |
+
"""
|
| 1304 |
+
|
| 1305 |
+
def __init__(self, optimizer, initial_lr=0.0):
|
| 1306 |
+
def lr_lambda(_):
|
| 1307 |
+
return initial_lr
|
| 1308 |
+
|
| 1309 |
+
for group in optimizer.param_groups:
|
| 1310 |
+
group["initial_lr"] = initial_lr
|
| 1311 |
+
super().__init__(optimizer, lr_lambda)
|
| 1312 |
+
for group in optimizer.param_groups:
|
| 1313 |
+
del group["initial_lr"]
|
| 1314 |
+
|
| 1315 |
+
def get_lr(self):
|
| 1316 |
+
opt = self.optimizer
|
| 1317 |
+
lrs = [
|
| 1318 |
+
opt._get_lr(group, opt.state[group["params"][0]])
|
| 1319 |
+
for group in opt.param_groups
|
| 1320 |
+
if group["params"][0].grad is not None
|
| 1321 |
+
]
|
| 1322 |
+
if len(lrs) == 0:
|
| 1323 |
+
lrs = self.base_lrs # if called before stepping
|
| 1324 |
+
return lrs
|
| 1325 |
+
|
| 1326 |
+
|
| 1327 |
+
def get_adafactor_schedule(optimizer, initial_lr=0.0):
|
| 1328 |
+
"""
|
| 1329 |
+
Get a proxy schedule for [`~optimization.Adafactor`]
|
| 1330 |
+
|
| 1331 |
+
Args:
|
| 1332 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 1333 |
+
The optimizer for which to schedule the learning rate.
|
| 1334 |
+
initial_lr (`float`, *optional*, defaults to 0.0):
|
| 1335 |
+
Initial lr
|
| 1336 |
+
|
| 1337 |
+
Return:
|
| 1338 |
+
[`~optimization.Adafactor`] proxy schedule object.
|
| 1339 |
+
|
| 1340 |
+
|
| 1341 |
+
"""
|
| 1342 |
+
return AdafactorSchedule(optimizer, initial_lr)
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/tokenization_python.py
ADDED
|
@@ -0,0 +1,1420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see
|
| 16 |
+
tokenization_utils_tokenizers.py
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import bisect
|
| 20 |
+
import unicodedata
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from typing import Any, overload
|
| 23 |
+
|
| 24 |
+
from .tokenization_utils_base import (
|
| 25 |
+
INIT_TOKENIZER_DOCSTRING,
|
| 26 |
+
AddedToken,
|
| 27 |
+
BatchEncoding,
|
| 28 |
+
EncodedInput,
|
| 29 |
+
PreTokenizedInput,
|
| 30 |
+
PreTrainedTokenizerBase,
|
| 31 |
+
TextInput,
|
| 32 |
+
TruncationStrategy,
|
| 33 |
+
)
|
| 34 |
+
from .utils import PaddingStrategy, TensorType, add_end_docstrings, logging
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.get_logger(__name__)
|
| 38 |
+
|
| 39 |
+
# Slow tokenizers are saved in a vocabulary plus three separated files
|
| 40 |
+
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
|
| 41 |
+
ADDED_TOKENS_FILE = "added_tokens.json"
|
| 42 |
+
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Trie:
|
| 46 |
+
"""
|
| 47 |
+
Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
|
| 48 |
+
Loose reference https://en.wikipedia.org/wiki/Trie
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, *args):
|
| 52 |
+
self.data = {}
|
| 53 |
+
self._tokens = set()
|
| 54 |
+
self._termination_char = ""
|
| 55 |
+
self.update(*args)
|
| 56 |
+
|
| 57 |
+
def update(self, *args):
|
| 58 |
+
"""
|
| 59 |
+
Updates the Trie with new tokens provided as arguments.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
*args: Variable number of words to be added to the Trie.
|
| 63 |
+
"""
|
| 64 |
+
for token in tuple(*args):
|
| 65 |
+
self.add(token)
|
| 66 |
+
|
| 67 |
+
def add(self, word: str):
|
| 68 |
+
"""
|
| 69 |
+
Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
|
| 70 |
+
The special key `""` in `self._termination_char` is used to represent termination.
|
| 71 |
+
|
| 72 |
+
This function is idempotent, adding twice the same word will leave the trie unchanged
|
| 73 |
+
|
| 74 |
+
Example:
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
>>> trie = Trie()
|
| 78 |
+
>>> trie.add("Hello 友達")
|
| 79 |
+
>>> trie.data
|
| 80 |
+
{"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
|
| 81 |
+
|
| 82 |
+
>>> trie.add("Hello")
|
| 83 |
+
>>> trie.data
|
| 84 |
+
{"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
|
| 85 |
+
```
|
| 86 |
+
"""
|
| 87 |
+
if not word:
|
| 88 |
+
# Prevent empty string
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
self._tokens.add(word)
|
| 92 |
+
ref = self.data
|
| 93 |
+
for char in word:
|
| 94 |
+
ref[char] = ref.setdefault(char, {})
|
| 95 |
+
ref = ref[char]
|
| 96 |
+
ref[self._termination_char] = 1
|
| 97 |
+
|
| 98 |
+
def split(self, text: str) -> list[str]:
|
| 99 |
+
"""
|
| 100 |
+
Will look for the words added to the trie within `text`. Output is the original string split along the
|
| 101 |
+
boundaries of the words found.
|
| 102 |
+
|
| 103 |
+
This trie will match the longest possible word first !
|
| 104 |
+
|
| 105 |
+
Example:
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
>>> trie = Trie()
|
| 109 |
+
>>> trie.split("[CLS] This is a extra_id_100")
|
| 110 |
+
["[CLS] This is a extra_id_100"]
|
| 111 |
+
|
| 112 |
+
>>> trie.add("[CLS]")
|
| 113 |
+
>>> trie.add("extra_id_1")
|
| 114 |
+
>>> trie.add("extra_id_100")
|
| 115 |
+
>>> trie.split("[CLS] This is a extra_id_100")
|
| 116 |
+
["[CLS]", " This is a ", "extra_id_100"]
|
| 117 |
+
```
|
| 118 |
+
"""
|
| 119 |
+
# indexes are counted left of the chars index.
|
| 120 |
+
# "hello", index 0, is left of h, index 1 is between h and e.
|
| 121 |
+
# index 5 is right of the "o".
|
| 122 |
+
|
| 123 |
+
# States are going to capture every possible start (indexes as above)
|
| 124 |
+
# as keys, and have as values, a pointer to the position in the trie
|
| 125 |
+
# where we're at. This is a partial match for now.
|
| 126 |
+
# This enables to keep track of multiple matches while we're iterating
|
| 127 |
+
# the string
|
| 128 |
+
# If the trie contains, "blowing", and "lower" and we encounter the
|
| 129 |
+
# string "blower", we need to split into ["b", "lower"].
|
| 130 |
+
# This is where we need to keep track of multiple possible starts.
|
| 131 |
+
states = OrderedDict()
|
| 132 |
+
|
| 133 |
+
# This will contain every indices where we need
|
| 134 |
+
# to cut.
|
| 135 |
+
# We force to cut at offset 0 and len(text) (added later)
|
| 136 |
+
offsets = [0]
|
| 137 |
+
|
| 138 |
+
# This is used by the lookahead which needs to skip over
|
| 139 |
+
# some text where the full match exceeded the place in the initial
|
| 140 |
+
# for loop
|
| 141 |
+
skip = 0
|
| 142 |
+
# Main loop, Giving this algorithm O(n) complexity
|
| 143 |
+
for current, current_char in enumerate(text):
|
| 144 |
+
if skip and current < skip:
|
| 145 |
+
# Prevents the lookahead for matching twice
|
| 146 |
+
# like extra_id_100 and id_100
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
# This will track every state
|
| 150 |
+
# that stop matching, we need to stop tracking them.
|
| 151 |
+
# If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
|
| 152 |
+
# fail on "b", we need to remove 0 from the valid states.
|
| 153 |
+
to_remove = set()
|
| 154 |
+
# Whenever we found a match, we need to drop everything
|
| 155 |
+
# this is a greedy algorithm, it will match on the first found token
|
| 156 |
+
reset = False
|
| 157 |
+
|
| 158 |
+
# In this case, we already have partial matches (But unfinished)
|
| 159 |
+
for start, trie_pointer in states.items():
|
| 160 |
+
if "" in trie_pointer:
|
| 161 |
+
# This is a final match, we need to reset and
|
| 162 |
+
# store the results in `offsets`.
|
| 163 |
+
|
| 164 |
+
# Lookahead to match longest first
|
| 165 |
+
# Important in case of extra_id_1 vs extra_id_100
|
| 166 |
+
# Here we are also actively looking for other earlier partial
|
| 167 |
+
# matches
|
| 168 |
+
# "[CLS]", "L", we need to match CLS even if L is special
|
| 169 |
+
for lookstart, looktrie_pointer in states.items():
|
| 170 |
+
if lookstart > start:
|
| 171 |
+
# This partial match is later, we can stop looking
|
| 172 |
+
break
|
| 173 |
+
elif lookstart < start:
|
| 174 |
+
# This partial match is earlier, the trie pointer
|
| 175 |
+
# was already updated, so index is + 1
|
| 176 |
+
lookahead_index = current + 1
|
| 177 |
+
end = current + 1
|
| 178 |
+
else:
|
| 179 |
+
# Here lookstart == start and
|
| 180 |
+
# looktrie_pointer == trie_pointer
|
| 181 |
+
# It wasn't updated yet so indices are current ones
|
| 182 |
+
lookahead_index = current
|
| 183 |
+
end = current
|
| 184 |
+
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
| 185 |
+
if "" in looktrie_pointer:
|
| 186 |
+
start = lookstart
|
| 187 |
+
end = lookahead_index
|
| 188 |
+
skip = lookahead_index
|
| 189 |
+
|
| 190 |
+
while next_char in looktrie_pointer:
|
| 191 |
+
looktrie_pointer = looktrie_pointer[next_char]
|
| 192 |
+
lookahead_index += 1
|
| 193 |
+
if "" in looktrie_pointer:
|
| 194 |
+
start = lookstart
|
| 195 |
+
end = lookahead_index
|
| 196 |
+
skip = lookahead_index
|
| 197 |
+
|
| 198 |
+
if lookahead_index == len(text):
|
| 199 |
+
# End of string
|
| 200 |
+
break
|
| 201 |
+
next_char = text[lookahead_index]
|
| 202 |
+
# End lookahead
|
| 203 |
+
|
| 204 |
+
# Storing and resetting
|
| 205 |
+
offsets.append(start)
|
| 206 |
+
offsets.append(end)
|
| 207 |
+
reset = True
|
| 208 |
+
break
|
| 209 |
+
elif current_char in trie_pointer:
|
| 210 |
+
# The current character being looked at has a match within the trie
|
| 211 |
+
# update the pointer (it will be stored back into states later).
|
| 212 |
+
trie_pointer = trie_pointer[current_char]
|
| 213 |
+
|
| 214 |
+
# Storing back the new pointer into the states.
|
| 215 |
+
# Partial matches got longer by one.
|
| 216 |
+
states[start] = trie_pointer
|
| 217 |
+
else:
|
| 218 |
+
# The new character has not match in the trie, we need
|
| 219 |
+
# to stop keeping track of this partial match.
|
| 220 |
+
# We can't do it directly within the loop because of how
|
| 221 |
+
# python iteration works
|
| 222 |
+
to_remove.add(start)
|
| 223 |
+
|
| 224 |
+
# Either clearing the full start (we found a real match)
|
| 225 |
+
# Or clearing only the partial matches that didn't work.
|
| 226 |
+
if reset:
|
| 227 |
+
states = {}
|
| 228 |
+
else:
|
| 229 |
+
for start in to_remove:
|
| 230 |
+
del states[start]
|
| 231 |
+
|
| 232 |
+
# If this character is a starting character within the trie
|
| 233 |
+
# start keeping track of this partial match.
|
| 234 |
+
if current >= skip and current_char in self.data:
|
| 235 |
+
states[current] = self.data[current_char]
|
| 236 |
+
|
| 237 |
+
# We have a cut at the end with states.
|
| 238 |
+
for start, trie_pointer in states.items():
|
| 239 |
+
if "" in trie_pointer:
|
| 240 |
+
# This is a final match, we need to reset and
|
| 241 |
+
# store the results in `offsets`.
|
| 242 |
+
end = len(text)
|
| 243 |
+
offsets.append(start)
|
| 244 |
+
offsets.append(end)
|
| 245 |
+
# Longest cut is always the one with lower start so the first
|
| 246 |
+
# item so we need to break.
|
| 247 |
+
break
|
| 248 |
+
|
| 249 |
+
return self.cut_text(text, offsets)
|
| 250 |
+
|
| 251 |
+
def cut_text(self, text, offsets):
|
| 252 |
+
# We have all the offsets now, we just need to do the actual splitting.
|
| 253 |
+
# We need to eventually add the first part of the string and the eventual
|
| 254 |
+
# last part.
|
| 255 |
+
offsets.append(len(text))
|
| 256 |
+
tokens = []
|
| 257 |
+
start = 0
|
| 258 |
+
for end in offsets:
|
| 259 |
+
if start > end:
|
| 260 |
+
logger.error(
|
| 261 |
+
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
|
| 262 |
+
" anyway."
|
| 263 |
+
)
|
| 264 |
+
continue
|
| 265 |
+
elif start == end:
|
| 266 |
+
# This might happen if there's a match at index 0
|
| 267 |
+
# we're also preventing zero-width cuts in case of two
|
| 268 |
+
# consecutive matches
|
| 269 |
+
continue
|
| 270 |
+
tokens.append(text[start:end])
|
| 271 |
+
start = end
|
| 272 |
+
|
| 273 |
+
return tokens
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class ExtensionsTrie(Trie):
|
| 277 |
+
def __init__(self, *args):
|
| 278 |
+
super().__init__(*args)
|
| 279 |
+
|
| 280 |
+
def extensions(self, prefix: str):
|
| 281 |
+
"""
|
| 282 |
+
Generates all extensions of a given prefix token in the Trie.
|
| 283 |
+
|
| 284 |
+
Example:
|
| 285 |
+
|
| 286 |
+
```python
|
| 287 |
+
>>> trie = Trie()
|
| 288 |
+
>>> trie.add("apple")
|
| 289 |
+
>>> trie.add("app")
|
| 290 |
+
>>> trie.add("application")
|
| 291 |
+
>>> trie.extensions("app")
|
| 292 |
+
['app', 'apple', 'application']
|
| 293 |
+
```
|
| 294 |
+
"""
|
| 295 |
+
prefix_node = self._get_node(prefix)
|
| 296 |
+
ret = self._collect_tokens(prefix_node)
|
| 297 |
+
return [prefix + token for token in ret]
|
| 298 |
+
|
| 299 |
+
def _get_node(self, token: str) -> dict:
|
| 300 |
+
"""
|
| 301 |
+
Retrieves the node corresponding to the given token in the Trie.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
token (str): The token for which the corresponding node needs to be retrieved.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
dict: The node in the Trie corresponding to the given token.
|
| 308 |
+
"""
|
| 309 |
+
node = self.data
|
| 310 |
+
for char in token:
|
| 311 |
+
if char not in node:
|
| 312 |
+
break
|
| 313 |
+
|
| 314 |
+
node = node[char]
|
| 315 |
+
return node
|
| 316 |
+
|
| 317 |
+
def _collect_tokens(self, node: dict) -> list:
|
| 318 |
+
"""
|
| 319 |
+
Generates all tokens in the Trie starting from a given node.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
node (dict): The node in the Trie from which tokens need to be generated.
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
list: List of tokens generated from the given node.
|
| 326 |
+
"""
|
| 327 |
+
tokens = [self._termination_char] if self._termination_char in node else []
|
| 328 |
+
for token, subtrie_head in node.items():
|
| 329 |
+
if token != self._termination_char:
|
| 330 |
+
subtokens = self._collect_tokens(subtrie_head)
|
| 331 |
+
tokens.extend([token + subtoken for subtoken in subtokens])
|
| 332 |
+
return tokens
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _is_whitespace(char):
|
| 336 |
+
"""Checks whether `char` is a whitespace character."""
|
| 337 |
+
# \t, \n, and \r are technically control characters but we treat them
|
| 338 |
+
# as whitespace since they are generally considered as such.
|
| 339 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
| 340 |
+
return True
|
| 341 |
+
cat = unicodedata.category(char)
|
| 342 |
+
if cat == "Zs":
|
| 343 |
+
return True
|
| 344 |
+
return False
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def _is_control(char):
|
| 348 |
+
"""Checks whether `char` is a control character."""
|
| 349 |
+
# These are technically control characters but we count them as whitespace
|
| 350 |
+
# characters.
|
| 351 |
+
if char == "\t" or char == "\n" or char == "\r":
|
| 352 |
+
return False
|
| 353 |
+
cat = unicodedata.category(char)
|
| 354 |
+
if cat.startswith("C"):
|
| 355 |
+
return True
|
| 356 |
+
return False
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def _is_punctuation(char):
|
| 360 |
+
"""Checks whether `char` is a punctuation character."""
|
| 361 |
+
cp = ord(char)
|
| 362 |
+
# We treat all non-letter/number ASCII as punctuation.
|
| 363 |
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
| 364 |
+
# Punctuation class but we treat them as punctuation anyways, for
|
| 365 |
+
# consistency.
|
| 366 |
+
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
|
| 367 |
+
return True
|
| 368 |
+
cat = unicodedata.category(char)
|
| 369 |
+
if cat.startswith("P"):
|
| 370 |
+
return True
|
| 371 |
+
return False
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _is_end_of_word(text):
|
| 375 |
+
"""Checks whether the last character in text is one of a punctuation, control or whitespace character."""
|
| 376 |
+
last_char = text[-1]
|
| 377 |
+
return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def _is_start_of_word(text):
|
| 381 |
+
"""Checks whether the first character in text is one of a punctuation, control or whitespace character."""
|
| 382 |
+
first_char = text[0]
|
| 383 |
+
return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _insert_one_token_to_ordered_list(token_list: list[str], new_token: str):
|
| 387 |
+
"""
|
| 388 |
+
Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted.
|
| 389 |
+
"""
|
| 390 |
+
insertion_idx = bisect.bisect_left(token_list, new_token)
|
| 391 |
+
# Checks if new_token is already in the ordered token_list
|
| 392 |
+
if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token:
|
| 393 |
+
# new_token is in token_list, don't add
|
| 394 |
+
return
|
| 395 |
+
else:
|
| 396 |
+
token_list.insert(insertion_idx, new_token)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
|
| 400 |
+
class PythonBackend(PreTrainedTokenizerBase):
|
| 401 |
+
"""
|
| 402 |
+
Base class for all slow tokenizers.
|
| 403 |
+
|
| 404 |
+
Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`].
|
| 405 |
+
|
| 406 |
+
Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading
|
| 407 |
+
pretrained tokenizers as well as adding tokens to the vocabulary.
|
| 408 |
+
|
| 409 |
+
This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the
|
| 410 |
+
specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
def __init__(self, **kwargs):
|
| 414 |
+
# 1. Init the parent class
|
| 415 |
+
|
| 416 |
+
self.tokens_trie = Trie()
|
| 417 |
+
|
| 418 |
+
# Initialize total_vocab_size early to avoid issues if get_vocab() is called early (custom tokenizers)
|
| 419 |
+
self.total_vocab_size = 0
|
| 420 |
+
|
| 421 |
+
# 2. init `_added_tokens_decoder` if child class did not
|
| 422 |
+
if not hasattr(self, "_added_tokens_decoder"):
|
| 423 |
+
self._added_tokens_decoder: dict[int, AddedToken] = {}
|
| 424 |
+
|
| 425 |
+
# 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite
|
| 426 |
+
self._added_tokens_decoder.update(kwargs.pop("added_tokens_decoder", {}))
|
| 427 |
+
self._added_tokens_encoder: dict[str, int] = {k.content: v for v, k in self._added_tokens_decoder.items()}
|
| 428 |
+
|
| 429 |
+
# 4. Token type ID configuration for dynamic mask building
|
| 430 |
+
# These can be overridden by subclasses to avoid overriding create_token_type_ids_from_sequences
|
| 431 |
+
self.token_type_ids_pattern = kwargs.pop("token_type_ids_pattern", "bert_style") # "all_zeros" or "bert_style"
|
| 432 |
+
self.token_type_ids_include_special_tokens = kwargs.pop("token_type_ids_include_special_tokens", True)
|
| 433 |
+
|
| 434 |
+
# 5. Special tokens mask configuration
|
| 435 |
+
# Patterns: "none", "cls_sep", "eos", "bos", "bos_eos", "cls_double_sep", "prefix_suffix"
|
| 436 |
+
self.special_tokens_pattern = kwargs.pop("special_tokens_pattern", None)
|
| 437 |
+
|
| 438 |
+
# 6. Set backend to "custom" if not already set (for direct PreTrainedTokenizer subclasses)
|
| 439 |
+
if "backend" not in kwargs:
|
| 440 |
+
kwargs["backend"] = "custom"
|
| 441 |
+
|
| 442 |
+
# 7. init the parent class
|
| 443 |
+
super().__init__(**kwargs)
|
| 444 |
+
|
| 445 |
+
# 4. If some of the special tokens are not part of the vocab, we add them, at the end.
|
| 446 |
+
# V5: the order of addition follows self.SPECIAL_TOKENS_ATTRIBUTES, then extra special tokens
|
| 447 |
+
# Note: _add_tokens will automatically skip tokens that are already in the base vocab
|
| 448 |
+
self._add_tokens(
|
| 449 |
+
[token for token in self.all_special_tokens if token not in self._added_tokens_encoder],
|
| 450 |
+
special_tokens=True,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
@property
|
| 454 |
+
def is_fast(self) -> bool:
|
| 455 |
+
return False
|
| 456 |
+
|
| 457 |
+
@property
|
| 458 |
+
def added_tokens_encoder(self) -> dict[str, int]:
|
| 459 |
+
"""
|
| 460 |
+
Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
|
| 461 |
+
optimisation in `self._added_tokens_encoder` for the slow tokenizers.
|
| 462 |
+
"""
|
| 463 |
+
return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
|
| 464 |
+
|
| 465 |
+
@property
|
| 466 |
+
def added_tokens_decoder(self) -> dict[int, AddedToken]:
|
| 467 |
+
"""
|
| 468 |
+
Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
`dict[str, int]`: The added tokens.
|
| 472 |
+
"""
|
| 473 |
+
return dict(sorted(self._added_tokens_decoder.items(), key=lambda item: item[0]))
|
| 474 |
+
|
| 475 |
+
@added_tokens_decoder.setter
|
| 476 |
+
def added_tokens_decoder(self, value: dict[int, AddedToken | str]) -> dict[int, AddedToken]:
|
| 477 |
+
# Always raise an error if string because users should define the behavior
|
| 478 |
+
for index, token in value.items():
|
| 479 |
+
if not isinstance(token, (str, AddedToken)) or not isinstance(index, int):
|
| 480 |
+
raise TypeError(
|
| 481 |
+
f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, should be a dict of {int, AddedToken | str}"
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
|
| 485 |
+
self._added_tokens_encoder[str(token)] = index
|
| 486 |
+
self._update_total_vocab_size()
|
| 487 |
+
|
| 488 |
+
def get_added_vocab(self) -> dict[str, int]:
|
| 489 |
+
"""
|
| 490 |
+
Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
|
| 491 |
+
the fast call because for now we always add the tokens even if they are already in the vocabulary. This is
|
| 492 |
+
something we should change.
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
`dict[str, int]`: The added tokens.
|
| 496 |
+
"""
|
| 497 |
+
return self._added_tokens_encoder
|
| 498 |
+
|
| 499 |
+
def __len__(self):
|
| 500 |
+
"""
|
| 501 |
+
Size of the full vocabulary with the added tokens.
|
| 502 |
+
"""
|
| 503 |
+
# Lazy evaluation: compute if not already set (e.g., during initialization)
|
| 504 |
+
if self.total_vocab_size == 0:
|
| 505 |
+
self._update_total_vocab_size()
|
| 506 |
+
return self.total_vocab_size
|
| 507 |
+
|
| 508 |
+
def _update_total_vocab_size(self):
|
| 509 |
+
"""
|
| 510 |
+
Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because
|
| 511 |
+
otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and
|
| 512 |
+
is only updated when adding tokens.
|
| 513 |
+
"""
|
| 514 |
+
self.total_vocab_size = len(self.get_vocab())
|
| 515 |
+
|
| 516 |
+
def _add_tokens(self, new_tokens: list[str] | list[AddedToken], special_tokens: bool = False) -> int:
|
| 517 |
+
"""
|
| 518 |
+
Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
|
| 519 |
+
it with indices starting from length of the current vocabulary. Special tokens are sometimes already in the
|
| 520 |
+
vocab which is why they have to be handled specifically.
|
| 521 |
+
|
| 522 |
+
Args:
|
| 523 |
+
new_tokens (`list[str]`or `list[tokenizers.AddedToken]`):
|
| 524 |
+
Token(s) to add in vocabulary. A token is counted as added if it's not already in the vocabulary
|
| 525 |
+
(tested by checking if the tokenizer assign the index of the `unk_token` to them). If a token is part
|
| 526 |
+
of the vocabulary then we simply mark this token as an `AddedToken` which allows to control the
|
| 527 |
+
stripping and normalization of this token. This is NOT possible in `tokenizers`.
|
| 528 |
+
special_tokens (`bool`, *optional*, defaults to `False`):
|
| 529 |
+
Whether or not the tokens should be added as special tokens.
|
| 530 |
+
|
| 531 |
+
Returns:
|
| 532 |
+
`int`: The number of tokens actually added to the vocabulary.
|
| 533 |
+
|
| 534 |
+
Examples:
|
| 535 |
+
|
| 536 |
+
```python
|
| 537 |
+
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
| 538 |
+
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 539 |
+
model = BertModel.from_pretrained("google-bert/bert-base-uncased")
|
| 540 |
+
|
| 541 |
+
num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
|
| 542 |
+
print("We have added", num_added_toks, "tokens")
|
| 543 |
+
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
| 544 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 545 |
+
```"""
|
| 546 |
+
added_tokens = 0
|
| 547 |
+
if new_tokens is None:
|
| 548 |
+
return added_tokens
|
| 549 |
+
# TODO this is fairly slow to improve!
|
| 550 |
+
current_vocab = self.get_vocab().copy()
|
| 551 |
+
new_idx = len(current_vocab) # only call this once, len gives the last index + 1
|
| 552 |
+
for token in new_tokens:
|
| 553 |
+
if not isinstance(token, (str, AddedToken)):
|
| 554 |
+
raise TypeError(f"Token {token} is not a string but a {type(token)}.")
|
| 555 |
+
if str(token) == "":
|
| 556 |
+
continue
|
| 557 |
+
if isinstance(token, str):
|
| 558 |
+
if token in self._added_tokens_encoder:
|
| 559 |
+
continue
|
| 560 |
+
else:
|
| 561 |
+
# very important for fast and slow equivalence!
|
| 562 |
+
is_special = token in self.all_special_tokens or special_tokens
|
| 563 |
+
token = AddedToken(
|
| 564 |
+
token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special
|
| 565 |
+
)
|
| 566 |
+
elif special_tokens:
|
| 567 |
+
# doing token.special=True changes the normalization! will fix in rust
|
| 568 |
+
# this is important and the only reason why the AddedTokens in each class are normalized by default
|
| 569 |
+
token.__setstate__({"special": True, "normalized": token.normalized})
|
| 570 |
+
if token in self._added_tokens_decoder:
|
| 571 |
+
continue
|
| 572 |
+
if not token.special and token.normalized and getattr(self, "do_lower_case", False):
|
| 573 |
+
# Normalize if requested
|
| 574 |
+
token.content = token.content.lower()
|
| 575 |
+
if token.content not in current_vocab:
|
| 576 |
+
token_index = new_idx + added_tokens
|
| 577 |
+
current_vocab[token.content] = token_index
|
| 578 |
+
added_tokens += 1
|
| 579 |
+
else:
|
| 580 |
+
token_index = current_vocab[token.content]
|
| 581 |
+
|
| 582 |
+
if token.special and str(token) not in self.all_special_tokens:
|
| 583 |
+
self._extra_special_tokens.append(token)
|
| 584 |
+
# the setter automatically updates the reverse map
|
| 585 |
+
self._added_tokens_decoder[token_index] = token
|
| 586 |
+
self._added_tokens_encoder[token.content] = token_index
|
| 587 |
+
if self.verbose:
|
| 588 |
+
logger.info(f"Adding {token} to the vocabulary")
|
| 589 |
+
|
| 590 |
+
self._update_trie()
|
| 591 |
+
self._update_total_vocab_size()
|
| 592 |
+
return added_tokens
|
| 593 |
+
|
| 594 |
+
def _update_trie(self, unique_no_split_tokens: list[str] | None = None):
|
| 595 |
+
for token in self._added_tokens_decoder.values():
|
| 596 |
+
if token.content not in self.tokens_trie._tokens:
|
| 597 |
+
self.tokens_trie.add(token.content)
|
| 598 |
+
for token in unique_no_split_tokens or []:
|
| 599 |
+
if token not in self.tokens_trie._tokens:
|
| 600 |
+
self.tokens_trie.add(token)
|
| 601 |
+
|
| 602 |
+
def num_special_tokens_to_add(self, pair: bool = False) -> int:
|
| 603 |
+
"""
|
| 604 |
+
Returns the number of added tokens when encoding a sequence with special tokens.
|
| 605 |
+
|
| 606 |
+
<Tip>
|
| 607 |
+
|
| 608 |
+
This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
|
| 609 |
+
this inside your training loop.
|
| 610 |
+
|
| 611 |
+
</Tip>
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
pair (`bool`, *optional*, defaults to `False`):
|
| 615 |
+
Whether the number of added tokens should be computed in the case of a sequence pair or a single
|
| 616 |
+
sequence.
|
| 617 |
+
|
| 618 |
+
Returns:
|
| 619 |
+
`int`: Number of special tokens added to sequences.
|
| 620 |
+
"""
|
| 621 |
+
token_ids_0 = []
|
| 622 |
+
token_ids_1 = []
|
| 623 |
+
return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
|
| 624 |
+
|
| 625 |
+
def tokenize(self, text: TextInput, **kwargs) -> list[str]:
|
| 626 |
+
"""
|
| 627 |
+
Converts a string into a sequence of tokens, using the tokenizer.
|
| 628 |
+
|
| 629 |
+
Args:
|
| 630 |
+
text: The sequence to be encoded.
|
| 631 |
+
**kwargs: Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
|
| 632 |
+
|
| 633 |
+
Returns:
|
| 634 |
+
The list of tokens.
|
| 635 |
+
"""
|
| 636 |
+
split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)
|
| 637 |
+
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
| 638 |
+
|
| 639 |
+
if split_special_tokens:
|
| 640 |
+
# Don't split on any tokens - just tokenize directly
|
| 641 |
+
return self._tokenize(text)
|
| 642 |
+
|
| 643 |
+
# Split on added tokens
|
| 644 |
+
tokens = self.tokens_trie.split(text)
|
| 645 |
+
no_split_token = self._added_tokens_encoder.keys()
|
| 646 |
+
|
| 647 |
+
# Handle added token properties (lstrip, rstrip, single_word)
|
| 648 |
+
for i, token in enumerate(tokens):
|
| 649 |
+
if token in no_split_token:
|
| 650 |
+
tok_extended = self._added_tokens_decoder.get(self._added_tokens_encoder[token])
|
| 651 |
+
left = tokens[i - 1] if i > 0 else None
|
| 652 |
+
right = tokens[i + 1] if i < len(tokens) - 1 else None
|
| 653 |
+
|
| 654 |
+
if isinstance(tok_extended, AddedToken):
|
| 655 |
+
if tok_extended.rstrip and right:
|
| 656 |
+
tokens[i + 1] = right.lstrip()
|
| 657 |
+
if tok_extended.lstrip and left:
|
| 658 |
+
tokens[i - 1] = left.rstrip()
|
| 659 |
+
if tok_extended.single_word:
|
| 660 |
+
if left and left[-1] != " ":
|
| 661 |
+
tokens[i - 1] += token
|
| 662 |
+
tokens[i] = ""
|
| 663 |
+
elif right and right[0] != " ":
|
| 664 |
+
tokens[i + 1] = token + tokens[i + 1]
|
| 665 |
+
tokens[i] = ""
|
| 666 |
+
|
| 667 |
+
# Tokenize non-added tokens
|
| 668 |
+
result = []
|
| 669 |
+
all_special_tokens_set = set(self.all_special_tokens)
|
| 670 |
+
for token in tokens:
|
| 671 |
+
if not token:
|
| 672 |
+
continue
|
| 673 |
+
if token in no_split_token or token in all_special_tokens_set:
|
| 674 |
+
result.append(token)
|
| 675 |
+
else:
|
| 676 |
+
result.extend(self._tokenize(token))
|
| 677 |
+
|
| 678 |
+
return result
|
| 679 |
+
|
| 680 |
+
def _tokenize(self, text, **kwargs):
|
| 681 |
+
"""
|
| 682 |
+
Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
|
| 683 |
+
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
|
| 684 |
+
|
| 685 |
+
Do NOT take care of added tokens.
|
| 686 |
+
"""
|
| 687 |
+
raise NotImplementedError
|
| 688 |
+
|
| 689 |
+
def _convert_token_to_id_with_added_voc(self, token):
|
| 690 |
+
if token in self.added_tokens_encoder:
|
| 691 |
+
return self.added_tokens_encoder[token]
|
| 692 |
+
return self._convert_token_to_id(token)
|
| 693 |
+
|
| 694 |
+
def _convert_token_to_id(self, token):
|
| 695 |
+
raise NotImplementedError
|
| 696 |
+
|
| 697 |
+
def _encode_plus(
|
| 698 |
+
self,
|
| 699 |
+
text: TextInput | PreTokenizedInput | EncodedInput,
|
| 700 |
+
text_pair: TextInput | PreTokenizedInput | EncodedInput | None = None,
|
| 701 |
+
add_special_tokens: bool = True,
|
| 702 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 703 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
| 704 |
+
max_length: int | None = None,
|
| 705 |
+
stride: int = 0,
|
| 706 |
+
is_split_into_words: bool = False,
|
| 707 |
+
pad_to_multiple_of: int | None = None,
|
| 708 |
+
padding_side: str | None = None,
|
| 709 |
+
return_tensors: str | TensorType | None = None,
|
| 710 |
+
return_token_type_ids: bool | None = None,
|
| 711 |
+
return_attention_mask: bool | None = None,
|
| 712 |
+
return_overflowing_tokens: bool = False,
|
| 713 |
+
return_special_tokens_mask: bool = False,
|
| 714 |
+
return_length: bool = False,
|
| 715 |
+
verbose: bool = True,
|
| 716 |
+
**kwargs,
|
| 717 |
+
) -> BatchEncoding:
|
| 718 |
+
# Detect batched inputs (list of sequences)
|
| 719 |
+
is_batched = isinstance(text, (list, tuple)) and (
|
| 720 |
+
(not text and not is_split_into_words)
|
| 721 |
+
or (text and is_split_into_words and isinstance(text[0], (list, tuple)))
|
| 722 |
+
or (text and not is_split_into_words and isinstance(text[0], (str, list, tuple)))
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
if is_batched:
|
| 726 |
+
if text_pair is not None:
|
| 727 |
+
if not isinstance(text_pair, (list, tuple)) or len(text_pair) != len(text):
|
| 728 |
+
raise ValueError("If `text` is a batch, `text_pair` must also be a batch of the same length.")
|
| 729 |
+
pairs = text_pair if text_pair is not None else [None] * len(text)
|
| 730 |
+
|
| 731 |
+
batch_outputs = {}
|
| 732 |
+
for current_text, current_pair in zip(text, pairs):
|
| 733 |
+
# Handle tuples/lists as sequence pairs like ("text1", "text2")
|
| 734 |
+
# For is_split_into_words=True: only unpack if it's a tuple of exactly 2 sequences (pair)
|
| 735 |
+
# Otherwise, treat the list as a single pretokenized sequence
|
| 736 |
+
if (
|
| 737 |
+
isinstance(current_text, (list, tuple))
|
| 738 |
+
and current_text
|
| 739 |
+
and not isinstance(current_text[0], int)
|
| 740 |
+
and current_pair is None
|
| 741 |
+
):
|
| 742 |
+
# Check if this looks like a pair: tuple/list of length 2 where elements are strings or lists/tuples
|
| 743 |
+
is_pair = (
|
| 744 |
+
len(current_text) == 2
|
| 745 |
+
and (isinstance(current_text[0], str) or isinstance(current_text[0], (list, tuple)))
|
| 746 |
+
and (isinstance(current_text[1], str) or isinstance(current_text[1], (list, tuple)))
|
| 747 |
+
)
|
| 748 |
+
if is_pair:
|
| 749 |
+
current_text, current_pair = current_text
|
| 750 |
+
elif len(current_text) == 1:
|
| 751 |
+
current_text = current_text[0]
|
| 752 |
+
elif not is_split_into_words:
|
| 753 |
+
# Only raise error for non-pretokenized input
|
| 754 |
+
raise ValueError(f"Expected a pair of sequences, got {len(current_text)} sequences.")
|
| 755 |
+
|
| 756 |
+
current_output = self._encode_plus(
|
| 757 |
+
text=current_text,
|
| 758 |
+
text_pair=current_pair,
|
| 759 |
+
add_special_tokens=add_special_tokens,
|
| 760 |
+
padding_strategy=PaddingStrategy.DO_NOT_PAD, # we pad in batch afterward
|
| 761 |
+
truncation_strategy=truncation_strategy,
|
| 762 |
+
max_length=max_length,
|
| 763 |
+
stride=stride,
|
| 764 |
+
is_split_into_words=is_split_into_words,
|
| 765 |
+
pad_to_multiple_of=None, # we pad in batch afterward
|
| 766 |
+
padding_side=None, # we pad in batch afterward
|
| 767 |
+
return_tensors=None, # We convert the whole batch to tensors at the end
|
| 768 |
+
return_token_type_ids=return_token_type_ids,
|
| 769 |
+
return_attention_mask=False, # we pad in batch afterward
|
| 770 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 771 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 772 |
+
return_length=return_length,
|
| 773 |
+
verbose=verbose,
|
| 774 |
+
**kwargs,
|
| 775 |
+
)
|
| 776 |
+
for key, value in current_output.items():
|
| 777 |
+
batch_outputs.setdefault(key, []).append(value)
|
| 778 |
+
|
| 779 |
+
# Remove overflow-related keys before tensor conversion if return_tensors is set
|
| 780 |
+
# Slow tokenizers don't support returning these as tensors
|
| 781 |
+
if return_tensors and return_overflowing_tokens:
|
| 782 |
+
batch_outputs.pop("overflowing_tokens", None)
|
| 783 |
+
batch_outputs.pop("num_truncated_tokens", None)
|
| 784 |
+
|
| 785 |
+
batch_outputs = self.pad(
|
| 786 |
+
batch_outputs,
|
| 787 |
+
padding=padding_strategy.value,
|
| 788 |
+
max_length=max_length,
|
| 789 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 790 |
+
padding_side=padding_side,
|
| 791 |
+
return_attention_mask=return_attention_mask,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
| 795 |
+
|
| 796 |
+
# Single sequence handling
|
| 797 |
+
def get_input_ids(text):
|
| 798 |
+
if isinstance(text, str):
|
| 799 |
+
# Normal case: tokenize string
|
| 800 |
+
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
| 801 |
+
if isinstance(text, (list, tuple)) and text:
|
| 802 |
+
if isinstance(text[0], int):
|
| 803 |
+
return text
|
| 804 |
+
# Pre-tokenized strings
|
| 805 |
+
if isinstance(text[0], str):
|
| 806 |
+
if is_split_into_words:
|
| 807 |
+
return self.convert_tokens_to_ids(
|
| 808 |
+
[tok for word in text for tok in self.tokenize(word, **kwargs)]
|
| 809 |
+
)
|
| 810 |
+
return self.convert_tokens_to_ids(text)
|
| 811 |
+
raise ValueError(f"Input must be a string, list of strings, or list of ints, got: {type(text)}")
|
| 812 |
+
|
| 813 |
+
first_ids = get_input_ids(text)
|
| 814 |
+
second_ids = get_input_ids(text_pair) if text_pair is not None else None
|
| 815 |
+
|
| 816 |
+
return self.prepare_for_model(
|
| 817 |
+
first_ids,
|
| 818 |
+
pair_ids=second_ids,
|
| 819 |
+
add_special_tokens=add_special_tokens,
|
| 820 |
+
padding=padding_strategy.value,
|
| 821 |
+
truncation=truncation_strategy.value,
|
| 822 |
+
max_length=max_length,
|
| 823 |
+
stride=stride,
|
| 824 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 825 |
+
padding_side=padding_side,
|
| 826 |
+
return_tensors=return_tensors,
|
| 827 |
+
prepend_batch_axis=True,
|
| 828 |
+
return_attention_mask=return_attention_mask,
|
| 829 |
+
return_token_type_ids=return_token_type_ids,
|
| 830 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 831 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 832 |
+
return_length=return_length,
|
| 833 |
+
verbose=verbose,
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
def prepare_for_tokenization(
|
| 837 |
+
self, text: str, is_split_into_words: bool = False, **kwargs
|
| 838 |
+
) -> tuple[str, dict[str, Any]]:
|
| 839 |
+
"""
|
| 840 |
+
Performs any necessary transformations before tokenization.
|
| 841 |
+
|
| 842 |
+
This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the
|
| 843 |
+
`kwargs` at the end of the encoding process to be sure all the arguments have been used.
|
| 844 |
+
|
| 845 |
+
Args:
|
| 846 |
+
text (`str`):
|
| 847 |
+
The text to prepare.
|
| 848 |
+
is_split_into_words (`bool`, *optional*, defaults to `False`):
|
| 849 |
+
Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
|
| 850 |
+
tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
|
| 851 |
+
which it will tokenize. This is useful for NER or token classification.
|
| 852 |
+
kwargs (`dict[str, Any]`, *optional*):
|
| 853 |
+
Keyword arguments to use for the tokenization.
|
| 854 |
+
|
| 855 |
+
Returns:
|
| 856 |
+
`tuple[str, dict[str, Any]]`: The prepared text and the unused kwargs.
|
| 857 |
+
"""
|
| 858 |
+
return (text, kwargs)
|
| 859 |
+
|
| 860 |
+
def build_inputs_with_special_tokens(
|
| 861 |
+
self, token_ids_0: list[int], token_ids_1: list[int] | None = None
|
| 862 |
+
) -> list[int]:
|
| 863 |
+
"""
|
| 864 |
+
Build model inputs from a sequence or a pair of sequences by adding special tokens.
|
| 865 |
+
|
| 866 |
+
This method dynamically builds inputs based on the tokenizer's `special_tokens_pattern`:
|
| 867 |
+
- `"none"`: No special tokens
|
| 868 |
+
- `"cls_sep"`: [CLS] seq0 [SEP] or [CLS] seq0 [SEP] seq1 [SEP]
|
| 869 |
+
- `"eos"`: seq0 [EOS] or seq0 [EOS] seq1 [EOS]
|
| 870 |
+
- `"bos"`: [BOS] seq0 or [BOS] seq0 [BOS] seq1
|
| 871 |
+
- `"bos_eos"`: [BOS] seq0 [EOS] or [BOS] seq0 [EOS] seq1 [EOS]
|
| 872 |
+
- `"cls_double_sep"`: [CLS] seq0 [SEP] or [CLS] seq0 [SEP] [SEP] seq1 [SEP]
|
| 873 |
+
- `"prefix_suffix"`: `<prefix_tokens> seq0 [seq1] <suffix_tokens>` (custom prefix/suffix stored on the tokenizer)
|
| 874 |
+
|
| 875 |
+
Args:
|
| 876 |
+
token_ids_0 (`list[int]`):
|
| 877 |
+
List of IDs to which the special tokens will be added.
|
| 878 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 879 |
+
Optional second list of IDs for sequence pairs.
|
| 880 |
+
|
| 881 |
+
Returns:
|
| 882 |
+
`list[int]`: List of input IDs with the appropriate special tokens.
|
| 883 |
+
"""
|
| 884 |
+
if self.special_tokens_pattern == "cls_sep":
|
| 885 |
+
# [CLS] seq0 [SEP] or [CLS] seq0 [SEP] seq1 [SEP]
|
| 886 |
+
if self.cls_token_id is None and self.sep_token_id is None:
|
| 887 |
+
raise ValueError(
|
| 888 |
+
"Cannot add special tokens following 'cls_sep' pattern because one or several special tokens "
|
| 889 |
+
f"are not defined (cls_token_id={self.cls_token_id}; sep_token_id={self.sep_token_id})"
|
| 890 |
+
"Set the required special tokens in tokenizer or update `tokenizer.special_tokens_pattern`"
|
| 891 |
+
)
|
| 892 |
+
if token_ids_1 is None:
|
| 893 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 894 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1 + [self.sep_token_id]
|
| 895 |
+
|
| 896 |
+
elif self.special_tokens_pattern == "eos":
|
| 897 |
+
# seq0 [EOS] or seq0 [EOS] seq1 [EOS]
|
| 898 |
+
if self.eos_token_id is None:
|
| 899 |
+
raise ValueError(
|
| 900 |
+
"Cannot add special tokens following 'eos' pattern because eos token is not defined "
|
| 901 |
+
f"(eos_token_id={self.eos_token_id})."
|
| 902 |
+
"Set the required special tokens in tokenizer or update `tokenizer.special_tokens_pattern`"
|
| 903 |
+
)
|
| 904 |
+
if token_ids_1 is None:
|
| 905 |
+
return token_ids_0 + [self.eos_token_id]
|
| 906 |
+
return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
|
| 907 |
+
|
| 908 |
+
elif self.special_tokens_pattern == "bos":
|
| 909 |
+
# [BOS] seq0 or [BOS] seq0 [BOS] seq1
|
| 910 |
+
if self.bos_token_id is None:
|
| 911 |
+
raise ValueError(
|
| 912 |
+
"Cannot add special tokens following 'bos' pattern because bos token is not defined "
|
| 913 |
+
f"(bos_token_id={self.bos_token_id})."
|
| 914 |
+
"Set the required special tokens in tokenizer or update `tokenizer.special_tokens_pattern`"
|
| 915 |
+
)
|
| 916 |
+
if token_ids_1 is None:
|
| 917 |
+
return [self.bos_token_id] + token_ids_0
|
| 918 |
+
return [self.bos_token_id] + token_ids_0 + [self.bos_token_id] + token_ids_1
|
| 919 |
+
|
| 920 |
+
elif self.special_tokens_pattern == "bos_eos":
|
| 921 |
+
# [BOS] seq0 [EOS] or [BOS] seq0 [EOS] seq1 [EOS]
|
| 922 |
+
if self.bos_token_id is None and self.eos_token_id is None:
|
| 923 |
+
raise ValueError(
|
| 924 |
+
"Cannot add special tokens following 'bos_eos' pattern because one or several special tokens "
|
| 925 |
+
f"are not defined (bos_token_id={self.bos_token_id}; eos_token_id={self.eos_token_id})"
|
| 926 |
+
"Set the required special tokens in tokenizer or update `tokenizer.special_tokens_pattern`"
|
| 927 |
+
)
|
| 928 |
+
return token_ids_0 if token_ids_1 is None else token_ids_0 + token_ids_1
|
| 929 |
+
|
| 930 |
+
if token_ids_1 is None:
|
| 931 |
+
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
| 932 |
+
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
|
| 933 |
+
|
| 934 |
+
elif self.special_tokens_pattern == "cls_double_sep":
|
| 935 |
+
# [CLS] seq0 [SEP] or [CLS] seq0 [SEP] [SEP] seq1 [SEP]
|
| 936 |
+
if self.cls_token_id is None and self.sep_token_id is None:
|
| 937 |
+
raise ValueError(
|
| 938 |
+
"Cannot add special tokens following 'cls_double_sep' pattern because one or several special tokens "
|
| 939 |
+
f"are not defined (cls_token_id={self.cls_token_id}; sep_token_id={self.sep_token_id})"
|
| 940 |
+
"Set the required special tokens in tokenizer or update `tokenizer.special_tokens_pattern`"
|
| 941 |
+
)
|
| 942 |
+
if token_ids_1 is None:
|
| 943 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 944 |
+
return (
|
| 945 |
+
[self.cls_token_id]
|
| 946 |
+
+ token_ids_0
|
| 947 |
+
+ [self.sep_token_id, self.sep_token_id]
|
| 948 |
+
+ token_ids_1
|
| 949 |
+
+ [self.sep_token_id]
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
elif self.special_tokens_pattern == "prefix_suffix":
|
| 953 |
+
prefix_tokens = getattr(self, "prefix_tokens", [])
|
| 954 |
+
suffix_tokens = getattr(self, "suffix_tokens", [])
|
| 955 |
+
if token_ids_1 is None:
|
| 956 |
+
return prefix_tokens + token_ids_0 + suffix_tokens
|
| 957 |
+
return prefix_tokens + token_ids_0 + token_ids_1 + suffix_tokens
|
| 958 |
+
|
| 959 |
+
else: # "none" or any other value
|
| 960 |
+
# No special tokens
|
| 961 |
+
if token_ids_1 is None:
|
| 962 |
+
return token_ids_0
|
| 963 |
+
return token_ids_0 + token_ids_1
|
| 964 |
+
|
| 965 |
+
def get_special_tokens_mask(
|
| 966 |
+
self, token_ids_0: list, token_ids_1: list | None = None, already_has_special_tokens: bool = False
|
| 967 |
+
) -> list[int]:
|
| 968 |
+
"""
|
| 969 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 970 |
+
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
|
| 971 |
+
|
| 972 |
+
This method dynamically builds the special tokens mask based on the tokenizer's `special_tokens_pattern`:
|
| 973 |
+
- `"none"`: No special tokens (default, returns all 0s)
|
| 974 |
+
- `"cls_sep"`: [CLS] seq0 [SEP] or [CLS] seq0 [SEP] seq1 [SEP]
|
| 975 |
+
- `"eos"`: seq0 [EOS] or seq0 [EOS] seq1 [EOS]
|
| 976 |
+
- `"bos"`: [BOS] seq0 or [BOS] seq0 [BOS] seq1
|
| 977 |
+
- `"bos_eos"`: [BOS] seq0 [EOS] or [BOS] seq0 [EOS] seq1 [EOS]
|
| 978 |
+
- `"cls_double_sep"`: [CLS] seq0 [SEP] or [CLS] seq0 [SEP] [SEP] seq1 [SEP]
|
| 979 |
+
- `"prefix_suffix"`: `<prefix_tokens> seq0 [seq1] <suffix_tokens>`
|
| 980 |
+
|
| 981 |
+
Args:
|
| 982 |
+
token_ids_0 (`list[int]`):
|
| 983 |
+
List of ids of the first sequence.
|
| 984 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 985 |
+
List of ids of the second sequence.
|
| 986 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 987 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 988 |
+
|
| 989 |
+
Returns:
|
| 990 |
+
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 991 |
+
"""
|
| 992 |
+
if already_has_special_tokens:
|
| 993 |
+
if token_ids_1 is not None:
|
| 994 |
+
raise ValueError(
|
| 995 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 996 |
+
"ids is already formatted with special tokens for the model."
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
return super().get_special_tokens_mask(
|
| 1000 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
if self.special_tokens_pattern == "cls_sep":
|
| 1004 |
+
# [CLS] seq0 [SEP] or [CLS] seq0 [SEP] seq1 [SEP]
|
| 1005 |
+
if token_ids_1 is None:
|
| 1006 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 1007 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 1008 |
+
|
| 1009 |
+
elif self.special_tokens_pattern == "eos":
|
| 1010 |
+
# seq0 [EOS] or seq0 [EOS] seq1 [EOS]
|
| 1011 |
+
if token_ids_1 is None:
|
| 1012 |
+
return ([0] * len(token_ids_0)) + [1]
|
| 1013 |
+
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 1014 |
+
|
| 1015 |
+
elif self.special_tokens_pattern == "bos":
|
| 1016 |
+
# [BOS] seq0 or [BOS] seq0 [BOS] seq1
|
| 1017 |
+
if token_ids_1 is None:
|
| 1018 |
+
return [1] + ([0] * len(token_ids_0))
|
| 1019 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
|
| 1020 |
+
|
| 1021 |
+
elif self.special_tokens_pattern == "bos_eos":
|
| 1022 |
+
# [BOS] seq0 [EOS] or [BOS] seq0 [EOS] seq1 [EOS]
|
| 1023 |
+
if token_ids_1 is None:
|
| 1024 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 1025 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 1026 |
+
|
| 1027 |
+
elif self.special_tokens_pattern == "cls_double_sep":
|
| 1028 |
+
# [CLS] seq0 [SEP] or [CLS] seq0 [SEP] [SEP] seq1 [SEP]
|
| 1029 |
+
if token_ids_1 is None:
|
| 1030 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 1031 |
+
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
| 1032 |
+
|
| 1033 |
+
elif self.special_tokens_pattern == "prefix_suffix":
|
| 1034 |
+
prefix_len = len(getattr(self, "prefix_tokens", []))
|
| 1035 |
+
suffix_len = len(getattr(self, "suffix_tokens", []))
|
| 1036 |
+
mask = [1] * prefix_len + ([0] * len(token_ids_0))
|
| 1037 |
+
if token_ids_1 is not None:
|
| 1038 |
+
mask += [0] * len(token_ids_1)
|
| 1039 |
+
mask += [1] * suffix_len
|
| 1040 |
+
return mask
|
| 1041 |
+
|
| 1042 |
+
else:
|
| 1043 |
+
return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
|
| 1044 |
+
|
| 1045 |
+
@overload
|
| 1046 |
+
def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
|
| 1047 |
+
|
| 1048 |
+
@overload
|
| 1049 |
+
def convert_ids_to_tokens(self, ids: list[int], skip_special_tokens: bool = False) -> list[str]: ...
|
| 1050 |
+
|
| 1051 |
+
def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False) -> str | list[str]:
|
| 1052 |
+
"""
|
| 1053 |
+
Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
|
| 1054 |
+
added tokens.
|
| 1055 |
+
|
| 1056 |
+
Args:
|
| 1057 |
+
ids (`int` or `list[int]`):
|
| 1058 |
+
The token id (or token ids) to convert to tokens.
|
| 1059 |
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 1060 |
+
Whether or not to remove special tokens in the decoding.
|
| 1061 |
+
|
| 1062 |
+
Returns:
|
| 1063 |
+
`str` or `list[str]`: The decoded token(s).
|
| 1064 |
+
"""
|
| 1065 |
+
if isinstance(ids, int):
|
| 1066 |
+
return (
|
| 1067 |
+
self._added_tokens_decoder[ids].content
|
| 1068 |
+
if ids in self._added_tokens_decoder
|
| 1069 |
+
else self._convert_id_to_token(ids)
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
tokens = []
|
| 1073 |
+
# self.all_special_ids is an @property which may be slow, so only compute it once before the loop
|
| 1074 |
+
ids_to_skip = set(self.all_special_ids) if skip_special_tokens else set()
|
| 1075 |
+
for index in ids:
|
| 1076 |
+
index = int(index)
|
| 1077 |
+
if index in ids_to_skip:
|
| 1078 |
+
continue
|
| 1079 |
+
tokens.append(
|
| 1080 |
+
self._added_tokens_decoder[index].content
|
| 1081 |
+
if index in self._added_tokens_decoder
|
| 1082 |
+
else self._convert_id_to_token(index)
|
| 1083 |
+
)
|
| 1084 |
+
return tokens
|
| 1085 |
+
|
| 1086 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 1087 |
+
raise NotImplementedError
|
| 1088 |
+
|
| 1089 |
+
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
| 1090 |
+
return " ".join(tokens)
|
| 1091 |
+
|
| 1092 |
+
def _decode(
|
| 1093 |
+
self,
|
| 1094 |
+
token_ids: int | list[int],
|
| 1095 |
+
skip_special_tokens: bool = False,
|
| 1096 |
+
clean_up_tokenization_spaces: bool | None = None,
|
| 1097 |
+
**kwargs,
|
| 1098 |
+
) -> str:
|
| 1099 |
+
"""Decode token ids to string."""
|
| 1100 |
+
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
| 1101 |
+
if isinstance(filtered_tokens, str):
|
| 1102 |
+
filtered_tokens = [filtered_tokens]
|
| 1103 |
+
|
| 1104 |
+
text = self.convert_tokens_to_string(filtered_tokens)
|
| 1105 |
+
|
| 1106 |
+
# Apply tokenizer-specific cleanup if available and requested
|
| 1107 |
+
clean_up_tokenization_spaces = (
|
| 1108 |
+
clean_up_tokenization_spaces
|
| 1109 |
+
if clean_up_tokenization_spaces is not None
|
| 1110 |
+
else self.clean_up_tokenization_spaces
|
| 1111 |
+
)
|
| 1112 |
+
if clean_up_tokenization_spaces:
|
| 1113 |
+
text = self.clean_up_tokenization(text)
|
| 1114 |
+
|
| 1115 |
+
return text
|
| 1116 |
+
|
| 1117 |
+
def prepare_for_model(
|
| 1118 |
+
self,
|
| 1119 |
+
ids: list[int],
|
| 1120 |
+
pair_ids: list[int] | None = None,
|
| 1121 |
+
add_special_tokens: bool = True,
|
| 1122 |
+
padding: bool | str | PaddingStrategy = False,
|
| 1123 |
+
truncation: bool | str | TruncationStrategy = False,
|
| 1124 |
+
max_length: int | None = None,
|
| 1125 |
+
stride: int = 0,
|
| 1126 |
+
pad_to_multiple_of: int | None = None,
|
| 1127 |
+
padding_side: str | None = None,
|
| 1128 |
+
return_tensors: str | TensorType | None = None,
|
| 1129 |
+
return_token_type_ids: bool | None = None,
|
| 1130 |
+
return_attention_mask: bool | None = None,
|
| 1131 |
+
return_overflowing_tokens: bool = False,
|
| 1132 |
+
return_special_tokens_mask: bool = False,
|
| 1133 |
+
return_length: bool = False,
|
| 1134 |
+
verbose: bool = True,
|
| 1135 |
+
prepend_batch_axis: bool = False,
|
| 1136 |
+
**kwargs,
|
| 1137 |
+
) -> BatchEncoding:
|
| 1138 |
+
"""
|
| 1139 |
+
Prepares a sequence of input ids so it can be used by the model. Adds special tokens, truncates, and pads.
|
| 1140 |
+
|
| 1141 |
+
Args:
|
| 1142 |
+
ids: Tokenized input ids of the first sequence.
|
| 1143 |
+
pair_ids: Tokenized input ids of the second sequence (optional).
|
| 1144 |
+
"""
|
| 1145 |
+
# Get padding/truncation strategies
|
| 1146 |
+
padding_strategy, truncation_strategy, max_length, _ = self._get_padding_truncation_strategies(
|
| 1147 |
+
padding=padding,
|
| 1148 |
+
truncation=truncation,
|
| 1149 |
+
max_length=max_length,
|
| 1150 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 1151 |
+
verbose=verbose,
|
| 1152 |
+
**kwargs,
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
# Validation
|
| 1156 |
+
if (
|
| 1157 |
+
return_overflowing_tokens
|
| 1158 |
+
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
|
| 1159 |
+
and pair_ids is not None
|
| 1160 |
+
):
|
| 1161 |
+
raise ValueError(
|
| 1162 |
+
"Not possible to return overflowing tokens for pair of sequences with the "
|
| 1163 |
+
"`longest_first`. Please select another truncation strategy than `longest_first`, "
|
| 1164 |
+
"for instance `only_second` or `only_first`."
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
# Defaults
|
| 1168 |
+
if return_token_type_ids is None:
|
| 1169 |
+
return_token_type_ids = "token_type_ids" in self.model_input_names
|
| 1170 |
+
if return_attention_mask is None:
|
| 1171 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
| 1172 |
+
|
| 1173 |
+
# Truncation
|
| 1174 |
+
pair = pair_ids is not None
|
| 1175 |
+
num_special = self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0
|
| 1176 |
+
total_len = len(ids) + len(pair_ids or []) + num_special
|
| 1177 |
+
|
| 1178 |
+
overflowing_tokens = []
|
| 1179 |
+
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
| 1180 |
+
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
| 1181 |
+
ids,
|
| 1182 |
+
pair_ids=pair_ids,
|
| 1183 |
+
num_tokens_to_remove=total_len - max_length,
|
| 1184 |
+
truncation_strategy=truncation_strategy,
|
| 1185 |
+
stride=stride,
|
| 1186 |
+
)
|
| 1187 |
+
|
| 1188 |
+
# Add special tokens
|
| 1189 |
+
if add_special_tokens:
|
| 1190 |
+
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
| 1191 |
+
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
| 1192 |
+
else:
|
| 1193 |
+
sequence = ids + (pair_ids if pair_ids else [])
|
| 1194 |
+
token_type_ids = [0] * len(sequence)
|
| 1195 |
+
|
| 1196 |
+
# Build output
|
| 1197 |
+
encoded_inputs = {"input_ids": sequence}
|
| 1198 |
+
if return_token_type_ids:
|
| 1199 |
+
encoded_inputs["token_type_ids"] = token_type_ids
|
| 1200 |
+
if return_special_tokens_mask:
|
| 1201 |
+
encoded_inputs["special_tokens_mask"] = (
|
| 1202 |
+
self.get_special_tokens_mask(ids, pair_ids) if add_special_tokens else [0] * len(sequence)
|
| 1203 |
+
)
|
| 1204 |
+
if return_overflowing_tokens and not return_tensors and overflowing_tokens:
|
| 1205 |
+
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
| 1206 |
+
encoded_inputs["num_truncated_tokens"] = total_len - max_length if max_length else 0
|
| 1207 |
+
|
| 1208 |
+
# Check sequence length and warn if needed
|
| 1209 |
+
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
|
| 1210 |
+
|
| 1211 |
+
# Pad
|
| 1212 |
+
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
| 1213 |
+
encoded_inputs = self.pad(
|
| 1214 |
+
encoded_inputs,
|
| 1215 |
+
max_length=max_length,
|
| 1216 |
+
padding=padding_strategy.value,
|
| 1217 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 1218 |
+
padding_side=padding_side,
|
| 1219 |
+
return_attention_mask=return_attention_mask,
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
+
if return_length:
|
| 1223 |
+
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
|
| 1224 |
+
|
| 1225 |
+
return BatchEncoding(encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis)
|
| 1226 |
+
|
| 1227 |
+
def truncate_sequences(
|
| 1228 |
+
self,
|
| 1229 |
+
ids: list[int],
|
| 1230 |
+
pair_ids: list[int] | None = None,
|
| 1231 |
+
num_tokens_to_remove: int = 0,
|
| 1232 |
+
truncation_strategy: str | TruncationStrategy = "longest_first",
|
| 1233 |
+
stride: int = 0,
|
| 1234 |
+
) -> tuple[list[int], list[int], list[int]]:
|
| 1235 |
+
"""Truncates sequences according to the specified strategy."""
|
| 1236 |
+
if num_tokens_to_remove <= 0:
|
| 1237 |
+
return ids, pair_ids, []
|
| 1238 |
+
|
| 1239 |
+
if not isinstance(truncation_strategy, TruncationStrategy):
|
| 1240 |
+
truncation_strategy = TruncationStrategy(truncation_strategy)
|
| 1241 |
+
|
| 1242 |
+
overflowing_tokens = []
|
| 1243 |
+
|
| 1244 |
+
# ONLY_FIRST or LONGEST_FIRST with single sequence
|
| 1245 |
+
if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
|
| 1246 |
+
truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
|
| 1247 |
+
):
|
| 1248 |
+
window_len = min(len(ids), stride + num_tokens_to_remove)
|
| 1249 |
+
if self.truncation_side == "left":
|
| 1250 |
+
overflowing_tokens = ids[:window_len]
|
| 1251 |
+
ids = ids[num_tokens_to_remove:]
|
| 1252 |
+
else:
|
| 1253 |
+
overflowing_tokens = ids[-window_len:]
|
| 1254 |
+
ids = ids[:-num_tokens_to_remove]
|
| 1255 |
+
|
| 1256 |
+
# LONGEST_FIRST with pair
|
| 1257 |
+
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
|
| 1258 |
+
logger.warning(
|
| 1259 |
+
"Be aware, overflowing tokens are not returned for the setting you have chosen,"
|
| 1260 |
+
f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
|
| 1261 |
+
"truncation strategy. So the returned list will always be empty even if some "
|
| 1262 |
+
"tokens have been removed."
|
| 1263 |
+
)
|
| 1264 |
+
len_ids, len_pair = len(ids), len(pair_ids) if pair_ids else 0
|
| 1265 |
+
first_remove = min(abs(len_pair - len_ids), num_tokens_to_remove)
|
| 1266 |
+
second_remove = num_tokens_to_remove - first_remove
|
| 1267 |
+
|
| 1268 |
+
if len_ids > len_pair:
|
| 1269 |
+
ids_to_move = first_remove + second_remove // 2
|
| 1270 |
+
pair_ids_to_move = second_remove - second_remove // 2
|
| 1271 |
+
else:
|
| 1272 |
+
ids_to_move = second_remove // 2
|
| 1273 |
+
pair_ids_to_move = first_remove + second_remove - (second_remove // 2)
|
| 1274 |
+
|
| 1275 |
+
if self.truncation_side == "right":
|
| 1276 |
+
ids = ids[:-ids_to_move] if ids_to_move > 0 else ids
|
| 1277 |
+
pair_ids = pair_ids[:-pair_ids_to_move] if pair_ids and pair_ids_to_move > 0 else pair_ids
|
| 1278 |
+
else:
|
| 1279 |
+
ids = ids[ids_to_move:]
|
| 1280 |
+
pair_ids = pair_ids[pair_ids_to_move:] if pair_ids else None
|
| 1281 |
+
|
| 1282 |
+
# ONLY_SECOND
|
| 1283 |
+
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids:
|
| 1284 |
+
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
|
| 1285 |
+
if self.truncation_side == "right":
|
| 1286 |
+
overflowing_tokens = pair_ids[-window_len:]
|
| 1287 |
+
pair_ids = pair_ids[:-num_tokens_to_remove]
|
| 1288 |
+
else:
|
| 1289 |
+
overflowing_tokens = pair_ids[:window_len]
|
| 1290 |
+
pair_ids = pair_ids[num_tokens_to_remove:]
|
| 1291 |
+
|
| 1292 |
+
return ids, pair_ids, overflowing_tokens
|
| 1293 |
+
|
| 1294 |
+
def create_token_type_ids_from_sequences(
|
| 1295 |
+
self, token_ids_0: list[int], token_ids_1: list[int] | None = None
|
| 1296 |
+
) -> list[int]:
|
| 1297 |
+
"""
|
| 1298 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 1299 |
+
|
| 1300 |
+
This method dynamically builds the token type IDs based on the tokenizer's configuration attributes:
|
| 1301 |
+
- `token_type_ids_pattern`: Pattern to use ("all_zeros" or "bert_style")
|
| 1302 |
+
- `token_type_ids_include_special_tokens`: Whether to account for special tokens in length calculation
|
| 1303 |
+
|
| 1304 |
+
Args:
|
| 1305 |
+
token_ids_0 (`list[int]`):
|
| 1306 |
+
List of IDs.
|
| 1307 |
+
token_ids_1 (`list[int]`, *optional*):
|
| 1308 |
+
Optional second list of IDs for sequence pairs.
|
| 1309 |
+
|
| 1310 |
+
Returns:
|
| 1311 |
+
`list[int]`: Token type IDs according to the configured pattern.
|
| 1312 |
+
|
| 1313 |
+
Examples:
|
| 1314 |
+
```python
|
| 1315 |
+
# All zeros pattern (default, used by RoBERTa, BART, etc.)
|
| 1316 |
+
tokenizer.token_type_ids_pattern = "all_zeros"
|
| 1317 |
+
# Returns: [0, 0, 0, ...] for both sequences
|
| 1318 |
+
|
| 1319 |
+
# BERT-style pattern (first sequence gets 0s, second gets 1s)
|
| 1320 |
+
tokenizer.token_type_ids_pattern = "bert_style"
|
| 1321 |
+
# Returns: [0, 0, 0, ..., 1, 1, 1, ...] for sequence pairs
|
| 1322 |
+
```
|
| 1323 |
+
"""
|
| 1324 |
+
# Calculate lengths - account for special tokens if configured
|
| 1325 |
+
if self.token_type_ids_include_special_tokens:
|
| 1326 |
+
# Build the full sequence to get accurate length
|
| 1327 |
+
if token_ids_1 is None:
|
| 1328 |
+
sequence = self.build_inputs_with_special_tokens(token_ids_0)
|
| 1329 |
+
seq0_len = len(sequence)
|
| 1330 |
+
seq1_len = 0
|
| 1331 |
+
else:
|
| 1332 |
+
full_sequence = self.build_inputs_with_special_tokens(token_ids_0, token_ids_1)
|
| 1333 |
+
# Approximate split - this works for most tokenizers
|
| 1334 |
+
# For more complex cases, subclasses should still override
|
| 1335 |
+
seq0_with_special = self.build_inputs_with_special_tokens(token_ids_0)
|
| 1336 |
+
seq0_len = len(seq0_with_special)
|
| 1337 |
+
seq1_len = len(full_sequence) - seq0_len
|
| 1338 |
+
else:
|
| 1339 |
+
# Use raw token lengths
|
| 1340 |
+
seq0_len = len(token_ids_0)
|
| 1341 |
+
seq1_len = len(token_ids_1) if token_ids_1 is not None else 0
|
| 1342 |
+
|
| 1343 |
+
# Build token type IDs based on pattern
|
| 1344 |
+
if self.special_tokens_pattern == "prefix_suffix":
|
| 1345 |
+
total_len = len(getattr(self, "prefix_tokens", [])) + len(token_ids_0)
|
| 1346 |
+
if token_ids_1 is not None:
|
| 1347 |
+
total_len += len(token_ids_1)
|
| 1348 |
+
total_len += len(getattr(self, "suffix_tokens", []))
|
| 1349 |
+
return [0] * total_len
|
| 1350 |
+
|
| 1351 |
+
if self.token_type_ids_pattern == "bert_style" and token_ids_1 is not None:
|
| 1352 |
+
# BERT-style: first sequence gets 0s, second sequence gets 1s
|
| 1353 |
+
return [0] * seq0_len + [1] * seq1_len
|
| 1354 |
+
else:
|
| 1355 |
+
# All zeros pattern (default): everything gets 0s
|
| 1356 |
+
return [0] * (seq0_len + seq1_len)
|
| 1357 |
+
|
| 1358 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str, ...]:
|
| 1359 |
+
"""
|
| 1360 |
+
Default implementation for common vocabulary saving patterns.
|
| 1361 |
+
Saves self.encoder/self.vocab as JSON, optionally with self.bpe_ranks as merges.
|
| 1362 |
+
Returns empty tuple if no vocabulary exists.
|
| 1363 |
+
|
| 1364 |
+
Override this method if your tokenizer needs custom saving logic (e.g., SentencePiece models,
|
| 1365 |
+
multiple vocabulary files, or special file formats).
|
| 1366 |
+
|
| 1367 |
+
Args:
|
| 1368 |
+
save_directory (`str`):
|
| 1369 |
+
The directory in which to save the vocabulary.
|
| 1370 |
+
filename_prefix (`str`, *optional*):
|
| 1371 |
+
An optional prefix to add to the named of the saved files.
|
| 1372 |
+
|
| 1373 |
+
Returns:
|
| 1374 |
+
`tuple[str, ...]`: Paths to the files saved, or empty tuple if no files saved.
|
| 1375 |
+
"""
|
| 1376 |
+
import json
|
| 1377 |
+
import os
|
| 1378 |
+
|
| 1379 |
+
vocab_attr = getattr(self, "encoder", None) or getattr(self, "vocab", None)
|
| 1380 |
+
if vocab_attr is None:
|
| 1381 |
+
return ()
|
| 1382 |
+
|
| 1383 |
+
if not os.path.isdir(save_directory):
|
| 1384 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 1385 |
+
return ()
|
| 1386 |
+
|
| 1387 |
+
vocab_files_names = getattr(self, "vocab_files_names", {})
|
| 1388 |
+
prefix = f"{filename_prefix}-" if filename_prefix else ""
|
| 1389 |
+
|
| 1390 |
+
# Save vocabulary
|
| 1391 |
+
vocab_file = os.path.join(save_directory, prefix + vocab_files_names.get("vocab_file", "vocab.json"))
|
| 1392 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 1393 |
+
f.write(json.dumps(vocab_attr, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 1394 |
+
|
| 1395 |
+
# Save BPE merges if present
|
| 1396 |
+
bpe_ranks = getattr(self, "bpe_ranks", None)
|
| 1397 |
+
if bpe_ranks is None:
|
| 1398 |
+
return (vocab_file,)
|
| 1399 |
+
|
| 1400 |
+
merge_file = os.path.join(save_directory, prefix + vocab_files_names.get("merges_file", "merges.txt"))
|
| 1401 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
| 1402 |
+
if getattr(self, "add_bpe_version_header", False):
|
| 1403 |
+
writer.write("#version: 0.2\n")
|
| 1404 |
+
|
| 1405 |
+
index = 0
|
| 1406 |
+
for bpe_tokens, token_index in sorted(bpe_ranks.items(), key=lambda kv: kv[1]):
|
| 1407 |
+
if index != token_index:
|
| 1408 |
+
logger.warning(
|
| 1409 |
+
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
| 1410 |
+
" Please check that the tokenizer is not corrupted!"
|
| 1411 |
+
)
|
| 1412 |
+
index = token_index
|
| 1413 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
| 1414 |
+
index += 1
|
| 1415 |
+
|
| 1416 |
+
return (vocab_file, merge_file)
|
| 1417 |
+
|
| 1418 |
+
|
| 1419 |
+
# Backward compatibility alias
|
| 1420 |
+
PreTrainedTokenizer = PythonBackend
|
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/video_utils.py
ADDED
|
@@ -0,0 +1,893 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import warnings
|
| 17 |
+
from collections.abc import Callable, Iterable, Mapping
|
| 18 |
+
from contextlib import redirect_stdout
|
| 19 |
+
from dataclasses import dataclass, fields
|
| 20 |
+
from io import BytesIO
|
| 21 |
+
from typing import NewType, Union
|
| 22 |
+
from urllib.parse import urlparse
|
| 23 |
+
|
| 24 |
+
import httpx
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
from .image_transforms import PaddingMode, to_channel_dimension_format
|
| 28 |
+
from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image
|
| 29 |
+
from .utils import (
|
| 30 |
+
is_av_available,
|
| 31 |
+
is_cv2_available,
|
| 32 |
+
is_decord_available,
|
| 33 |
+
is_numpy_array,
|
| 34 |
+
is_torch_available,
|
| 35 |
+
is_torch_tensor,
|
| 36 |
+
is_torchcodec_available,
|
| 37 |
+
is_torchvision_available,
|
| 38 |
+
is_vision_available,
|
| 39 |
+
is_yt_dlp_available,
|
| 40 |
+
logging,
|
| 41 |
+
requires_backends,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if is_vision_available():
|
| 46 |
+
import PIL.Image
|
| 47 |
+
|
| 48 |
+
if is_torchvision_available():
|
| 49 |
+
from torchvision import io as torchvision_io
|
| 50 |
+
|
| 51 |
+
if is_torch_available():
|
| 52 |
+
import torch
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
logger = logging.get_logger(__name__)
|
| 56 |
+
|
| 57 |
+
URL = NewType("URL", str)
|
| 58 |
+
Path = NewType("Path", str)
|
| 59 |
+
|
| 60 |
+
VideoInput = Union[
|
| 61 |
+
list["PIL.Image.Image"],
|
| 62 |
+
np.ndarray,
|
| 63 |
+
"torch.Tensor",
|
| 64 |
+
list[np.ndarray],
|
| 65 |
+
list["torch.Tensor"],
|
| 66 |
+
list[list["PIL.Image.Image"]],
|
| 67 |
+
list[list[np.ndarray]],
|
| 68 |
+
list[list["torch.Tensor"]],
|
| 69 |
+
URL,
|
| 70 |
+
list[URL],
|
| 71 |
+
list[list[URL]],
|
| 72 |
+
Path,
|
| 73 |
+
list[Path],
|
| 74 |
+
list[list[Path]],
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass
|
| 79 |
+
class VideoMetadata(Mapping):
|
| 80 |
+
total_num_frames: int
|
| 81 |
+
fps: float | None = None
|
| 82 |
+
width: int | None = None
|
| 83 |
+
height: int | None = None
|
| 84 |
+
duration: float | None = None
|
| 85 |
+
video_backend: str | None = None
|
| 86 |
+
frames_indices: list[int] | None = None
|
| 87 |
+
|
| 88 |
+
def __iter__(self):
|
| 89 |
+
return (f.name for f in fields(self))
|
| 90 |
+
|
| 91 |
+
def __len__(self):
|
| 92 |
+
return len(fields(self))
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, item):
|
| 95 |
+
return getattr(self, item)
|
| 96 |
+
|
| 97 |
+
def __setitem__(self, key, value):
|
| 98 |
+
return setattr(self, key, value)
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def timestamps(self) -> list[float]:
|
| 102 |
+
"Timestamps of the sampled frames in seconds."
|
| 103 |
+
if self.fps is None or self.frames_indices is None:
|
| 104 |
+
raise ValueError("Cannot infer video `timestamps` when `fps` or `frames_indices` is None.")
|
| 105 |
+
return [frame_idx / self.fps for frame_idx in self.frames_indices]
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def sampled_fps(self) -> float:
|
| 109 |
+
"FPS of the sampled video."
|
| 110 |
+
if self.frames_indices is None or self.total_num_frames is None or self.fps is None:
|
| 111 |
+
return self.fps or 24
|
| 112 |
+
return len(self.frames_indices) / self.total_num_frames * self.fps
|
| 113 |
+
|
| 114 |
+
def update(self, dictionary):
|
| 115 |
+
for key, value in dictionary.items():
|
| 116 |
+
if hasattr(self, key):
|
| 117 |
+
setattr(self, key, value)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
VideoMetadataType = VideoMetadata | dict | list[dict | VideoMetadata] | list[list[dict | VideoMetadata]]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def is_valid_video_frame(frame):
|
| 124 |
+
return isinstance(frame, PIL.Image.Image) or (
|
| 125 |
+
(is_numpy_array(frame) or is_torch_tensor(frame)) and frame.ndim == 3
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def is_valid_video(video):
|
| 130 |
+
if not isinstance(video, (list, tuple)):
|
| 131 |
+
return (is_numpy_array(video) or is_torch_tensor(video)) and video.ndim == 4
|
| 132 |
+
return video and all(is_valid_video_frame(frame) for frame in video)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def valid_videos(videos):
|
| 136 |
+
# If we have a list of videos, it could be either one video as list of frames or a batch
|
| 137 |
+
if isinstance(videos, (list, tuple)):
|
| 138 |
+
for video_or_frame in videos:
|
| 139 |
+
if not (is_valid_video(video_or_frame) or is_valid_video_frame(video_or_frame)):
|
| 140 |
+
return False
|
| 141 |
+
# If not a list, then we have a single 4D video or 5D batched tensor
|
| 142 |
+
elif not is_valid_video(videos) or videos.ndim == 5:
|
| 143 |
+
return False
|
| 144 |
+
return True
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def is_batched_video(videos):
|
| 148 |
+
if isinstance(videos, (list, tuple)):
|
| 149 |
+
return is_valid_video(videos[0])
|
| 150 |
+
elif (is_numpy_array(videos) or is_torch_tensor(videos)) and videos.ndim == 5:
|
| 151 |
+
return True
|
| 152 |
+
return False
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def is_scaled_video(video: np.ndarray) -> bool:
|
| 156 |
+
"""
|
| 157 |
+
Checks to see whether the pixel values have already been rescaled to [0, 1].
|
| 158 |
+
"""
|
| 159 |
+
# It's possible the video has pixel values in [0, 255] but is of floating type
|
| 160 |
+
return np.min(video) >= 0 and np.max(video) <= 1
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union[np.ndarray, "torch.Tensor"]]:
|
| 164 |
+
"""
|
| 165 |
+
Given a batch of videos, converts each video to a 4D array. If video is already in array type,
|
| 166 |
+
it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
videos (`VideoInput`):
|
| 170 |
+
Video inputs to turn into a list of videos.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
if not (isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0])):
|
| 174 |
+
return videos
|
| 175 |
+
|
| 176 |
+
video_converted = []
|
| 177 |
+
for video in videos:
|
| 178 |
+
video = [np.array(frame) for frame in video]
|
| 179 |
+
video = np.stack(video)
|
| 180 |
+
video_converted.append(video)
|
| 181 |
+
return video_converted
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def make_batched_videos(videos) -> list[Union[np.ndarray, "torch.Tensor", "URL", "Path"]]:
|
| 185 |
+
"""
|
| 186 |
+
Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1.
|
| 187 |
+
If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image`
|
| 188 |
+
frames are converted to 4D arrays.
|
| 189 |
+
|
| 190 |
+
We assume that all inputs in the list are in the same format, based on the type of the first element.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
videos (`VideoInput`):
|
| 194 |
+
Video inputs to turn into a list of videos.
|
| 195 |
+
"""
|
| 196 |
+
# Early exit for deeply nested list of image frame paths. We shouldn't flatten them
|
| 197 |
+
try:
|
| 198 |
+
if isinstance(videos[0][0], list) and isinstance(videos[0][0][0], str):
|
| 199 |
+
return [image_paths for sublist in videos for image_paths in sublist]
|
| 200 |
+
except (IndexError, TypeError):
|
| 201 |
+
pass
|
| 202 |
+
|
| 203 |
+
if is_batched_video(videos):
|
| 204 |
+
return convert_pil_frames_to_video(list(videos))
|
| 205 |
+
elif isinstance(videos, str) or is_valid_video(videos):
|
| 206 |
+
return convert_pil_frames_to_video([videos])
|
| 207 |
+
# only one frame passed, thus we unsqueeze time dim
|
| 208 |
+
elif is_valid_image(videos):
|
| 209 |
+
if isinstance(videos, PIL.Image.Image):
|
| 210 |
+
videos = np.array(videos)
|
| 211 |
+
return [videos[None, ...]]
|
| 212 |
+
elif not isinstance(videos, list):
|
| 213 |
+
raise ValueError(
|
| 214 |
+
f"Invalid video input. Expected either a list of video frames or an input of 4 or 5 dimensions, but got"
|
| 215 |
+
f" type {type(videos)}."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Recursively flatten any nested structure
|
| 219 |
+
flat_videos_list = []
|
| 220 |
+
for item in videos:
|
| 221 |
+
if isinstance(item, str) or is_valid_video(item):
|
| 222 |
+
flat_videos_list.append(item)
|
| 223 |
+
elif isinstance(item, list) and item:
|
| 224 |
+
flat_videos_list.extend(make_batched_videos(item))
|
| 225 |
+
|
| 226 |
+
flat_videos_list = convert_pil_frames_to_video(flat_videos_list)
|
| 227 |
+
return flat_videos_list
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def make_batched_metadata(videos: VideoInput, video_metadata: VideoMetadataType) -> list[VideoMetadata]:
|
| 231 |
+
if video_metadata is None:
|
| 232 |
+
# Create default metadata and fill attributes we can infer from given video
|
| 233 |
+
video_metadata = [
|
| 234 |
+
{
|
| 235 |
+
"total_num_frames": len(video),
|
| 236 |
+
"fps": None,
|
| 237 |
+
"duration": None,
|
| 238 |
+
"frames_indices": list(range(len(video))),
|
| 239 |
+
"height": get_video_size(video)[0] if is_valid_video(video) else None,
|
| 240 |
+
"width": get_video_size(video)[1] if is_valid_video(video) else None,
|
| 241 |
+
}
|
| 242 |
+
for video in videos
|
| 243 |
+
]
|
| 244 |
+
|
| 245 |
+
if isinstance(video_metadata, list):
|
| 246 |
+
# Flatten if nested list
|
| 247 |
+
if isinstance(video_metadata[0], list):
|
| 248 |
+
video_metadata = [
|
| 249 |
+
VideoMetadata(**metadata) for metadata_list in video_metadata for metadata in metadata_list
|
| 250 |
+
]
|
| 251 |
+
# Simply wrap in VideoMetadata if simple dict
|
| 252 |
+
elif isinstance(video_metadata[0], dict):
|
| 253 |
+
video_metadata = [VideoMetadata(**metadata) for metadata in video_metadata]
|
| 254 |
+
else:
|
| 255 |
+
# Create a batched list from single object
|
| 256 |
+
video_metadata = [VideoMetadata(**video_metadata)]
|
| 257 |
+
return video_metadata
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def get_video_size(video: np.ndarray, channel_dim: ChannelDimension | None = None) -> tuple[int, int]:
|
| 261 |
+
"""
|
| 262 |
+
Returns the (height, width) dimensions of the video.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
video (`np.ndarray`):
|
| 266 |
+
The video to get the dimensions of.
|
| 267 |
+
channel_dim (`ChannelDimension`, *optional*):
|
| 268 |
+
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the video.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
A tuple of the video's height and width.
|
| 272 |
+
"""
|
| 273 |
+
if channel_dim is None:
|
| 274 |
+
channel_dim = infer_channel_dimension_format(video, num_channels=(1, 3, 4))
|
| 275 |
+
|
| 276 |
+
if channel_dim == ChannelDimension.FIRST:
|
| 277 |
+
return video.shape[-2], video.shape[-1]
|
| 278 |
+
elif channel_dim == ChannelDimension.LAST:
|
| 279 |
+
return video.shape[-3], video.shape[-2]
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"Unsupported data format: {channel_dim}")
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def get_uniform_frame_indices(total_num_frames: int, num_frames: int | None = None):
|
| 285 |
+
"""
|
| 286 |
+
Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
|
| 287 |
+
when loading a video.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
total_num_frames (`int`):
|
| 291 |
+
Total number of frames that a video has.
|
| 292 |
+
num_frames (`int`, *optional*):
|
| 293 |
+
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
np.ndarray: np array of frame indices that will be sampled.
|
| 297 |
+
"""
|
| 298 |
+
if num_frames is not None:
|
| 299 |
+
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
|
| 300 |
+
else:
|
| 301 |
+
indices = np.arange(0, total_num_frames).astype(int)
|
| 302 |
+
return indices
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
|
| 306 |
+
"""
|
| 307 |
+
A default sampling function that replicates the logic used in get_uniform_frame_indices,
|
| 308 |
+
while optionally handling `fps` if `num_frames` is not provided.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
metadata (`VideoMetadata`):
|
| 312 |
+
`VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
|
| 313 |
+
num_frames (`int`, *optional*):
|
| 314 |
+
Number of frames to sample uniformly.
|
| 315 |
+
fps (`int` or `float`, *optional*):
|
| 316 |
+
Desired frames per second. Takes priority over num_frames if both are provided.
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
`np.ndarray`: Array of frame indices to sample.
|
| 320 |
+
"""
|
| 321 |
+
total_num_frames = metadata.total_num_frames
|
| 322 |
+
video_fps = metadata.fps
|
| 323 |
+
|
| 324 |
+
# If num_frames is not given but fps is, calculate num_frames from fps
|
| 325 |
+
if num_frames is None and fps is not None:
|
| 326 |
+
num_frames = int(total_num_frames / video_fps * fps)
|
| 327 |
+
if num_frames > total_num_frames:
|
| 328 |
+
raise ValueError(
|
| 329 |
+
f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
|
| 330 |
+
f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
if num_frames is not None:
|
| 334 |
+
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
|
| 335 |
+
else:
|
| 336 |
+
indices = np.arange(0, total_num_frames, dtype=int)
|
| 337 |
+
return indices
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def read_video_opencv(
|
| 341 |
+
video_path: Union["URL", "Path"],
|
| 342 |
+
sample_indices_fn: Callable,
|
| 343 |
+
**kwargs,
|
| 344 |
+
) -> tuple[np.ndarray, VideoMetadata]:
|
| 345 |
+
"""
|
| 346 |
+
Decode a video using the OpenCV backend.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
video_path (`str`):
|
| 350 |
+
Path to the video file.
|
| 351 |
+
sample_indices_fn (`Callable`):
|
| 352 |
+
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
| 353 |
+
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
| 354 |
+
If not provided, simple uniform sampling with fps is performed.
|
| 355 |
+
Example:
|
| 356 |
+
def sample_indices_fn(metadata, **kwargs):
|
| 357 |
+
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
| 358 |
+
|
| 359 |
+
Returns:
|
| 360 |
+
tuple[`np.ndarray`, `VideoMetadata`]: A tuple containing:
|
| 361 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 362 |
+
- `VideoMetadata` object.
|
| 363 |
+
"""
|
| 364 |
+
# Lazy import cv2
|
| 365 |
+
requires_backends(read_video_opencv, ["cv2"])
|
| 366 |
+
import cv2
|
| 367 |
+
|
| 368 |
+
video = cv2.VideoCapture(video_path)
|
| 369 |
+
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 370 |
+
video_fps = video.get(cv2.CAP_PROP_FPS)
|
| 371 |
+
duration = total_num_frames / video_fps if video_fps else 0
|
| 372 |
+
metadata = VideoMetadata(
|
| 373 |
+
total_num_frames=int(total_num_frames),
|
| 374 |
+
fps=float(video_fps),
|
| 375 |
+
duration=float(duration),
|
| 376 |
+
video_backend="opencv",
|
| 377 |
+
height=int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)),
|
| 378 |
+
width=int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
| 382 |
+
index = 0
|
| 383 |
+
frames = []
|
| 384 |
+
while video.isOpened():
|
| 385 |
+
success, frame = video.read()
|
| 386 |
+
if not success:
|
| 387 |
+
break
|
| 388 |
+
if index in indices:
|
| 389 |
+
height, width, channel = frame.shape
|
| 390 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 391 |
+
frames.append(frame[0:height, 0:width, 0:channel])
|
| 392 |
+
if success:
|
| 393 |
+
index += 1
|
| 394 |
+
if index >= total_num_frames:
|
| 395 |
+
break
|
| 396 |
+
|
| 397 |
+
video.release()
|
| 398 |
+
metadata.frames_indices = indices
|
| 399 |
+
return np.stack(frames), metadata
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def read_video_decord(
|
| 403 |
+
video_path: Union["URL", "Path"],
|
| 404 |
+
sample_indices_fn: Callable,
|
| 405 |
+
**kwargs,
|
| 406 |
+
):
|
| 407 |
+
"""
|
| 408 |
+
Decode a video using the Decord backend.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
video_path (`str`):
|
| 412 |
+
Path to the video file.
|
| 413 |
+
sample_indices_fn (`Callable`):
|
| 414 |
+
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
| 415 |
+
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
| 416 |
+
If not provided, simple uniform sampling with fps is performed.
|
| 417 |
+
Example:
|
| 418 |
+
def sample_indices_fn(metadata, **kwargs):
|
| 419 |
+
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
| 420 |
+
|
| 421 |
+
Returns:
|
| 422 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 423 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 424 |
+
- `VideoMetadata` object.
|
| 425 |
+
"""
|
| 426 |
+
# Lazy import from decord
|
| 427 |
+
requires_backends(read_video_decord, ["decord"])
|
| 428 |
+
from decord import VideoReader, cpu
|
| 429 |
+
|
| 430 |
+
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
|
| 431 |
+
video_fps = vr.get_avg_fps()
|
| 432 |
+
total_num_frames = len(vr)
|
| 433 |
+
duration = total_num_frames / video_fps if video_fps else 0
|
| 434 |
+
metadata = VideoMetadata(
|
| 435 |
+
total_num_frames=int(total_num_frames),
|
| 436 |
+
fps=float(video_fps),
|
| 437 |
+
duration=float(duration),
|
| 438 |
+
video_backend="decord",
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
| 442 |
+
video = vr.get_batch(indices).asnumpy()
|
| 443 |
+
|
| 444 |
+
metadata.update(
|
| 445 |
+
{
|
| 446 |
+
"frames_indices": indices,
|
| 447 |
+
"height": video.shape[1],
|
| 448 |
+
"width": video.shape[2],
|
| 449 |
+
}
|
| 450 |
+
)
|
| 451 |
+
return video, metadata
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def read_video_pyav(
|
| 455 |
+
video_path: Union["URL", "Path"],
|
| 456 |
+
sample_indices_fn: Callable,
|
| 457 |
+
**kwargs,
|
| 458 |
+
):
|
| 459 |
+
"""
|
| 460 |
+
Decode the video with PyAV decoder.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
video_path (`str`):
|
| 464 |
+
Path to the video file.
|
| 465 |
+
sample_indices_fn (`Callable`, *optional*):
|
| 466 |
+
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
| 467 |
+
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
| 468 |
+
If not provided, simple uniform sampling with fps is performed.
|
| 469 |
+
Example:
|
| 470 |
+
def sample_indices_fn(metadata, **kwargs):
|
| 471 |
+
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
| 475 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 476 |
+
- `VideoMetadata` object.
|
| 477 |
+
"""
|
| 478 |
+
# Lazy import av
|
| 479 |
+
requires_backends(read_video_pyav, ["av"])
|
| 480 |
+
import av
|
| 481 |
+
|
| 482 |
+
container = av.open(video_path)
|
| 483 |
+
total_num_frames = container.streams.video[0].frames
|
| 484 |
+
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
|
| 485 |
+
duration = total_num_frames / video_fps if video_fps else 0
|
| 486 |
+
metadata = VideoMetadata(
|
| 487 |
+
total_num_frames=int(total_num_frames),
|
| 488 |
+
fps=float(video_fps),
|
| 489 |
+
duration=float(duration),
|
| 490 |
+
video_backend="pyav",
|
| 491 |
+
height=container.streams.video[0].height,
|
| 492 |
+
width=container.streams.video[0].width,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
| 496 |
+
frames = []
|
| 497 |
+
container.seek(0)
|
| 498 |
+
end_index = indices[-1]
|
| 499 |
+
for i, frame in enumerate(container.decode(video=0)):
|
| 500 |
+
if i > end_index:
|
| 501 |
+
break
|
| 502 |
+
if i >= 0 and i in indices:
|
| 503 |
+
frames.append(frame)
|
| 504 |
+
|
| 505 |
+
video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
| 506 |
+
metadata.frames_indices = indices
|
| 507 |
+
return video, metadata
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def read_video_torchvision(
|
| 511 |
+
video_path: Union["URL", "Path"],
|
| 512 |
+
sample_indices_fn: Callable,
|
| 513 |
+
**kwargs,
|
| 514 |
+
):
|
| 515 |
+
"""
|
| 516 |
+
Decode the video with torchvision decoder.
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
video_path (`str`):
|
| 520 |
+
Path to the video file.
|
| 521 |
+
sample_indices_fn (`Callable`, *optional*):
|
| 522 |
+
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
| 523 |
+
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
| 524 |
+
If not provided, simple uniform sampling with fps is performed.
|
| 525 |
+
Example:
|
| 526 |
+
def sample_indices_fn(metadata, **kwargs):
|
| 527 |
+
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
| 528 |
+
|
| 529 |
+
Returns:
|
| 530 |
+
tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
|
| 531 |
+
- Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 532 |
+
- `VideoMetadata` object.
|
| 533 |
+
"""
|
| 534 |
+
warnings.warn(
|
| 535 |
+
"Using `torchvision` for video decoding is deprecated and will be removed in future versions. "
|
| 536 |
+
"Please use `torchcodec` instead."
|
| 537 |
+
)
|
| 538 |
+
video, _, info = torchvision_io.read_video(
|
| 539 |
+
video_path,
|
| 540 |
+
start_pts=0.0,
|
| 541 |
+
end_pts=None,
|
| 542 |
+
pts_unit="sec",
|
| 543 |
+
output_format="TCHW",
|
| 544 |
+
)
|
| 545 |
+
video_fps = info["video_fps"]
|
| 546 |
+
total_num_frames = video.size(0)
|
| 547 |
+
duration = total_num_frames / video_fps if video_fps else 0
|
| 548 |
+
metadata = VideoMetadata(
|
| 549 |
+
total_num_frames=int(total_num_frames),
|
| 550 |
+
fps=float(video_fps),
|
| 551 |
+
duration=float(duration),
|
| 552 |
+
video_backend="torchvision",
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
| 556 |
+
video = video[indices].contiguous()
|
| 557 |
+
metadata.update(
|
| 558 |
+
{
|
| 559 |
+
"frames_indices": indices,
|
| 560 |
+
"height": video.shape[2],
|
| 561 |
+
"width": video.shape[3],
|
| 562 |
+
}
|
| 563 |
+
)
|
| 564 |
+
return video, metadata
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def read_video_torchcodec(
|
| 568 |
+
video_path: Union["URL", "Path"],
|
| 569 |
+
sample_indices_fn: Callable,
|
| 570 |
+
**kwargs,
|
| 571 |
+
):
|
| 572 |
+
"""
|
| 573 |
+
Decode the video with torchcodec decoder.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
video_path (`str`):
|
| 577 |
+
Path to the video file.
|
| 578 |
+
sample_indices_fn (`Callable`):
|
| 579 |
+
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
| 580 |
+
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
| 581 |
+
If not provided, simple uniform sampling with fps is performed.
|
| 582 |
+
Example:
|
| 583 |
+
def sample_indices_fn(metadata, **kwargs):
|
| 584 |
+
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
| 585 |
+
|
| 586 |
+
Returns:
|
| 587 |
+
Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
|
| 588 |
+
- Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 589 |
+
- `VideoMetadata` object.
|
| 590 |
+
"""
|
| 591 |
+
# Lazy import torchcodec
|
| 592 |
+
requires_backends(read_video_torchcodec, ["torchcodec"])
|
| 593 |
+
from torchcodec.decoders import VideoDecoder
|
| 594 |
+
|
| 595 |
+
# VideoDecoder expects a string for device, default to "cpu" if None
|
| 596 |
+
|
| 597 |
+
decoder = VideoDecoder(
|
| 598 |
+
video_path,
|
| 599 |
+
# Interestingly `exact` mode takes less than approximate when we load the whole video
|
| 600 |
+
seek_mode="exact",
|
| 601 |
+
# Allow FFmpeg decide on the number of threads for efficiency
|
| 602 |
+
num_ffmpeg_threads=0,
|
| 603 |
+
device=kwargs.get("device", "cpu"),
|
| 604 |
+
)
|
| 605 |
+
total_num_frames = decoder.metadata.num_frames
|
| 606 |
+
video_fps = decoder.metadata.average_fps
|
| 607 |
+
metadata = VideoMetadata(
|
| 608 |
+
total_num_frames=total_num_frames,
|
| 609 |
+
fps=video_fps,
|
| 610 |
+
duration=decoder.metadata.duration_seconds,
|
| 611 |
+
video_backend="torchcodec",
|
| 612 |
+
height=decoder.metadata.height,
|
| 613 |
+
width=decoder.metadata.width,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
| 617 |
+
video = decoder.get_frames_at(indices=indices).data.contiguous()
|
| 618 |
+
metadata.frames_indices = indices
|
| 619 |
+
return video, metadata
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
VIDEO_DECODERS = {
|
| 623 |
+
"decord": read_video_decord,
|
| 624 |
+
"opencv": read_video_opencv,
|
| 625 |
+
"pyav": read_video_pyav,
|
| 626 |
+
"torchvision": read_video_torchvision,
|
| 627 |
+
"torchcodec": read_video_torchcodec,
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def load_video(
|
| 632 |
+
video: VideoInput,
|
| 633 |
+
num_frames: int | None = None,
|
| 634 |
+
fps: int | float | None = None,
|
| 635 |
+
backend: str = "pyav",
|
| 636 |
+
sample_indices_fn: Callable | None = None,
|
| 637 |
+
**kwargs,
|
| 638 |
+
) -> np.ndarray:
|
| 639 |
+
"""
|
| 640 |
+
Loads `video` to a numpy array.
|
| 641 |
+
|
| 642 |
+
Args:
|
| 643 |
+
video (`VideoInput`):
|
| 644 |
+
The video to convert to the numpy array format. Can be a link to video or local path.
|
| 645 |
+
num_frames (`int`, *optional*):
|
| 646 |
+
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
| 647 |
+
fps (`int` or `float`, *optional*):
|
| 648 |
+
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
| 649 |
+
If not specified and `num_frames==None`, all frames are sampled.
|
| 650 |
+
backend (`str`, *optional*, defaults to `"pyav"`):
|
| 651 |
+
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav".
|
| 652 |
+
sample_indices_fn (`Callable`, *optional*):
|
| 653 |
+
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
| 654 |
+
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
| 655 |
+
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
|
| 656 |
+
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
|
| 657 |
+
indices at which the video should be sampled. For example:
|
| 658 |
+
|
| 659 |
+
Example:
|
| 660 |
+
def sample_indices_fn(metadata, **kwargs):
|
| 661 |
+
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
| 662 |
+
|
| 663 |
+
Returns:
|
| 664 |
+
tuple[`np.ndarray`, Dict]: A tuple containing:
|
| 665 |
+
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
| 666 |
+
- Metadata dictionary.
|
| 667 |
+
"""
|
| 668 |
+
|
| 669 |
+
# If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
|
| 670 |
+
if fps is not None and num_frames is not None and sample_indices_fn is None:
|
| 671 |
+
raise ValueError(
|
| 672 |
+
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
# If user didn't pass a sampling function, create one on the fly with default logic
|
| 676 |
+
if sample_indices_fn is None:
|
| 677 |
+
|
| 678 |
+
def sample_indices_fn_func(metadata, **fn_kwargs):
|
| 679 |
+
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
|
| 680 |
+
|
| 681 |
+
sample_indices_fn = sample_indices_fn_func
|
| 682 |
+
|
| 683 |
+
# Early exit if provided an array or `PIL` frames
|
| 684 |
+
if not isinstance(video, str):
|
| 685 |
+
metadata = [None] * len(video)
|
| 686 |
+
return video, metadata
|
| 687 |
+
|
| 688 |
+
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
|
| 689 |
+
if not is_yt_dlp_available():
|
| 690 |
+
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
| 691 |
+
# Lazy import from yt_dlp
|
| 692 |
+
requires_backends(load_video, ["yt_dlp"])
|
| 693 |
+
from yt_dlp import YoutubeDL
|
| 694 |
+
|
| 695 |
+
buffer = BytesIO()
|
| 696 |
+
with redirect_stdout(buffer), YoutubeDL() as f:
|
| 697 |
+
f.download([video])
|
| 698 |
+
bytes_obj = buffer.getvalue()
|
| 699 |
+
file_obj = BytesIO(bytes_obj)
|
| 700 |
+
elif video.startswith("http://") or video.startswith("https://"):
|
| 701 |
+
file_obj = BytesIO(httpx.get(video, follow_redirects=True).content)
|
| 702 |
+
elif os.path.isfile(video):
|
| 703 |
+
file_obj = video
|
| 704 |
+
else:
|
| 705 |
+
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
|
| 706 |
+
|
| 707 |
+
# can also load with decord, but not cv2/torchvision
|
| 708 |
+
# both will fail in case of url links
|
| 709 |
+
video_is_url = video.startswith("http://") or video.startswith("https://")
|
| 710 |
+
if video_is_url and backend == "opencv":
|
| 711 |
+
raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
|
| 712 |
+
|
| 713 |
+
if (
|
| 714 |
+
(not is_decord_available() and backend == "decord")
|
| 715 |
+
or (not is_av_available() and backend == "pyav")
|
| 716 |
+
or (not is_cv2_available() and backend == "opencv")
|
| 717 |
+
or (not is_torchvision_available() and backend == "torchvision")
|
| 718 |
+
or (not is_torchcodec_available() and backend == "torchcodec")
|
| 719 |
+
):
|
| 720 |
+
raise ImportError(
|
| 721 |
+
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
| 722 |
+
f"Make sure to install {backend} before loading the video."
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
video_decoder = VIDEO_DECODERS[backend]
|
| 726 |
+
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
|
| 727 |
+
return video, metadata
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
def convert_to_rgb(
|
| 731 |
+
video: np.ndarray,
|
| 732 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 733 |
+
) -> np.ndarray:
|
| 734 |
+
"""
|
| 735 |
+
Convert video to RGB by blending the transparency layer if it's in RGBA format, otherwise simply returns it.
|
| 736 |
+
|
| 737 |
+
Args:
|
| 738 |
+
video (`np.ndarray`):
|
| 739 |
+
The video to convert.
|
| 740 |
+
input_data_format (`ChannelDimension`, *optional*):
|
| 741 |
+
The channel dimension format of the input video. If unset, will use the inferred format from the input.
|
| 742 |
+
"""
|
| 743 |
+
if not isinstance(video, np.ndarray):
|
| 744 |
+
raise TypeError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}")
|
| 745 |
+
|
| 746 |
+
# np.array usually comes with ChannelDimension.LAST so let's convert it
|
| 747 |
+
if input_data_format is None:
|
| 748 |
+
input_data_format = infer_channel_dimension_format(video)
|
| 749 |
+
video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format)
|
| 750 |
+
|
| 751 |
+
# 3 channels for RGB already
|
| 752 |
+
if video.shape[-3] == 3:
|
| 753 |
+
return video
|
| 754 |
+
|
| 755 |
+
# Grayscale video so we repeat it 3 times for each channel
|
| 756 |
+
if video.shape[-3] == 1:
|
| 757 |
+
return video.repeat(3, -3)
|
| 758 |
+
|
| 759 |
+
if not (video[..., 3, :, :] < 255).any():
|
| 760 |
+
return video
|
| 761 |
+
|
| 762 |
+
# There is a transparency layer, blend it with a white background.
|
| 763 |
+
# Calculate the alpha proportion for blending.
|
| 764 |
+
alpha = video[..., 3, :, :] / 255.0
|
| 765 |
+
video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., 3, :, :]
|
| 766 |
+
return video
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
def pad(
|
| 770 |
+
video: np.ndarray,
|
| 771 |
+
padding: int | tuple[int, int] | Iterable[tuple[int, int]],
|
| 772 |
+
mode: PaddingMode = PaddingMode.CONSTANT,
|
| 773 |
+
constant_values: float | Iterable[float] = 0.0,
|
| 774 |
+
data_format: str | ChannelDimension | None = None,
|
| 775 |
+
input_data_format: str | ChannelDimension | None = None,
|
| 776 |
+
) -> np.ndarray:
|
| 777 |
+
"""
|
| 778 |
+
Pads the `video` with the specified (height, width) `padding` and `mode`.
|
| 779 |
+
|
| 780 |
+
Args:
|
| 781 |
+
video (`np.ndarray`):
|
| 782 |
+
The video to pad.
|
| 783 |
+
padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
|
| 784 |
+
Padding to apply to the edges of the height, width axes. Can be one of three formats:
|
| 785 |
+
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
|
| 786 |
+
- `((before, after),)` yields same before and after pad for height and width.
|
| 787 |
+
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
|
| 788 |
+
mode (`PaddingMode`):
|
| 789 |
+
The padding mode to use. Can be one of:
|
| 790 |
+
- `"constant"`: pads with a constant value.
|
| 791 |
+
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
|
| 792 |
+
vector along each axis.
|
| 793 |
+
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
|
| 794 |
+
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
|
| 795 |
+
constant_values (`float` or `Iterable[float]`, *optional*):
|
| 796 |
+
The value to use for the padding if `mode` is `"constant"`.
|
| 797 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 798 |
+
The channel dimension format for the output video. Can be one of:
|
| 799 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
|
| 800 |
+
- `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
|
| 801 |
+
If unset, will use same as the input video.
|
| 802 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 803 |
+
The channel dimension format for the input video. Can be one of:
|
| 804 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
|
| 805 |
+
- `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
|
| 806 |
+
If unset, will use the inferred format of the input video.
|
| 807 |
+
|
| 808 |
+
Returns:
|
| 809 |
+
`np.ndarray`: The padded video.
|
| 810 |
+
|
| 811 |
+
"""
|
| 812 |
+
if input_data_format is None:
|
| 813 |
+
input_data_format = infer_channel_dimension_format(video)
|
| 814 |
+
|
| 815 |
+
def _expand_for_data_format(values):
|
| 816 |
+
"""
|
| 817 |
+
Convert values to be in the format expected by np.pad based on the data format.
|
| 818 |
+
"""
|
| 819 |
+
if isinstance(values, (int, float)):
|
| 820 |
+
values = ((values, values), (values, values))
|
| 821 |
+
elif isinstance(values, tuple) and len(values) == 1:
|
| 822 |
+
values = ((values[0], values[0]), (values[0], values[0]))
|
| 823 |
+
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
|
| 824 |
+
values = (values, values)
|
| 825 |
+
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
|
| 826 |
+
pass
|
| 827 |
+
else:
|
| 828 |
+
raise ValueError(f"Unsupported format: {values}")
|
| 829 |
+
|
| 830 |
+
# add 0 for channel dimension
|
| 831 |
+
values = (
|
| 832 |
+
((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0))
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
# Add additional padding if there's a batch dimension
|
| 836 |
+
values = (0, *values) if video.ndim == 5 else values
|
| 837 |
+
return values
|
| 838 |
+
|
| 839 |
+
padding_map = {
|
| 840 |
+
PaddingMode.CONSTANT: "constant",
|
| 841 |
+
PaddingMode.REFLECT: "reflect",
|
| 842 |
+
PaddingMode.REPLICATE: "replicate",
|
| 843 |
+
PaddingMode.SYMMETRIC: "symmetric",
|
| 844 |
+
}
|
| 845 |
+
padding = _expand_for_data_format(padding)
|
| 846 |
+
|
| 847 |
+
pad_kwargs = {}
|
| 848 |
+
if mode not in padding_map:
|
| 849 |
+
raise ValueError(f"Invalid padding mode: {mode}")
|
| 850 |
+
elif mode == PaddingMode.CONSTANT:
|
| 851 |
+
pad_kwargs["constant_values"] = _expand_for_data_format(constant_values)
|
| 852 |
+
|
| 853 |
+
video = np.pad(video, padding, mode=padding_map[mode], **pad_kwargs)
|
| 854 |
+
video = to_channel_dimension_format(video, data_format, input_data_format) if data_format is not None else video
|
| 855 |
+
return video
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
def group_videos_by_shape(
|
| 859 |
+
videos: list["torch.Tensor"],
|
| 860 |
+
) -> tuple[dict[tuple[int, int], "torch.Tensor"], dict[int, tuple[tuple[int, int], int]]]:
|
| 861 |
+
"""
|
| 862 |
+
Groups videos by shape.
|
| 863 |
+
Returns a dictionary with the shape as key and a list of videos with that shape as value,
|
| 864 |
+
and a dictionary with the index of the video in the original list as key and the shape and index in the grouped list as value.
|
| 865 |
+
"""
|
| 866 |
+
grouped_videos = {}
|
| 867 |
+
|
| 868 |
+
grouped_videos_index = {}
|
| 869 |
+
for i, video in enumerate(videos):
|
| 870 |
+
shape = video.shape[-2::]
|
| 871 |
+
num_frames = video.shape[-4] # video format BTCHW
|
| 872 |
+
shape = (num_frames, *shape)
|
| 873 |
+
if shape not in grouped_videos:
|
| 874 |
+
grouped_videos[shape] = []
|
| 875 |
+
grouped_videos[shape].append(video)
|
| 876 |
+
grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
|
| 877 |
+
|
| 878 |
+
# stack videos with the same size and number of frames
|
| 879 |
+
grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
|
| 880 |
+
return grouped_videos, grouped_videos_index
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
def reorder_videos(
|
| 884 |
+
processed_videos: dict[tuple[int, int], "torch.Tensor"],
|
| 885 |
+
grouped_videos_index: dict[int, tuple[tuple[int, int], int]],
|
| 886 |
+
) -> list["torch.Tensor"]:
|
| 887 |
+
"""
|
| 888 |
+
Reconstructs a list of videos in the original order.
|
| 889 |
+
"""
|
| 890 |
+
return [
|
| 891 |
+
processed_videos[grouped_videos_index[i][0]][grouped_videos_index[i][1]]
|
| 892 |
+
for i in range(len(grouped_videos_index))
|
| 893 |
+
]
|